Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for non-string key-types for GetMapValue and element_at() #4944

Merged
merged 7 commits into from
Mar 18, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 88 additions & 7 deletions integration_tests/src/main/python/map_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,17 +61,60 @@ def test_map_entries(data_gen):
# in here yet, and would need some special case code for checking equality
'map_entries(a)'))

@pytest.mark.parametrize('data_gen', [simple_string_to_string_map_gen], ids=idfn)
def test_simple_get_map_value(data_gen):

map_value_gens = [ByteGen, ShortGen, IntegerGen, LongGen, FloatGen, DoubleGen, StringGen, DateGen, TimestampGen]


@pytest.mark.parametrize('data_gen',
[MapGen(StringGen(pattern='key_[0-9]', nullable=False), value()) for value in map_value_gens],
mythrocks marked this conversation as resolved.
Show resolved Hide resolved
ids=idfn)
def test_get_map_value_string_keys(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, data_gen).selectExpr(
lambda spark: unary_op_df(spark, data_gen).selectExpr(
'a["key_0"]',
'a["key_1"]',
'a[null]',
'a["key_9"]',
'a["NOT_FOUND"]',
'a["key_5"]'))


numeric_key_gens = [key(nullable=False) if key in [FloatGen, DoubleGen]
else key(nullable=False, min_val=0, max_val=100)
for key in [ByteGen, ShortGen, IntegerGen, LongGen, FloatGen, DoubleGen]]

numeric_key_map_gens = [MapGen(key, value()) for key in numeric_key_gens for value in map_value_gens]


@pytest.mark.parametrize('data_gen', numeric_key_map_gens, ids=idfn)
def test_get_map_value_numeric_keys(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, data_gen).selectExpr(
'a[0]',
'a[1]',
'a[null]',
'a[-9]',
'a[999]'))


@pytest.mark.parametrize('data_gen', [MapGen(DateGen(nullable=False), value()) for value in map_value_gens], ids=idfn)
def test_get_map_value_date_keys(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, data_gen).selectExpr(
'a[date "1997"]',
'a[date "2022-01-01"]',
'a[null]'))


@pytest.mark.parametrize('data_gen', [MapGen(TimestampGen(nullable=False), value()) for value in map_value_gens], ids=idfn)
def test_get_map_value_timestamp_keys(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, data_gen).selectExpr(
'a[timestamp "1997"]',
'a[timestamp "2022-01-01"]',
'a[null]'))


@pytest.mark.parametrize('key_gen', [StringGen(nullable=False), IntegerGen(nullable=False), basic_struct_gen], ids=idfn)
@pytest.mark.parametrize('value_gen', [StringGen(nullable=True), IntegerGen(nullable=True), basic_struct_gen], ids=idfn)
def test_single_entry_map(key_gen, value_gen):
Expand Down Expand Up @@ -249,17 +292,55 @@ def test_map_get_map_value_ansi_not_fail(data_gen):
'a["NOT_FOUND"]'),
conf=ansi_enabled_conf)

@pytest.mark.parametrize('data_gen', [simple_string_to_string_map_gen], ids=idfn)
def test_simple_element_at_map(data_gen):

@pytest.mark.parametrize('data_gen',
[MapGen(StringGen(pattern='key_[0-9]', nullable=False), value()) for value in map_value_gens],
ids=idfn)
def test_element_at_map_string_keys(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, data_gen).selectExpr(
lambda spark: unary_op_df(spark, data_gen).selectExpr(
'element_at(a, "key_0")',
'element_at(a, "key_1")',
'element_at(a, "null")',
'element_at(a, "key_9")',
'element_at(a, "NOT_FOUND")',
'element_at(a, "key_5")'),
conf={'spark.sql.ansi.enabled':False})
conf={'spark.sql.ansi.enabled': False})


@pytest.mark.parametrize('data_gen', numeric_key_map_gens, ids=idfn)
def test_element_at_map_numeric_keys(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, data_gen).selectExpr(
'element_at(a, 0)',
'element_at(a, 1)',
'element_at(a, null)',
'element_at(a, -9)',
'element_at(a, 999)'),
conf={'spark.sql.ansi.enabled': False})


@pytest.mark.parametrize('data_gen', [MapGen(DateGen(nullable=False), value()) for value in map_value_gens], ids=idfn)
def test_element_at_map_date_keys(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, data_gen).selectExpr(
'element_at(a, date "1997")',
'element_at(a, date "2022-01-01")',
'element_at(a, null)'),
conf={'spark.sql.ansi.enabled': False})


@pytest.mark.parametrize('data_gen',
[MapGen(TimestampGen(nullable=False), value()) for value in map_value_gens],
ids=idfn)
def test_element_at_map_timestamp_keys(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, data_gen).selectExpr(
'element_at(a, timestamp "1997")',
'element_at(a, timestamp "2022-01-01")',
'element_at(a, null)'),
conf={'spark.sql.ansi.enabled': False})


@pytest.mark.skipif(is_before_spark_311(), reason="Only in Spark 3.1.1 + ANSI mode, map key throws on no such element")
@pytest.mark.parametrize('data_gen', [simple_string_to_string_map_gen], ids=idfn)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -279,9 +279,12 @@ abstract class Spark31XShims extends Spark301until320Shims with Logging {
}),
GpuOverrides.expr[GetMapValue](
"Gets Value from a Map based on a key",
ExprChecks.binaryProject(TypeSig.STRING, TypeSig.all,
("map", TypeSig.MAP.nested(TypeSig.STRING), TypeSig.MAP.nested(TypeSig.all)),
("key", TypeSig.lit(TypeEnum.STRING), TypeSig.all)),
ExprChecks.binaryProject(
(TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.NULL +
mythrocks marked this conversation as resolved.
Show resolved Hide resolved
TypeSig.DECIMAL_128 + TypeSig.MAP).nested(),
mythrocks marked this conversation as resolved.
Show resolved Hide resolved
TypeSig.all,
("map", TypeSig.MAP.nested(TypeSig.commonCudfTypes), TypeSig.MAP.nested(TypeSig.all)),
("key", TypeSig.commonCudfTypesLit(), TypeSig.all)),
(in, conf, p, r) => new GpuGetMapValueMeta(in, conf, p, r){
override def convertToGpu(map: Expression, key: Expression): GpuExpression =
GpuGetMapValue(map, key, shouldFailOnElementNotExists)
Expand All @@ -294,24 +297,24 @@ abstract class Spark31XShims extends Spark301until320Shims with Logging {
TypeSig.DECIMAL_128 + TypeSig.MAP).nested(), TypeSig.all,
("array/map", TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.ARRAY +
TypeSig.STRUCT + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.MAP) +
TypeSig.MAP.nested(TypeSig.STRING)
.withPsNote(TypeEnum.MAP ,"If it's map, only string is supported."),
TypeSig.MAP.nested(TypeSig.commonCudfTypes)
.withPsNote(TypeEnum.MAP ,"If it's map, only primitive key types supported."),
mythrocks marked this conversation as resolved.
Show resolved Hide resolved
TypeSig.ARRAY.nested(TypeSig.all) + TypeSig.MAP.nested(TypeSig.all)),
("index/key", (TypeSig.INT + TypeSig.lit(TypeEnum.STRING))
.withPsNote(TypeEnum.INT, "ints are only supported as array indexes, " +
"not as maps keys")
.withPsNote(TypeEnum.STRING, "strings are only supported as map keys, " +
"not array indexes"),
("index/key", (TypeSig.INT + TypeSig.commonCudfTypesLit())
.withPsNote(TypeEnum.INT, "Only ints are supported as array indexes"),
TypeSig.all)),
(in, conf, p, r) => new BinaryExprMeta[ElementAt](in, conf, p, r) {
override def tagExprForGpu(): Unit = {
// To distinguish the supported nested type between Array and Map
val checks = in.left.dataType match {
case _: MapType =>
// Match exactly with the checks for GetMapValue
ExprChecks.binaryProject(TypeSig.STRING, TypeSig.all,
("map", TypeSig.MAP.nested(TypeSig.STRING), TypeSig.MAP.nested(TypeSig.all)),
("key", TypeSig.lit(TypeEnum.STRING), TypeSig.all))
ExprChecks.binaryProject(
(TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.NULL +
TypeSig.DECIMAL_128 + TypeSig.MAP).nested(),
TypeSig.all,
("map", TypeSig.MAP.nested(TypeSig.commonCudfTypes), TypeSig.MAP.nested(TypeSig.all)),
("key", TypeSig.commonCudfTypesLit(), TypeSig.all))
case _: ArrayType =>
// Match exactly with the checks for GetArrayItem
ExprChecks.binaryProject(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -278,9 +278,12 @@ abstract class Spark31XdbShims extends Spark31XdbShimsBase with Logging {
}),
GpuOverrides.expr[GetMapValue](
"Gets Value from a Map based on a key",
ExprChecks.binaryProject(TypeSig.STRING, TypeSig.all,
("map", TypeSig.MAP.nested(TypeSig.STRING), TypeSig.MAP.nested(TypeSig.all)),
("key", TypeSig.lit(TypeEnum.STRING), TypeSig.all)),
ExprChecks.binaryProject(
(TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.NULL +
TypeSig.DECIMAL_128 + TypeSig.MAP).nested(),
TypeSig.all,
("map", TypeSig.MAP.nested(TypeSig.commonCudfTypes), TypeSig.MAP.nested(TypeSig.all)),
("key", TypeSig.commonCudfTypesLit(), TypeSig.all)),
(in, conf, p, r) => new GpuGetMapValueMeta(in, conf, p, r){
override def convertToGpu(map: Expression, key: Expression): GpuExpression =
GpuGetMapValue(map, key, shouldFailOnElementNotExists)
Expand All @@ -293,24 +296,24 @@ abstract class Spark31XdbShims extends Spark31XdbShimsBase with Logging {
TypeSig.DECIMAL_128 + TypeSig.MAP).nested(), TypeSig.all,
("array/map", TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.ARRAY +
TypeSig.STRUCT + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.MAP) +
TypeSig.MAP.nested(TypeSig.STRING)
.withPsNote(TypeEnum.MAP ,"If it's map, only string is supported."),
TypeSig.MAP.nested(TypeSig.commonCudfTypes)
.withPsNote(TypeEnum.MAP ,"If it's map, only primitive key types are supported."),
TypeSig.ARRAY.nested(TypeSig.all) + TypeSig.MAP.nested(TypeSig.all)),
("index/key", (TypeSig.INT + TypeSig.lit(TypeEnum.STRING))
.withPsNote(TypeEnum.INT, "ints are only supported as array indexes, " +
"not as maps keys")
.withPsNote(TypeEnum.STRING, "strings are only supported as map keys, " +
"not array indexes"),
("index/key", (TypeSig.INT + TypeSig.commonCudfTypesLit())
.withPsNote(TypeEnum.INT, "Only ints are supported as array indexes"),
TypeSig.all)),
(in, conf, p, r) => new BinaryExprMeta[ElementAt](in, conf, p, r) {
override def tagExprForGpu(): Unit = {
// To distinguish the supported nested type between Array and Map
val checks = in.left.dataType match {
case _: MapType =>
// Match exactly with the checks for GetMapValue
ExprChecks.binaryProject(TypeSig.STRING, TypeSig.all,
("map", TypeSig.MAP.nested(TypeSig.STRING), TypeSig.MAP.nested(TypeSig.all)),
("key", TypeSig.lit(TypeEnum.STRING), TypeSig.all))
ExprChecks.binaryProject(
(TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.NULL +
TypeSig.DECIMAL_128 + TypeSig.MAP).nested(),
TypeSig.all,
("map", TypeSig.MAP.nested(TypeSig.commonCudfTypes), TypeSig.MAP.nested(TypeSig.all)),
("key", TypeSig.commonCudfTypesLit(), TypeSig.all))
case _: ArrayType =>
// Match exactly with the checks for GetArrayItem
ExprChecks.binaryProject(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -383,9 +383,12 @@ trait Spark320PlusShims extends SparkShims with RebaseShims with Logging {
}),
GpuOverrides.expr[GetMapValue](
"Gets Value from a Map based on a key",
ExprChecks.binaryProject(TypeSig.STRING, TypeSig.all,
("map", TypeSig.MAP.nested(TypeSig.STRING), TypeSig.MAP.nested(TypeSig.all)),
("key", TypeSig.lit(TypeEnum.STRING), TypeSig.all)),
ExprChecks.binaryProject(
(TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.NULL +
TypeSig.DECIMAL_128 + TypeSig.MAP).nested(),
TypeSig.all,
("map", TypeSig.MAP.nested(TypeSig.commonCudfTypes), TypeSig.MAP.nested(TypeSig.all)),
("key", TypeSig.commonCudfTypesLit(), TypeSig.all)),
(in, conf, p, r) => new GpuGetMapValueMeta(in, conf, p, r) {
override def convertToGpu(map: Expression, key: Expression): GpuExpression =
GpuGetMapValue(map, key, shouldFailOnElementNotExists)
Expand All @@ -398,24 +401,24 @@ trait Spark320PlusShims extends SparkShims with RebaseShims with Logging {
TypeSig.DECIMAL_128 + TypeSig.MAP).nested(), TypeSig.all,
("array/map", TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.ARRAY +
TypeSig.STRUCT + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.MAP) +
TypeSig.MAP.nested(TypeSig.STRING)
.withPsNote(TypeEnum.MAP, "If it's map, only string is supported."),
TypeSig.MAP.nested(TypeSig.commonCudfTypes)
.withPsNote(TypeEnum.MAP, "If it's map, only primitive key types are supported."),
TypeSig.ARRAY.nested(TypeSig.all) + TypeSig.MAP.nested(TypeSig.all)),
("index/key", (TypeSig.INT + TypeSig.lit(TypeEnum.STRING))
.withPsNote(TypeEnum.INT, "ints are only supported as array indexes, " +
"not as maps keys")
.withPsNote(TypeEnum.STRING, "strings are only supported as map keys, " +
"not array indexes"),
("index/key", (TypeSig.INT + TypeSig.commonCudfTypesLit())
.withPsNote(TypeEnum.INT, "Only ints are supported as array indexes"),
TypeSig.all)),
(in, conf, p, r) => new BinaryExprMeta[ElementAt](in, conf, p, r) {
override def tagExprForGpu(): Unit = {
// To distinguish the supported nested type between Array and Map
val checks = in.left.dataType match {
case _: MapType =>
// Match exactly with the checks for GetMapValue
ExprChecks.binaryProject(TypeSig.STRING, TypeSig.all,
("map", TypeSig.MAP.nested(TypeSig.STRING), TypeSig.MAP.nested(TypeSig.all)),
("key", TypeSig.lit(TypeEnum.STRING), TypeSig.all))
ExprChecks.binaryProject(
(TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.NULL +
TypeSig.DECIMAL_128 + TypeSig.MAP).nested(),
TypeSig.all,
("map", TypeSig.MAP.nested(TypeSig.commonCudfTypes), TypeSig.MAP.nested(TypeSig.all)),
("key", TypeSig.commonCudfTypesLit(), TypeSig.all))
case _: ArrayType =>
// Match exactly with the checks for GetArrayItem
ExprChecks.binaryProject(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2577,9 +2577,12 @@ object GpuOverrides extends Logging {
(in, conf, p, r) => new GpuGetArrayItemMeta(in, conf, p, r)),
expr[GetMapValue](
"Gets Value from a Map based on a key",
ExprChecks.binaryProject(TypeSig.STRING, TypeSig.all,
("map", TypeSig.MAP.nested(TypeSig.STRING), TypeSig.MAP.nested(TypeSig.all)),
("key", TypeSig.lit(TypeEnum.STRING), TypeSig.all)),
ExprChecks.binaryProject(
(TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.NULL +
TypeSig.DECIMAL_128 + TypeSig.MAP).nested(),
TypeSig.all,
("map", TypeSig.MAP.nested(TypeSig.commonCudfTypes), TypeSig.MAP.nested(TypeSig.all)),
("key", TypeSig.commonCudfTypesLit(), TypeSig.all)),
(in, conf, p, r) => new GpuGetMapValueMeta(in, conf, p, r)),
expr[ElementAt](
"Returns element of array at given(1-based) index in value if column is array. " +
Expand All @@ -2589,24 +2592,25 @@ object GpuOverrides extends Logging {
TypeSig.DECIMAL_128 + TypeSig.MAP).nested(), TypeSig.all,
("array/map", TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.ARRAY +
TypeSig.STRUCT + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.MAP) +
TypeSig.MAP.nested(TypeSig.STRING)
.withPsNote(TypeEnum.MAP ,"If it's map, only string is supported."),
TypeSig.MAP.nested(TypeSig.commonCudfTypes)
.withPsNote(TypeEnum.MAP ,"If it's map, only primitive key types are supported."),
TypeSig.ARRAY.nested(TypeSig.all) + TypeSig.MAP.nested(TypeSig.all)),
("index/key", (TypeSig.INT + TypeSig.lit(TypeEnum.STRING))
.withPsNote(TypeEnum.INT, "ints are only supported as array indexes, " +
"not as maps keys")
.withPsNote(TypeEnum.STRING, "strings are only supported as map keys, " +
"not array indexes"),
("index/key", (TypeSig.INT + TypeSig.commonCudfTypesLit())
.withPsNote(TypeEnum.INT, "Only ints are supported as array indexes"),
TypeSig.all)),
(in, conf, p, r) => new BinaryExprMeta[ElementAt](in, conf, p, r) {
override def tagExprForGpu(): Unit = {
// To distinguish the supported nested type between Array and Map
val checks = in.left.dataType match {
case _: MapType =>
// Match exactly with the checks for GetMapValue
ExprChecks.binaryProject(TypeSig.STRING, TypeSig.all,
("map", TypeSig.MAP.nested(TypeSig.STRING), TypeSig.MAP.nested(TypeSig.all)),
("key", TypeSig.lit(TypeEnum.STRING), TypeSig.all))
ExprChecks.binaryProject(
(TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.NULL +
TypeSig.DECIMAL_128 + TypeSig.MAP).nested(),
TypeSig.all,
("map", TypeSig.MAP.nested(TypeSig.commonCudfTypes),
TypeSig.MAP.nested(TypeSig.all)),
("key", TypeSig.commonCudfTypesLit(), TypeSig.all))
case _: ArrayType =>
// Match exactly with the checks for GetArrayItem
ExprChecks.binaryProject(
Expand Down
24 changes: 24 additions & 0 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,30 @@ object TypeSig {
def lit(dataType: TypeEnum.Value): TypeSig =
TypeSig.none.withLit(dataType)

/**
* Create a TypeSig that only supports literals of certain given types.
*/
def lit(dataTypes: TypeEnum.ValueSet): TypeSig =
new TypeSig(dataTypes)

/**
* Create a TypeSig that supports only literals of common primitive CUDF types.
*/
def commonCudfTypesLit(): TypeSig = {
lit(TypeEnum.ValueSet(
TypeEnum.BOOLEAN,
TypeEnum.BYTE,
TypeEnum.SHORT,
TypeEnum.INT,
TypeEnum.LONG,
TypeEnum.FLOAT,
TypeEnum.DOUBLE,
TypeEnum.DATE,
TypeEnum.TIMESTAMP,
TypeEnum.STRING
))
}

/**
* Create a TypeSig that has partial support for the given type.
*/
Expand Down