Skip to content

Commit

Permalink
Add in the GpuArrayFilter command (#10763)
Browse files Browse the repository at this point in the history
Signed-off-by: Robert (Bobby) Evans <bobby@apache.org>
  • Loading branch information
revans2 authored May 6, 2024
1 parent 128aab8 commit 71ecc9f
Show file tree
Hide file tree
Showing 51 changed files with 325 additions and 91 deletions.
1 change: 1 addition & 0 deletions docs/additional-functionality/advanced_configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ Name | SQL Function(s) | Description | Default Value | Notes
<a name="sql.expression.ArrayContains"></a>spark.rapids.sql.expression.ArrayContains|`array_contains`|Returns a boolean if the array contains the passed in key|true|None|
<a name="sql.expression.ArrayExcept"></a>spark.rapids.sql.expression.ArrayExcept|`array_except`|Returns an array of the elements in array1 but not in array2, without duplicates|true|This is not 100% compatible with the Spark version because the GPU implementation treats -0.0 and 0.0 as equal, but the CPU implementation currently does not (see SPARK-39845). Also, Apache Spark 3.1.3 fixed issue SPARK-36741 where NaNs in these set like operators were not treated as being equal. We have chosen to break with compatibility for the older versions of Spark in this instance and handle NaNs the same as 3.1.3+|
<a name="sql.expression.ArrayExists"></a>spark.rapids.sql.expression.ArrayExists|`exists`|Return true if any element satisfies the predicate LambdaFunction|true|None|
<a name="sql.expression.ArrayFilter"></a>spark.rapids.sql.expression.ArrayFilter|`filter`|Filter an input array using a given predicate|true|None|
<a name="sql.expression.ArrayIntersect"></a>spark.rapids.sql.expression.ArrayIntersect|`array_intersect`|Returns an array of the elements in the intersection of array1 and array2, without duplicates|true|This is not 100% compatible with the Spark version because the GPU implementation treats -0.0 and 0.0 as equal, but the CPU implementation currently does not (see SPARK-39845). Also, Apache Spark 3.1.3 fixed issue SPARK-36741 where NaNs in these set like operators were not treated as being equal. We have chosen to break with compatibility for the older versions of Spark in this instance and handle NaNs the same as 3.1.3+|
<a name="sql.expression.ArrayMax"></a>spark.rapids.sql.expression.ArrayMax|`array_max`|Returns the maximum value in the array|true|None|
<a name="sql.expression.ArrayMin"></a>spark.rapids.sql.expression.ArrayMin|`array_min`|Returns the minimum value in the array|true|None|
Expand Down
224 changes: 146 additions & 78 deletions docs/supported_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -2288,6 +2288,74 @@ are limited.
<td> </td>
</tr>
<tr>
<td rowSpan="3">ArrayFilter</td>
<td rowSpan="3">`filter`</td>
<td rowSpan="3">Filter an input array using a given predicate</td>
<td rowSpan="3">None</td>
<td rowSpan="3">project</td>
<td>argument</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, UDT</em></td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td>function</td>
<td>S</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td>result</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, UDT</em></td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td rowSpan="3">ArrayIntersect</td>
<td rowSpan="3">`array_intersect`</td>
<td rowSpan="3">Returns an array of the elements in the intersection of array1 and array2, without duplicates</td>
Expand Down Expand Up @@ -2518,6 +2586,32 @@ are limited.
<td> </td>
</tr>
<tr>
<th>Expression</th>
<th>SQL Functions(s)</th>
<th>Description</th>
<th>Notes</th>
<th>Context</th>
<th>Param/Output</th>
<th>BOOLEAN</th>
<th>BYTE</th>
<th>SHORT</th>
<th>INT</th>
<th>LONG</th>
<th>FLOAT</th>
<th>DOUBLE</th>
<th>DATE</th>
<th>TIMESTAMP</th>
<th>STRING</th>
<th>DECIMAL</th>
<th>NULL</th>
<th>BINARY</th>
<th>CALENDAR</th>
<th>ARRAY</th>
<th>MAP</th>
<th>STRUCT</th>
<th>UDT</th>
</tr>
<tr>
<td rowSpan="3">ArrayRepeat</td>
<td rowSpan="3">`array_repeat`</td>
<td rowSpan="3">Returns the array containing the given input value (left) count (right) times</td>
Expand Down Expand Up @@ -2586,32 +2680,6 @@ are limited.
<td> </td>
</tr>
<tr>
<th>Expression</th>
<th>SQL Functions(s)</th>
<th>Description</th>
<th>Notes</th>
<th>Context</th>
<th>Param/Output</th>
<th>BOOLEAN</th>
<th>BYTE</th>
<th>SHORT</th>
<th>INT</th>
<th>LONG</th>
<th>FLOAT</th>
<th>DOUBLE</th>
<th>DATE</th>
<th>TIMESTAMP</th>
<th>STRING</th>
<th>DECIMAL</th>
<th>NULL</th>
<th>BINARY</th>
<th>CALENDAR</th>
<th>ARRAY</th>
<th>MAP</th>
<th>STRUCT</th>
<th>UDT</th>
</tr>
<tr>
<td rowSpan="3">ArrayTransform</td>
<td rowSpan="3">`transform`</td>
<td rowSpan="3">Transform elements in an array using the transform function. This is similar to a `map` in functional programming</td>
Expand Down Expand Up @@ -2910,6 +2978,32 @@ are limited.
<td> </td>
</tr>
<tr>
<th>Expression</th>
<th>SQL Functions(s)</th>
<th>Description</th>
<th>Notes</th>
<th>Context</th>
<th>Param/Output</th>
<th>BOOLEAN</th>
<th>BYTE</th>
<th>SHORT</th>
<th>INT</th>
<th>LONG</th>
<th>FLOAT</th>
<th>DOUBLE</th>
<th>DATE</th>
<th>TIMESTAMP</th>
<th>STRING</th>
<th>DECIMAL</th>
<th>NULL</th>
<th>BINARY</th>
<th>CALENDAR</th>
<th>ARRAY</th>
<th>MAP</th>
<th>STRUCT</th>
<th>UDT</th>
</tr>
<tr>
<td rowSpan="4">Asin</td>
<td rowSpan="4">`asin`</td>
<td rowSpan="4">Inverse sine</td>
Expand Down Expand Up @@ -3000,32 +3094,6 @@ are limited.
<td> </td>
</tr>
<tr>
<th>Expression</th>
<th>SQL Functions(s)</th>
<th>Description</th>
<th>Notes</th>
<th>Context</th>
<th>Param/Output</th>
<th>BOOLEAN</th>
<th>BYTE</th>
<th>SHORT</th>
<th>INT</th>
<th>LONG</th>
<th>FLOAT</th>
<th>DOUBLE</th>
<th>DATE</th>
<th>TIMESTAMP</th>
<th>STRING</th>
<th>DECIMAL</th>
<th>NULL</th>
<th>BINARY</th>
<th>CALENDAR</th>
<th>ARRAY</th>
<th>MAP</th>
<th>STRUCT</th>
<th>UDT</th>
</tr>
<tr>
<td rowSpan="4">Asinh</td>
<td rowSpan="4">`asinh`</td>
<td rowSpan="4">Inverse hyperbolic sine</td>
Expand Down Expand Up @@ -3343,6 +3411,32 @@ are limited.
<td> </td>
</tr>
<tr>
<th>Expression</th>
<th>SQL Functions(s)</th>
<th>Description</th>
<th>Notes</th>
<th>Context</th>
<th>Param/Output</th>
<th>BOOLEAN</th>
<th>BYTE</th>
<th>SHORT</th>
<th>INT</th>
<th>LONG</th>
<th>FLOAT</th>
<th>DOUBLE</th>
<th>DATE</th>
<th>TIMESTAMP</th>
<th>STRING</th>
<th>DECIMAL</th>
<th>NULL</th>
<th>BINARY</th>
<th>CALENDAR</th>
<th>ARRAY</th>
<th>MAP</th>
<th>STRUCT</th>
<th>UDT</th>
</tr>
<tr>
<td rowSpan="2">AttributeReference</td>
<td rowSpan="2"> </td>
<td rowSpan="2">References an input column</td>
Expand Down Expand Up @@ -3391,32 +3485,6 @@ are limited.
<td><b>NS</b></td>
</tr>
<tr>
<th>Expression</th>
<th>SQL Functions(s)</th>
<th>Description</th>
<th>Notes</th>
<th>Context</th>
<th>Param/Output</th>
<th>BOOLEAN</th>
<th>BYTE</th>
<th>SHORT</th>
<th>INT</th>
<th>LONG</th>
<th>FLOAT</th>
<th>DOUBLE</th>
<th>DATE</th>
<th>TIMESTAMP</th>
<th>STRING</th>
<th>DECIMAL</th>
<th>NULL</th>
<th>BINARY</th>
<th>CALENDAR</th>
<th>ARRAY</th>
<th>MAP</th>
<th>STRUCT</th>
<th>UDT</th>
</tr>
<tr>
<td rowSpan="3">BRound</td>
<td rowSpan="3">`bround`</td>
<td rowSpan="3">Round an expression to d decimal places using HALF_EVEN rounding mode</td>
Expand Down
29 changes: 28 additions & 1 deletion integration_tests/src/main/python/array_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020-2023, NVIDIA CORPORATION.
# Copyright (c) 2020-2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -449,6 +449,33 @@ def do_it(spark):
})


@pytest.mark.parametrize('data_gen', [
ArrayGen(string_gen),
ArrayGen(int_gen),
ArrayGen(ArrayGen(int_gen)),
ArrayGen(ArrayGen(StructGen([["A", int_gen], ["B", string_gen]])))], ids=idfn)
def test_array_filter(data_gen):
def do_it(spark):
columns = ['a']
element_type = data_gen.data_type.elementType
if isinstance(element_type, IntegralType):
columns.extend([
'filter(a, item -> item % 2 = 0) as filter_even',
'filter(a, item -> item < 0) as filter_negative',
'filter(a, item -> item >= 0) as filter_non_negative'
])

if isinstance(element_type, StringType):
columns.extend(['filter(a, entry -> length(entry) > 5) as filter_longer_than_5'])

if isinstance(element_type, ArrayType):
columns.extend(['filter(a, entry -> size(entry) < 5) as filter_shorter_than_5'])

return unary_op_df(spark, data_gen).selectExpr(columns)

assert_gpu_and_cpu_are_equal_collect(do_it)


array_zips_gen = array_gens_sample + [ArrayGen(map_string_string_gen[0], max_length=5),
ArrayGen(BinaryGen(max_length=5), max_length=5)]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2737,6 +2737,25 @@ object GpuOverrides extends Logging {
)
}
}),
expr[ArrayFilter](
"Filter an input array using a given predicate",
ExprChecks.projectOnly(TypeSig.ARRAY.nested(TypeSig.commonCudfTypes +
TypeSig.DECIMAL_128 + TypeSig.NULL + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP),
TypeSig.ARRAY.nested(TypeSig.all),
Seq(
ParamCheck("argument",
TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL +
TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP),
TypeSig.ARRAY.nested(TypeSig.all)),
ParamCheck("function", TypeSig.BOOLEAN, TypeSig.BOOLEAN))),
(in, conf, p, r) => new ExprMeta[ArrayFilter](in, conf, p, r) {
override def convertToGpu(): GpuExpression = {
GpuArrayFilter(
childExprs.head.convertToGpu(),
childExprs(1).convertToGpu()
)
}
}),
// TODO: fix the signature https://github.com/NVIDIA/spark-rapids/issues/5327
expr[ArraysZip](
"Returns a merged array of structs in which the N-th struct contains" +
Expand Down
Loading

0 comments on commit 71ecc9f

Please sign in to comment.