Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for non-string key-types for GetMapValue and element_at() #4944

Merged
merged 7 commits into from
Mar 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 38 additions & 38 deletions docs/supported_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -4968,23 +4968,23 @@ are limited.
<td> </td>
<td> </td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, UDT</em></td>
<td><em>PS<br/>If it's map, only string is supported.;<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, UDT</em></td>
<td><em>PS<br/>If it's map, only primitive key types are supported.;<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, UDT</em></td>
<td> </td>
<td> </td>
</tr>
<tr>
<td>index/key</td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS<br/>ints are only supported as array indexes, not as maps keys</em></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS<br/>strings are only supported as map keys, not array indexes;<br/>Literal value only</em></td>
<td><b>NS</b></td>
<td>S</td>
<td>S</td>
<td>S</td>
<td><em>PS<br/>Only ints are supported as array indexes</em></td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td><em>PS<br/>UTC is only supported TZ for TIMESTAMP</em></td>
<td>S</td>
<td><em>PS<br/>max DECIMAL precision of 18;<br/>Literal value only</em></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
Expand Down Expand Up @@ -5907,23 +5907,23 @@ are limited.
<td> </td>
<td> </td>
<td> </td>
<td><em>PS<br/>unsupported child types BOOLEAN, BYTE, SHORT, INT, LONG, FLOAT, DOUBLE, DATE, TIMESTAMP, DECIMAL, NULL, BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT</em></td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, UDT</em></td>
<td> </td>
<td> </td>
</tr>
<tr>
<td>key</td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS<br/>Literal value only</em></td>
<td><b>NS</b></td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td><em>PS<br/>UTC is only supported TZ for TIMESTAMP</em></td>
<td>S</td>
<td><em>PS<br/>max DECIMAL precision of 18;<br/>Literal value only</em></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
Expand All @@ -5934,23 +5934,23 @@ are limited.
</tr>
<tr>
<td>result</td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td><em>PS<br/>UTC is only supported TZ for TIMESTAMP</em></td>
<td>S</td>
<td>S</td>
<td>S</td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, UDT</em></td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, UDT</em></td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, UDT</em></td>
<td><b>NS</b></td>
</tr>
<tr>
Expand Down
120 changes: 113 additions & 7 deletions integration_tests/src/main/python/map_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,17 +61,81 @@ def test_map_entries(data_gen):
# in here yet, and would need some special case code for checking equality
'map_entries(a)'))

@pytest.mark.parametrize('data_gen', [simple_string_to_string_map_gen], ids=idfn)
def test_simple_get_map_value(data_gen):

def get_map_value_gens():
def simple_struct_value_gen():
return StructGen([["child", IntegerGen()]])

def nested_struct_value_gen():
return StructGen([["child", simple_struct_value_gen()]])

def nested_map_value_gen():
return MapGen(StringGen(pattern='key_[0-9]', nullable=False), IntegerGen(), max_length=6)

def array_value_gen():
return ArrayGen(IntegerGen(), max_length=6)

return [ByteGen, ShortGen, IntegerGen, LongGen, FloatGen, DoubleGen,
StringGen, DateGen, TimestampGen, DecimalGen,
simple_struct_value_gen, nested_struct_value_gen, nested_map_value_gen, array_value_gen]


@pytest.mark.parametrize('data_gen',
[MapGen(StringGen(pattern='key_[0-9]', nullable=False), value(), max_length=6)
for value in get_map_value_gens()],
ids=idfn)
def test_get_map_value_string_keys(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, data_gen).selectExpr(
lambda spark: unary_op_df(spark, data_gen).selectExpr(
'a["key_0"]',
'a["key_1"]',
'a[null]',
'a["key_9"]',
'a["NOT_FOUND"]',
'a["key_5"]'))


numeric_key_gens = [key(nullable=False) if key in [FloatGen, DoubleGen, DecimalGen]
else key(nullable=False, min_val=0, max_val=100)
for key in [ByteGen, ShortGen, IntegerGen, LongGen, FloatGen, DoubleGen, DecimalGen]]

numeric_key_map_gens = [MapGen(key, value(), max_length=6)
for key in numeric_key_gens for value in get_map_value_gens()]


@pytest.mark.parametrize('data_gen', numeric_key_map_gens, ids=idfn)
def test_get_map_value_numeric_keys(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, data_gen).selectExpr(
'a[0]',
'a[1]',
'a[null]',
'a[-9]',
'a[999]'))


@pytest.mark.parametrize('data_gen',
[MapGen(DateGen(nullable=False), value(), max_length=6)
for value in get_map_value_gens()], ids=idfn)
def test_get_map_value_date_keys(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, data_gen).selectExpr(
'a[date "1997"]',
'a[date "2022-01-01"]',
'a[null]'))


@pytest.mark.parametrize('data_gen',
[MapGen(TimestampGen(nullable=False), value(), max_length=6)
for value in get_map_value_gens()], ids=idfn)
def test_get_map_value_timestamp_keys(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, data_gen).selectExpr(
'a[timestamp "1997"]',
'a[timestamp "2022-01-01"]',
'a[null]'))


@pytest.mark.parametrize('key_gen', [StringGen(nullable=False), IntegerGen(nullable=False), basic_struct_gen], ids=idfn)
@pytest.mark.parametrize('value_gen', [StringGen(nullable=True), IntegerGen(nullable=True), basic_struct_gen], ids=idfn)
def test_single_entry_map(key_gen, value_gen):
Expand Down Expand Up @@ -249,17 +313,59 @@ def test_map_get_map_value_ansi_not_fail(data_gen):
'a["NOT_FOUND"]'),
conf=ansi_enabled_conf)

@pytest.mark.parametrize('data_gen', [simple_string_to_string_map_gen], ids=idfn)
def test_simple_element_at_map(data_gen):

@pytest.mark.parametrize('data_gen',
[MapGen(StringGen(pattern='key_[0-9]', nullable=False), value(), max_length=6)
for value in get_map_value_gens()],
ids=idfn)
def test_element_at_map_string_keys(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, data_gen).selectExpr(
lambda spark: unary_op_df(spark, data_gen).selectExpr(
'element_at(a, "key_0")',
'element_at(a, "key_1")',
'element_at(a, "null")',
'element_at(a, "key_9")',
'element_at(a, "NOT_FOUND")',
'element_at(a, "key_5")'),
conf={'spark.sql.ansi.enabled':False})
conf={'spark.sql.ansi.enabled': False})


@pytest.mark.parametrize('data_gen', numeric_key_map_gens, ids=idfn)
def test_element_at_map_numeric_keys(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, data_gen).selectExpr(
'element_at(a, 0)',
'element_at(a, 1)',
'element_at(a, null)',
'element_at(a, -9)',
'element_at(a, 999)'),
conf={'spark.sql.ansi.enabled': False})


@pytest.mark.parametrize('data_gen',
[MapGen(DateGen(nullable=False), value(), max_length=6)
for value in get_map_value_gens()], ids=idfn)
def test_element_at_map_date_keys(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, data_gen).selectExpr(
'element_at(a, date "1997")',
'element_at(a, date "2022-01-01")',
'element_at(a, null)'),
conf={'spark.sql.ansi.enabled': False})


@pytest.mark.parametrize('data_gen',
[MapGen(TimestampGen(nullable=False), value(), max_length=6)
for value in get_map_value_gens()],
ids=idfn)
def test_element_at_map_timestamp_keys(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, data_gen).selectExpr(
'element_at(a, timestamp "1997")',
'element_at(a, timestamp "2022-01-01")',
'element_at(a, null)'),
conf={'spark.sql.ansi.enabled': False})


@pytest.mark.skipif(is_before_spark_311(), reason="Only in Spark 3.1.1 + ANSI mode, map key throws on no such element")
@pytest.mark.parametrize('data_gen', [simple_string_to_string_map_gen], ids=idfn)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -279,9 +279,14 @@ abstract class Spark31XShims extends Spark301until320Shims with Logging {
}),
GpuOverrides.expr[GetMapValue](
"Gets Value from a Map based on a key",
ExprChecks.binaryProject(TypeSig.STRING, TypeSig.all,
("map", TypeSig.MAP.nested(TypeSig.STRING), TypeSig.MAP.nested(TypeSig.all)),
("key", TypeSig.lit(TypeEnum.STRING), TypeSig.all)),
ExprChecks.binaryProject(
(TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.NULL +
mythrocks marked this conversation as resolved.
Show resolved Hide resolved
TypeSig.DECIMAL_128 + TypeSig.MAP).nested(),
mythrocks marked this conversation as resolved.
Show resolved Hide resolved
TypeSig.all,
("map",
TypeSig.MAP.nested(TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT +
TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.MAP), TypeSig.MAP.nested(TypeSig.all)),
("key", TypeSig.commonCudfTypesLit() + TypeSig.lit(TypeEnum.DECIMAL), TypeSig.all)),
(in, conf, p, r) => new GpuGetMapValueMeta(in, conf, p, r){
override def convertToGpu(map: Expression, key: Expression): GpuExpression =
GpuGetMapValue(map, key, shouldFailOnElementNotExists)
Expand All @@ -294,24 +299,28 @@ abstract class Spark31XShims extends Spark301until320Shims with Logging {
TypeSig.DECIMAL_128 + TypeSig.MAP).nested(), TypeSig.all,
("array/map", TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.ARRAY +
TypeSig.STRUCT + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.MAP) +
TypeSig.MAP.nested(TypeSig.STRING)
.withPsNote(TypeEnum.MAP ,"If it's map, only string is supported."),
TypeSig.MAP.nested(TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT +
TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.MAP)
.withPsNote(TypeEnum.MAP ,"If it's map, only primitive key types supported."),
mythrocks marked this conversation as resolved.
Show resolved Hide resolved
TypeSig.ARRAY.nested(TypeSig.all) + TypeSig.MAP.nested(TypeSig.all)),
("index/key", (TypeSig.INT + TypeSig.lit(TypeEnum.STRING))
.withPsNote(TypeEnum.INT, "ints are only supported as array indexes, " +
"not as maps keys")
.withPsNote(TypeEnum.STRING, "strings are only supported as map keys, " +
"not array indexes"),
("index/key", (TypeSig.INT + TypeSig.commonCudfTypesLit() + TypeSig.lit(TypeEnum.DECIMAL))
.withPsNote(TypeEnum.INT, "Only ints are supported as array indexes"),
TypeSig.all)),
(in, conf, p, r) => new BinaryExprMeta[ElementAt](in, conf, p, r) {
override def tagExprForGpu(): Unit = {
// To distinguish the supported nested type between Array and Map
val checks = in.left.dataType match {
case _: MapType =>
// Match exactly with the checks for GetMapValue
ExprChecks.binaryProject(TypeSig.STRING, TypeSig.all,
("map", TypeSig.MAP.nested(TypeSig.STRING), TypeSig.MAP.nested(TypeSig.all)),
("key", TypeSig.lit(TypeEnum.STRING), TypeSig.all))
ExprChecks.binaryProject(
(TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.NULL +
TypeSig.DECIMAL_128 + TypeSig.MAP).nested(),
TypeSig.all,
("map",
TypeSig.MAP.nested(TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT +
TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.MAP),
TypeSig.MAP.nested(TypeSig.all)),
("key", TypeSig.commonCudfTypesLit() + TypeSig.lit(TypeEnum.DECIMAL), TypeSig.all))
case _: ArrayType =>
// Match exactly with the checks for GetArrayItem
ExprChecks.binaryProject(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -278,9 +278,15 @@ abstract class Spark31XdbShims extends Spark31XdbShimsBase with Logging {
}),
GpuOverrides.expr[GetMapValue](
"Gets Value from a Map based on a key",
ExprChecks.binaryProject(TypeSig.STRING, TypeSig.all,
("map", TypeSig.MAP.nested(TypeSig.STRING), TypeSig.MAP.nested(TypeSig.all)),
("key", TypeSig.lit(TypeEnum.STRING), TypeSig.all)),
ExprChecks.binaryProject(
(TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.NULL +
TypeSig.DECIMAL_128 + TypeSig.MAP).nested(),
TypeSig.all,
("map",
TypeSig.MAP.nested(TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT +
TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.MAP),
TypeSig.MAP.nested(TypeSig.all)),
("key", TypeSig.commonCudfTypesLit() + TypeSig.lit(TypeEnum.DECIMAL), TypeSig.all)),
(in, conf, p, r) => new GpuGetMapValueMeta(in, conf, p, r){
override def convertToGpu(map: Expression, key: Expression): GpuExpression =
GpuGetMapValue(map, key, shouldFailOnElementNotExists)
Expand All @@ -293,24 +299,27 @@ abstract class Spark31XdbShims extends Spark31XdbShimsBase with Logging {
TypeSig.DECIMAL_128 + TypeSig.MAP).nested(), TypeSig.all,
("array/map", TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.ARRAY +
TypeSig.STRUCT + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.MAP) +
TypeSig.MAP.nested(TypeSig.STRING)
.withPsNote(TypeEnum.MAP ,"If it's map, only string is supported."),
TypeSig.MAP.nested(TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT +
TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.MAP)
.withPsNote(TypeEnum.MAP ,"If it's map, only primitive key types are supported."),
TypeSig.ARRAY.nested(TypeSig.all) + TypeSig.MAP.nested(TypeSig.all)),
("index/key", (TypeSig.INT + TypeSig.lit(TypeEnum.STRING))
.withPsNote(TypeEnum.INT, "ints are only supported as array indexes, " +
"not as maps keys")
.withPsNote(TypeEnum.STRING, "strings are only supported as map keys, " +
"not array indexes"),
("index/key", (TypeSig.INT + TypeSig.commonCudfTypesLit() + TypeSig.lit(TypeEnum.DECIMAL))
.withPsNote(TypeEnum.INT, "Only ints are supported as array indexes"),
TypeSig.all)),
(in, conf, p, r) => new BinaryExprMeta[ElementAt](in, conf, p, r) {
override def tagExprForGpu(): Unit = {
// To distinguish the supported nested type between Array and Map
val checks = in.left.dataType match {
case _: MapType =>
// Match exactly with the checks for GetMapValue
ExprChecks.binaryProject(TypeSig.STRING, TypeSig.all,
("map", TypeSig.MAP.nested(TypeSig.STRING), TypeSig.MAP.nested(TypeSig.all)),
("key", TypeSig.lit(TypeEnum.STRING), TypeSig.all))
ExprChecks.binaryProject(
(TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.NULL +
TypeSig.DECIMAL_128 + TypeSig.MAP).nested(),
TypeSig.all,
("map", TypeSig.MAP.nested(TypeSig.commonCudfTypes + TypeSig.ARRAY +
TypeSig.STRUCT + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.MAP),
TypeSig.MAP.nested(TypeSig.all)),
("key", TypeSig.commonCudfTypesLit() + TypeSig.lit(TypeEnum.DECIMAL), TypeSig.all))
case _: ArrayType =>
// Match exactly with the checks for GetArrayItem
ExprChecks.binaryProject(
Expand Down
Loading