Skip to content

Commit

Permalink
Support for non-string Map keys and values
Browse files Browse the repository at this point in the history
This commit adds support for non-string keys and values
in map columns. Specifically:
1. `GetMapValue` supports keys of integral, floating point,
    chrono, and string scalars. The returned column can be
    of any CUDF type.
2. Similarly, `ElementAt` on Map inputs now supports keys
   of integral, floating point, chrono, and string scalars.
   The returned column can be of any supported CUDF type.

Signed-off-by: MithunR <mythrocks@gmail.com>
  • Loading branch information
mythrocks committed Mar 15, 2022
1 parent 611f54d commit 66f3e2d
Show file tree
Hide file tree
Showing 6 changed files with 177 additions and 59 deletions.
95 changes: 88 additions & 7 deletions integration_tests/src/main/python/map_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,17 +61,60 @@ def test_map_entries(data_gen):
# in here yet, and would need some special case code for checking equality
'map_entries(a)'))

@pytest.mark.parametrize('data_gen', [simple_string_to_string_map_gen], ids=idfn)
def test_simple_get_map_value(data_gen):

map_value_gens = [ByteGen, ShortGen, IntegerGen, LongGen, FloatGen, DoubleGen, StringGen, DateGen, TimestampGen]


@pytest.mark.parametrize('data_gen',
[MapGen(StringGen(pattern='key_[0-9]', nullable=False), value()) for value in map_value_gens],
ids=idfn)
def test_get_map_value_string_keys(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, data_gen).selectExpr(
lambda spark: unary_op_df(spark, data_gen).selectExpr(
'a["key_0"]',
'a["key_1"]',
'a[null]',
'a["key_9"]',
'a["NOT_FOUND"]',
'a["key_5"]'))


numeric_key_gens = [key(nullable=False) if key in [FloatGen, DoubleGen]
else key(nullable=False, min_val=0, max_val=100)
for key in [ByteGen, ShortGen, IntegerGen, LongGen, FloatGen, DoubleGen]]

numeric_key_map_gens = [MapGen(key, value()) for key in numeric_key_gens for value in map_value_gens]


@pytest.mark.parametrize('data_gen', numeric_key_map_gens, ids=idfn)
def test_get_map_value_numeric_keys(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, data_gen).selectExpr(
'a[0]',
'a[1]',
'a[null]',
'a[-9]',
'a[999]'))


@pytest.mark.parametrize('data_gen', [MapGen(DateGen(nullable=False), value()) for value in map_value_gens], ids=idfn)
def test_get_map_value_date_keys(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, data_gen).selectExpr(
'a[date "1997"]',
'a[date "2022-01-01"]',
'a[null]'))


@pytest.mark.parametrize('data_gen', [MapGen(TimestampGen(nullable=False), value()) for value in map_value_gens], ids=idfn)
def test_get_map_value_timestamp_keys(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, data_gen).selectExpr(
'a[timestamp "1997"]',
'a[timestamp "2022-01-01"]',
'a[null]'))


@pytest.mark.parametrize('key_gen', [StringGen(nullable=False), IntegerGen(nullable=False), basic_struct_gen], ids=idfn)
@pytest.mark.parametrize('value_gen', [StringGen(nullable=True), IntegerGen(nullable=True), basic_struct_gen], ids=idfn)
def test_single_entry_map(key_gen, value_gen):
Expand Down Expand Up @@ -249,17 +292,55 @@ def test_map_get_map_value_ansi_not_fail(data_gen):
'a["NOT_FOUND"]'),
conf=ansi_enabled_conf)

@pytest.mark.parametrize('data_gen', [simple_string_to_string_map_gen], ids=idfn)
def test_simple_element_at_map(data_gen):

@pytest.mark.parametrize('data_gen',
[MapGen(StringGen(pattern='key_[0-9]', nullable=False), value()) for value in map_value_gens],
ids=idfn)
def test_element_at_map_string_keys(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, data_gen).selectExpr(
lambda spark: unary_op_df(spark, data_gen).selectExpr(
'element_at(a, "key_0")',
'element_at(a, "key_1")',
'element_at(a, "null")',
'element_at(a, "key_9")',
'element_at(a, "NOT_FOUND")',
'element_at(a, "key_5")'),
conf={'spark.sql.ansi.enabled':False})
conf={'spark.sql.ansi.enabled': False})


@pytest.mark.parametrize('data_gen', numeric_key_map_gens, ids=idfn)
def test_element_at_map_numeric_keys(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, data_gen).selectExpr(
'element_at(a, 0)',
'element_at(a, 1)',
'element_at(a, null)',
'element_at(a, -9)',
'element_at(a, 999)'),
conf={'spark.sql.ansi.enabled': False})


@pytest.mark.parametrize('data_gen', [MapGen(DateGen(nullable=False), value()) for value in map_value_gens], ids=idfn)
def test_element_at_map_date_keys(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, data_gen).selectExpr(
'element_at(a, date "1997")',
'element_at(a, date "2022-01-01")',
'element_at(a, null)'),
conf={'spark.sql.ansi.enabled': False})


@pytest.mark.parametrize('data_gen',
[MapGen(TimestampGen(nullable=False), value()) for value in map_value_gens],
ids=idfn)
def test_element_at_map_timestamp_keys(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, data_gen).selectExpr(
'element_at(a, timestamp "1997")',
'element_at(a, timestamp "2022-01-01")',
'element_at(a, null)'),
conf={'spark.sql.ansi.enabled': False})


@pytest.mark.skipif(is_before_spark_311(), reason="Only in Spark 3.1.1 + ANSI mode, map key throws on no such element")
@pytest.mark.parametrize('data_gen', [simple_string_to_string_map_gen], ids=idfn)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -279,9 +279,12 @@ abstract class Spark31XShims extends Spark301until320Shims with Logging {
}),
GpuOverrides.expr[GetMapValue](
"Gets Value from a Map based on a key",
ExprChecks.binaryProject(TypeSig.STRING, TypeSig.all,
("map", TypeSig.MAP.nested(TypeSig.STRING), TypeSig.MAP.nested(TypeSig.all)),
("key", TypeSig.lit(TypeEnum.STRING), TypeSig.all)),
ExprChecks.binaryProject(
(TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.NULL +
TypeSig.DECIMAL_128 + TypeSig.MAP).nested(),
TypeSig.all,
("map", TypeSig.MAP.nested(TypeSig.commonCudfTypes), TypeSig.MAP.nested(TypeSig.all)),
("key", TypeSig.commonCudfTypesLit(), TypeSig.all)),
(in, conf, p, r) => new GpuGetMapValueMeta(in, conf, p, r){
override def convertToGpu(map: Expression, key: Expression): GpuExpression =
GpuGetMapValue(map, key, shouldFailOnElementNotExists)
Expand All @@ -294,24 +297,24 @@ abstract class Spark31XShims extends Spark301until320Shims with Logging {
TypeSig.DECIMAL_128 + TypeSig.MAP).nested(), TypeSig.all,
("array/map", TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.ARRAY +
TypeSig.STRUCT + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.MAP) +
TypeSig.MAP.nested(TypeSig.STRING)
.withPsNote(TypeEnum.MAP ,"If it's map, only string is supported."),
TypeSig.MAP.nested(TypeSig.commonCudfTypes)
.withPsNote(TypeEnum.MAP ,"If it's map, only primitive key types supported."),
TypeSig.ARRAY.nested(TypeSig.all) + TypeSig.MAP.nested(TypeSig.all)),
("index/key", (TypeSig.INT + TypeSig.lit(TypeEnum.STRING))
.withPsNote(TypeEnum.INT, "ints are only supported as array indexes, " +
"not as maps keys")
.withPsNote(TypeEnum.STRING, "strings are only supported as map keys, " +
"not array indexes"),
("index/key", (TypeSig.INT + TypeSig.commonCudfTypesLit())
.withPsNote(TypeEnum.INT, "Only ints are supported as array indexes"),
TypeSig.all)),
(in, conf, p, r) => new BinaryExprMeta[ElementAt](in, conf, p, r) {
override def tagExprForGpu(): Unit = {
// To distinguish the supported nested type between Array and Map
val checks = in.left.dataType match {
case _: MapType =>
// Match exactly with the checks for GetMapValue
ExprChecks.binaryProject(TypeSig.STRING, TypeSig.all,
("map", TypeSig.MAP.nested(TypeSig.STRING), TypeSig.MAP.nested(TypeSig.all)),
("key", TypeSig.lit(TypeEnum.STRING), TypeSig.all))
ExprChecks.binaryProject(
(TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.NULL +
TypeSig.DECIMAL_128 + TypeSig.MAP).nested(),
TypeSig.all,
("map", TypeSig.MAP.nested(TypeSig.commonCudfTypes), TypeSig.MAP.nested(TypeSig.all)),
("key", TypeSig.commonCudfTypesLit(), TypeSig.all))
case _: ArrayType =>
// Match exactly with the checks for GetArrayItem
ExprChecks.binaryProject(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -278,9 +278,12 @@ abstract class Spark31XdbShims extends Spark31XdbShimsBase with Logging {
}),
GpuOverrides.expr[GetMapValue](
"Gets Value from a Map based on a key",
ExprChecks.binaryProject(TypeSig.STRING, TypeSig.all,
("map", TypeSig.MAP.nested(TypeSig.STRING), TypeSig.MAP.nested(TypeSig.all)),
("key", TypeSig.lit(TypeEnum.STRING), TypeSig.all)),
ExprChecks.binaryProject(
(TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.NULL +
TypeSig.DECIMAL_128 + TypeSig.MAP).nested(),
TypeSig.all,
("map", TypeSig.MAP.nested(TypeSig.commonCudfTypes), TypeSig.MAP.nested(TypeSig.all)),
("key", TypeSig.commonCudfTypesLit(), TypeSig.all)),
(in, conf, p, r) => new GpuGetMapValueMeta(in, conf, p, r){
override def convertToGpu(map: Expression, key: Expression): GpuExpression =
GpuGetMapValue(map, key, shouldFailOnElementNotExists)
Expand All @@ -293,24 +296,24 @@ abstract class Spark31XdbShims extends Spark31XdbShimsBase with Logging {
TypeSig.DECIMAL_128 + TypeSig.MAP).nested(), TypeSig.all,
("array/map", TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.ARRAY +
TypeSig.STRUCT + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.MAP) +
TypeSig.MAP.nested(TypeSig.STRING)
.withPsNote(TypeEnum.MAP ,"If it's map, only string is supported."),
TypeSig.MAP.nested(TypeSig.commonCudfTypes)
.withPsNote(TypeEnum.MAP ,"If it's map, only primitive key types are supported."),
TypeSig.ARRAY.nested(TypeSig.all) + TypeSig.MAP.nested(TypeSig.all)),
("index/key", (TypeSig.INT + TypeSig.lit(TypeEnum.STRING))
.withPsNote(TypeEnum.INT, "ints are only supported as array indexes, " +
"not as maps keys")
.withPsNote(TypeEnum.STRING, "strings are only supported as map keys, " +
"not array indexes"),
("index/key", (TypeSig.INT + TypeSig.commonCudfTypesLit())
.withPsNote(TypeEnum.INT, "Only ints are supported as array indexes"),
TypeSig.all)),
(in, conf, p, r) => new BinaryExprMeta[ElementAt](in, conf, p, r) {
override def tagExprForGpu(): Unit = {
// To distinguish the supported nested type between Array and Map
val checks = in.left.dataType match {
case _: MapType =>
// Match exactly with the checks for GetMapValue
ExprChecks.binaryProject(TypeSig.STRING, TypeSig.all,
("map", TypeSig.MAP.nested(TypeSig.STRING), TypeSig.MAP.nested(TypeSig.all)),
("key", TypeSig.lit(TypeEnum.STRING), TypeSig.all))
ExprChecks.binaryProject(
(TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.NULL +
TypeSig.DECIMAL_128 + TypeSig.MAP).nested(),
TypeSig.all,
("map", TypeSig.MAP.nested(TypeSig.commonCudfTypes), TypeSig.MAP.nested(TypeSig.all)),
("key", TypeSig.commonCudfTypesLit(), TypeSig.all))
case _: ArrayType =>
// Match exactly with the checks for GetArrayItem
ExprChecks.binaryProject(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -383,9 +383,12 @@ trait Spark320PlusShims extends SparkShims with RebaseShims with Logging {
}),
GpuOverrides.expr[GetMapValue](
"Gets Value from a Map based on a key",
ExprChecks.binaryProject(TypeSig.STRING, TypeSig.all,
("map", TypeSig.MAP.nested(TypeSig.STRING), TypeSig.MAP.nested(TypeSig.all)),
("key", TypeSig.lit(TypeEnum.STRING), TypeSig.all)),
ExprChecks.binaryProject(
(TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.NULL +
TypeSig.DECIMAL_128 + TypeSig.MAP).nested(),
TypeSig.all,
("map", TypeSig.MAP.nested(TypeSig.commonCudfTypes), TypeSig.MAP.nested(TypeSig.all)),
("key", TypeSig.commonCudfTypesLit(), TypeSig.all)),
(in, conf, p, r) => new GpuGetMapValueMeta(in, conf, p, r) {
override def convertToGpu(map: Expression, key: Expression): GpuExpression =
GpuGetMapValue(map, key, shouldFailOnElementNotExists)
Expand All @@ -398,24 +401,24 @@ trait Spark320PlusShims extends SparkShims with RebaseShims with Logging {
TypeSig.DECIMAL_128 + TypeSig.MAP).nested(), TypeSig.all,
("array/map", TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.ARRAY +
TypeSig.STRUCT + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.MAP) +
TypeSig.MAP.nested(TypeSig.STRING)
.withPsNote(TypeEnum.MAP, "If it's map, only string is supported."),
TypeSig.MAP.nested(TypeSig.commonCudfTypes)
.withPsNote(TypeEnum.MAP, "If it's map, only primitive key types are supported."),
TypeSig.ARRAY.nested(TypeSig.all) + TypeSig.MAP.nested(TypeSig.all)),
("index/key", (TypeSig.INT + TypeSig.lit(TypeEnum.STRING))
.withPsNote(TypeEnum.INT, "ints are only supported as array indexes, " +
"not as maps keys")
.withPsNote(TypeEnum.STRING, "strings are only supported as map keys, " +
"not array indexes"),
("index/key", (TypeSig.INT + TypeSig.commonCudfTypesLit())
.withPsNote(TypeEnum.INT, "Only ints are supported as array indexes"),
TypeSig.all)),
(in, conf, p, r) => new BinaryExprMeta[ElementAt](in, conf, p, r) {
override def tagExprForGpu(): Unit = {
// To distinguish the supported nested type between Array and Map
val checks = in.left.dataType match {
case _: MapType =>
// Match exactly with the checks for GetMapValue
ExprChecks.binaryProject(TypeSig.STRING, TypeSig.all,
("map", TypeSig.MAP.nested(TypeSig.STRING), TypeSig.MAP.nested(TypeSig.all)),
("key", TypeSig.lit(TypeEnum.STRING), TypeSig.all))
ExprChecks.binaryProject(
(TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.NULL +
TypeSig.DECIMAL_128 + TypeSig.MAP).nested(),
TypeSig.all,
("map", TypeSig.MAP.nested(TypeSig.commonCudfTypes), TypeSig.MAP.nested(TypeSig.all)),
("key", TypeSig.commonCudfTypesLit(), TypeSig.all))
case _: ArrayType =>
// Match exactly with the checks for GetArrayItem
ExprChecks.binaryProject(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2577,9 +2577,12 @@ object GpuOverrides extends Logging {
(in, conf, p, r) => new GpuGetArrayItemMeta(in, conf, p, r)),
expr[GetMapValue](
"Gets Value from a Map based on a key",
ExprChecks.binaryProject(TypeSig.STRING, TypeSig.all,
("map", TypeSig.MAP.nested(TypeSig.STRING), TypeSig.MAP.nested(TypeSig.all)),
("key", TypeSig.lit(TypeEnum.STRING), TypeSig.all)),
ExprChecks.binaryProject(
(TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.NULL +
TypeSig.DECIMAL_128 + TypeSig.MAP).nested(),
TypeSig.all,
("map", TypeSig.MAP.nested(TypeSig.commonCudfTypes), TypeSig.MAP.nested(TypeSig.all)),
("key", TypeSig.commonCudfTypesLit(), TypeSig.all)),
(in, conf, p, r) => new GpuGetMapValueMeta(in, conf, p, r)),
expr[ElementAt](
"Returns element of array at given(1-based) index in value if column is array. " +
Expand All @@ -2589,24 +2592,25 @@ object GpuOverrides extends Logging {
TypeSig.DECIMAL_128 + TypeSig.MAP).nested(), TypeSig.all,
("array/map", TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.ARRAY +
TypeSig.STRUCT + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.MAP) +
TypeSig.MAP.nested(TypeSig.STRING)
.withPsNote(TypeEnum.MAP ,"If it's map, only string is supported."),
TypeSig.MAP.nested(TypeSig.commonCudfTypes)
.withPsNote(TypeEnum.MAP ,"If it's map, only primitive key types are supported."),
TypeSig.ARRAY.nested(TypeSig.all) + TypeSig.MAP.nested(TypeSig.all)),
("index/key", (TypeSig.INT + TypeSig.lit(TypeEnum.STRING))
.withPsNote(TypeEnum.INT, "ints are only supported as array indexes, " +
"not as maps keys")
.withPsNote(TypeEnum.STRING, "strings are only supported as map keys, " +
"not array indexes"),
("index/key", (TypeSig.INT + TypeSig.commonCudfTypesLit())
.withPsNote(TypeEnum.INT, "Only ints are supported as array indexes"),
TypeSig.all)),
(in, conf, p, r) => new BinaryExprMeta[ElementAt](in, conf, p, r) {
override def tagExprForGpu(): Unit = {
// To distinguish the supported nested type between Array and Map
val checks = in.left.dataType match {
case _: MapType =>
// Match exactly with the checks for GetMapValue
ExprChecks.binaryProject(TypeSig.STRING, TypeSig.all,
("map", TypeSig.MAP.nested(TypeSig.STRING), TypeSig.MAP.nested(TypeSig.all)),
("key", TypeSig.lit(TypeEnum.STRING), TypeSig.all))
ExprChecks.binaryProject(
(TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.NULL +
TypeSig.DECIMAL_128 + TypeSig.MAP).nested(),
TypeSig.all,
("map", TypeSig.MAP.nested(TypeSig.commonCudfTypes),
TypeSig.MAP.nested(TypeSig.all)),
("key", TypeSig.commonCudfTypesLit(), TypeSig.all))
case _: ArrayType =>
// Match exactly with the checks for GetArrayItem
ExprChecks.binaryProject(
Expand Down
24 changes: 24 additions & 0 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,30 @@ object TypeSig {
def lit(dataType: TypeEnum.Value): TypeSig =
TypeSig.none.withLit(dataType)

/**
* Create a TypeSig that only supports literals of certain given types.
*/
def lit(dataTypes: TypeEnum.ValueSet): TypeSig =
new TypeSig(dataTypes)

/**
* Create a TypeSig that supports only literals of common primitive CUDF types.
*/
def commonCudfTypesLit(): TypeSig = {
lit(TypeEnum.ValueSet(
TypeEnum.BOOLEAN,
TypeEnum.BYTE,
TypeEnum.SHORT,
TypeEnum.INT,
TypeEnum.LONG,
TypeEnum.FLOAT,
TypeEnum.DOUBLE,
TypeEnum.DATE,
TypeEnum.TIMESTAMP,
TypeEnum.STRING
))
}

/**
* Create a TypeSig that has partial support for the given type.
*/
Expand Down

0 comments on commit 66f3e2d

Please sign in to comment.