Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add in the GpuArrayFilter command #10763

Merged
merged 3 commits into from
May 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/additional-functionality/advanced_configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ Name | SQL Function(s) | Description | Default Value | Notes
<a name="sql.expression.ArrayContains"></a>spark.rapids.sql.expression.ArrayContains|`array_contains`|Returns a boolean if the array contains the passed in key|true|None|
<a name="sql.expression.ArrayExcept"></a>spark.rapids.sql.expression.ArrayExcept|`array_except`|Returns an array of the elements in array1 but not in array2, without duplicates|true|This is not 100% compatible with the Spark version because the GPU implementation treats -0.0 and 0.0 as equal, but the CPU implementation currently does not (see SPARK-39845). Also, Apache Spark 3.1.3 fixed issue SPARK-36741 where NaNs in these set like operators were not treated as being equal. We have chosen to break with compatibility for the older versions of Spark in this instance and handle NaNs the same as 3.1.3+|
<a name="sql.expression.ArrayExists"></a>spark.rapids.sql.expression.ArrayExists|`exists`|Return true if any element satisfies the predicate LambdaFunction|true|None|
<a name="sql.expression.ArrayFilter"></a>spark.rapids.sql.expression.ArrayFilter|`filter`|Filter an input array using a given predicate|true|None|
<a name="sql.expression.ArrayIntersect"></a>spark.rapids.sql.expression.ArrayIntersect|`array_intersect`|Returns an array of the elements in the intersection of array1 and array2, without duplicates|true|This is not 100% compatible with the Spark version because the GPU implementation treats -0.0 and 0.0 as equal, but the CPU implementation currently does not (see SPARK-39845). Also, Apache Spark 3.1.3 fixed issue SPARK-36741 where NaNs in these set like operators were not treated as being equal. We have chosen to break with compatibility for the older versions of Spark in this instance and handle NaNs the same as 3.1.3+|
<a name="sql.expression.ArrayMax"></a>spark.rapids.sql.expression.ArrayMax|`array_max`|Returns the maximum value in the array|true|None|
<a name="sql.expression.ArrayMin"></a>spark.rapids.sql.expression.ArrayMin|`array_min`|Returns the minimum value in the array|true|None|
Expand Down
224 changes: 146 additions & 78 deletions docs/supported_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -2288,6 +2288,74 @@ are limited.
<td> </td>
</tr>
<tr>
<td rowSpan="3">ArrayFilter</td>
<td rowSpan="3">`filter`</td>
<td rowSpan="3">Filter an input array using a given predicate</td>
<td rowSpan="3">None</td>
<td rowSpan="3">project</td>
<td>argument</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, UDT</em></td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td>function</td>
<td>S</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td>result</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, UDT</em></td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td rowSpan="3">ArrayIntersect</td>
<td rowSpan="3">`array_intersect`</td>
<td rowSpan="3">Returns an array of the elements in the intersection of array1 and array2, without duplicates</td>
Expand Down Expand Up @@ -2518,6 +2586,32 @@ are limited.
<td> </td>
</tr>
<tr>
<th>Expression</th>
<th>SQL Functions(s)</th>
<th>Description</th>
<th>Notes</th>
<th>Context</th>
<th>Param/Output</th>
<th>BOOLEAN</th>
<th>BYTE</th>
<th>SHORT</th>
<th>INT</th>
<th>LONG</th>
<th>FLOAT</th>
<th>DOUBLE</th>
<th>DATE</th>
<th>TIMESTAMP</th>
<th>STRING</th>
<th>DECIMAL</th>
<th>NULL</th>
<th>BINARY</th>
<th>CALENDAR</th>
<th>ARRAY</th>
<th>MAP</th>
<th>STRUCT</th>
<th>UDT</th>
</tr>
<tr>
<td rowSpan="3">ArrayRepeat</td>
<td rowSpan="3">`array_repeat`</td>
<td rowSpan="3">Returns the array containing the given input value (left) count (right) times</td>
Expand Down Expand Up @@ -2586,32 +2680,6 @@ are limited.
<td> </td>
</tr>
<tr>
<th>Expression</th>
<th>SQL Functions(s)</th>
<th>Description</th>
<th>Notes</th>
<th>Context</th>
<th>Param/Output</th>
<th>BOOLEAN</th>
<th>BYTE</th>
<th>SHORT</th>
<th>INT</th>
<th>LONG</th>
<th>FLOAT</th>
<th>DOUBLE</th>
<th>DATE</th>
<th>TIMESTAMP</th>
<th>STRING</th>
<th>DECIMAL</th>
<th>NULL</th>
<th>BINARY</th>
<th>CALENDAR</th>
<th>ARRAY</th>
<th>MAP</th>
<th>STRUCT</th>
<th>UDT</th>
</tr>
<tr>
<td rowSpan="3">ArrayTransform</td>
<td rowSpan="3">`transform`</td>
<td rowSpan="3">Transform elements in an array using the transform function. This is similar to a `map` in functional programming</td>
Expand Down Expand Up @@ -2910,6 +2978,32 @@ are limited.
<td> </td>
</tr>
<tr>
<th>Expression</th>
<th>SQL Functions(s)</th>
<th>Description</th>
<th>Notes</th>
<th>Context</th>
<th>Param/Output</th>
<th>BOOLEAN</th>
<th>BYTE</th>
<th>SHORT</th>
<th>INT</th>
<th>LONG</th>
<th>FLOAT</th>
<th>DOUBLE</th>
<th>DATE</th>
<th>TIMESTAMP</th>
<th>STRING</th>
<th>DECIMAL</th>
<th>NULL</th>
<th>BINARY</th>
<th>CALENDAR</th>
<th>ARRAY</th>
<th>MAP</th>
<th>STRUCT</th>
<th>UDT</th>
</tr>
<tr>
<td rowSpan="4">Asin</td>
<td rowSpan="4">`asin`</td>
<td rowSpan="4">Inverse sine</td>
Expand Down Expand Up @@ -3000,32 +3094,6 @@ are limited.
<td> </td>
</tr>
<tr>
<th>Expression</th>
<th>SQL Functions(s)</th>
<th>Description</th>
<th>Notes</th>
<th>Context</th>
<th>Param/Output</th>
<th>BOOLEAN</th>
<th>BYTE</th>
<th>SHORT</th>
<th>INT</th>
<th>LONG</th>
<th>FLOAT</th>
<th>DOUBLE</th>
<th>DATE</th>
<th>TIMESTAMP</th>
<th>STRING</th>
<th>DECIMAL</th>
<th>NULL</th>
<th>BINARY</th>
<th>CALENDAR</th>
<th>ARRAY</th>
<th>MAP</th>
<th>STRUCT</th>
<th>UDT</th>
</tr>
<tr>
<td rowSpan="4">Asinh</td>
<td rowSpan="4">`asinh`</td>
<td rowSpan="4">Inverse hyperbolic sine</td>
Expand Down Expand Up @@ -3343,6 +3411,32 @@ are limited.
<td> </td>
</tr>
<tr>
<th>Expression</th>
<th>SQL Functions(s)</th>
<th>Description</th>
<th>Notes</th>
<th>Context</th>
<th>Param/Output</th>
<th>BOOLEAN</th>
<th>BYTE</th>
<th>SHORT</th>
<th>INT</th>
<th>LONG</th>
<th>FLOAT</th>
<th>DOUBLE</th>
<th>DATE</th>
<th>TIMESTAMP</th>
<th>STRING</th>
<th>DECIMAL</th>
<th>NULL</th>
<th>BINARY</th>
<th>CALENDAR</th>
<th>ARRAY</th>
<th>MAP</th>
<th>STRUCT</th>
<th>UDT</th>
</tr>
<tr>
<td rowSpan="2">AttributeReference</td>
<td rowSpan="2"> </td>
<td rowSpan="2">References an input column</td>
Expand Down Expand Up @@ -3391,32 +3485,6 @@ are limited.
<td><b>NS</b></td>
</tr>
<tr>
<th>Expression</th>
<th>SQL Functions(s)</th>
<th>Description</th>
<th>Notes</th>
<th>Context</th>
<th>Param/Output</th>
<th>BOOLEAN</th>
<th>BYTE</th>
<th>SHORT</th>
<th>INT</th>
<th>LONG</th>
<th>FLOAT</th>
<th>DOUBLE</th>
<th>DATE</th>
<th>TIMESTAMP</th>
<th>STRING</th>
<th>DECIMAL</th>
<th>NULL</th>
<th>BINARY</th>
<th>CALENDAR</th>
<th>ARRAY</th>
<th>MAP</th>
<th>STRUCT</th>
<th>UDT</th>
</tr>
<tr>
<td rowSpan="3">BRound</td>
<td rowSpan="3">`bround`</td>
<td rowSpan="3">Round an expression to d decimal places using HALF_EVEN rounding mode</td>
Expand Down
29 changes: 28 additions & 1 deletion integration_tests/src/main/python/array_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020-2023, NVIDIA CORPORATION.
# Copyright (c) 2020-2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -449,6 +449,33 @@ def do_it(spark):
})


@pytest.mark.parametrize('data_gen', [
ArrayGen(string_gen),
ArrayGen(int_gen),
ArrayGen(ArrayGen(int_gen)),
ArrayGen(ArrayGen(StructGen([["A", int_gen], ["B", string_gen]])))], ids=idfn)
def test_array_filter(data_gen):
def do_it(spark):
columns = ['a']
element_type = data_gen.data_type.elementType
if isinstance(element_type, IntegralType):
columns.extend([
'filter(a, item -> item % 2 = 0) as filter_even',
'filter(a, item -> item < 0) as filter_negative',
'filter(a, item -> item >= 0) as filter_non_negative'
])

if isinstance(element_type, StringType):
columns.extend(['filter(a, entry -> length(entry) > 5) as filter_longer_than_5'])

if isinstance(element_type, ArrayType):
columns.extend(['filter(a, entry -> size(entry) < 5) as filter_shorter_than_5'])

return unary_op_df(spark, data_gen).selectExpr(columns)

assert_gpu_and_cpu_are_equal_collect(do_it)


array_zips_gen = array_gens_sample + [ArrayGen(map_string_string_gen[0], max_length=5),
ArrayGen(BinaryGen(max_length=5), max_length=5)]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2737,6 +2737,25 @@ object GpuOverrides extends Logging {
)
}
}),
expr[ArrayFilter](
"Filter an input array using a given predicate",
ExprChecks.projectOnly(TypeSig.ARRAY.nested(TypeSig.commonCudfTypes +
TypeSig.DECIMAL_128 + TypeSig.NULL + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP),
TypeSig.ARRAY.nested(TypeSig.all),
Seq(
ParamCheck("argument",
TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL +
TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP),
TypeSig.ARRAY.nested(TypeSig.all)),
ParamCheck("function", TypeSig.BOOLEAN, TypeSig.BOOLEAN))),
(in, conf, p, r) => new ExprMeta[ArrayFilter](in, conf, p, r) {
override def convertToGpu(): GpuExpression = {
GpuArrayFilter(
childExprs.head.convertToGpu(),
childExprs(1).convertToGpu()
)
}
}),
// TODO: fix the signature https://github.com/NVIDIA/spark-rapids/issues/5327
expr[ArraysZip](
"Returns a merged array of structs in which the N-th struct contains" +
Expand Down
Loading
Loading