Skip to content

Commit

Permalink
Implement support for ArrayExists expression (#4973)
Browse files Browse the repository at this point in the history
This PR implements ArrayExists, it has two major phases
1.  first apply function to produce array of Booleans 
2. run segmented reduce ANY to if any of the values are true

Spark 3.x default is the 3VL logic:
- if any element is true the array maps to true
- if no element is true and there is at least one null, the array maps to null
- if no element is true and none is null, the array maps to false

Legacy mode 2VL:
- if any element is true the array maps to true
- if no element is true , the array maps to false

Closes #4815

Signed-off-by: Gera Shegalov <gera@apache.org>
  • Loading branch information
gerashegalov authored Mar 21, 2022
1 parent 96c90b9 commit 5ed86dd
Show file tree
Hide file tree
Showing 5 changed files with 324 additions and 96 deletions.
1 change: 1 addition & 0 deletions docs/configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ Name | SQL Function(s) | Description | Default Value | Notes
<a name="sql.expression.And"></a>spark.rapids.sql.expression.And|`and`|Logical AND|true|None|
<a name="sql.expression.AnsiCast"></a>spark.rapids.sql.expression.AnsiCast| |Convert a column of one type of data into another type|true|None|
<a name="sql.expression.ArrayContains"></a>spark.rapids.sql.expression.ArrayContains|`array_contains`|Returns a boolean if the array contains the passed in key|true|None|
<a name="sql.expression.ArrayExists"></a>spark.rapids.sql.expression.ArrayExists|`exists`|Return true if any element satisfies the predicate LambdaFunction|true|None|
<a name="sql.expression.ArrayMax"></a>spark.rapids.sql.expression.ArrayMax|`array_max`|Returns the maximum value in the array|true|None|
<a name="sql.expression.ArrayMin"></a>spark.rapids.sql.expression.ArrayMin|`array_min`|Returns the minimum value in the array|true|None|
<a name="sql.expression.ArrayTransform"></a>spark.rapids.sql.expression.ArrayTransform|`transform`|Transform elements in an array using the transform function. This is similar to a `map` in functional programming|true|None|
Expand Down
218 changes: 143 additions & 75 deletions docs/supported_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -2030,12 +2030,12 @@ are limited.
<td> </td>
</tr>
<tr>
<td rowSpan="2">ArrayMax</td>
<td rowSpan="2">`array_max`</td>
<td rowSpan="2">Returns the maximum value in the array</td>
<td rowSpan="2">None</td>
<td rowSpan="2">project</td>
<td>input</td>
<td rowSpan="3">ArrayExists</td>
<td rowSpan="3">`exists`</td>
<td rowSpan="3">Return true if any element satisfies the predicate LambdaFunction</td>
<td rowSpan="3">None</td>
<td rowSpan="3">project</td>
<td>argument</td>
<td> </td>
<td> </td>
<td> </td>
Expand All @@ -2050,31 +2050,52 @@ are limited.
<td> </td>
<td> </td>
<td> </td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, ARRAY, STRUCT, UDT</em></td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, UDT</em></td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td>result</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td><em>PS<br/>UTC is only supported TZ for TIMESTAMP</em></td>
<td>S</td>
<td>function</td>
<td>S</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td>result</td>
<td>S</td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td> </td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<th>Expression</th>
Expand Down Expand Up @@ -2103,6 +2124,53 @@ are limited.
<th>UDT</th>
</tr>
<tr>
<td rowSpan="2">ArrayMax</td>
<td rowSpan="2">`array_max`</td>
<td rowSpan="2">Returns the maximum value in the array</td>
<td rowSpan="2">None</td>
<td rowSpan="2">project</td>
<td>input</td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, ARRAY, STRUCT, UDT</em></td>
<td> </td>
<td> </td>
<td> </td>
</tr>
<tr>
<td>result</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td><em>PS<br/>UTC is only supported TZ for TIMESTAMP</em></td>
<td>S</td>
<td>S</td>
<td>S</td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td> </td>
<td><b>NS</b></td>
<td><b>NS</b></td>
</tr>
<tr>
<td rowSpan="2">ArrayMin</td>
<td rowSpan="2">`array_min`</td>
<td rowSpan="2">Returns the minimum value in the array</td>
Expand Down Expand Up @@ -2398,6 +2466,32 @@ are limited.
<td> </td>
</tr>
<tr>
<th>Expression</th>
<th>SQL Functions(s)</th>
<th>Description</th>
<th>Notes</th>
<th>Context</th>
<th>Param/Output</th>
<th>BOOLEAN</th>
<th>BYTE</th>
<th>SHORT</th>
<th>INT</th>
<th>LONG</th>
<th>FLOAT</th>
<th>DOUBLE</th>
<th>DATE</th>
<th>TIMESTAMP</th>
<th>STRING</th>
<th>DECIMAL</th>
<th>NULL</th>
<th>BINARY</th>
<th>CALENDAR</th>
<th>ARRAY</th>
<th>MAP</th>
<th>STRUCT</th>
<th>UDT</th>
</tr>
<tr>
<td rowSpan="2">AtLeastNNonNulls</td>
<td rowSpan="2"> </td>
<td rowSpan="2">Checks if number of non null/Nan values is greater than a given value</td>
Expand Down Expand Up @@ -2445,32 +2539,6 @@ are limited.
<td> </td>
</tr>
<tr>
<th>Expression</th>
<th>SQL Functions(s)</th>
<th>Description</th>
<th>Notes</th>
<th>Context</th>
<th>Param/Output</th>
<th>BOOLEAN</th>
<th>BYTE</th>
<th>SHORT</th>
<th>INT</th>
<th>LONG</th>
<th>FLOAT</th>
<th>DOUBLE</th>
<th>DATE</th>
<th>TIMESTAMP</th>
<th>STRING</th>
<th>DECIMAL</th>
<th>NULL</th>
<th>BINARY</th>
<th>CALENDAR</th>
<th>ARRAY</th>
<th>MAP</th>
<th>STRUCT</th>
<th>UDT</th>
</tr>
<tr>
<td rowSpan="4">Atan</td>
<td rowSpan="4">`atan`</td>
<td rowSpan="4">Inverse tangent</td>
Expand Down Expand Up @@ -2767,6 +2835,32 @@ are limited.
<td> </td>
</tr>
<tr>
<th>Expression</th>
<th>SQL Functions(s)</th>
<th>Description</th>
<th>Notes</th>
<th>Context</th>
<th>Param/Output</th>
<th>BOOLEAN</th>
<th>BYTE</th>
<th>SHORT</th>
<th>INT</th>
<th>LONG</th>
<th>FLOAT</th>
<th>DOUBLE</th>
<th>DATE</th>
<th>TIMESTAMP</th>
<th>STRING</th>
<th>DECIMAL</th>
<th>NULL</th>
<th>BINARY</th>
<th>CALENDAR</th>
<th>ARRAY</th>
<th>MAP</th>
<th>STRUCT</th>
<th>UDT</th>
</tr>
<tr>
<td rowSpan="2">BitLength</td>
<td rowSpan="2">`bit_length`</td>
<td rowSpan="2">The bit length of string data</td>
Expand Down Expand Up @@ -2814,32 +2908,6 @@ are limited.
<td> </td>
</tr>
<tr>
<th>Expression</th>
<th>SQL Functions(s)</th>
<th>Description</th>
<th>Notes</th>
<th>Context</th>
<th>Param/Output</th>
<th>BOOLEAN</th>
<th>BYTE</th>
<th>SHORT</th>
<th>INT</th>
<th>LONG</th>
<th>FLOAT</th>
<th>DOUBLE</th>
<th>DATE</th>
<th>TIMESTAMP</th>
<th>STRING</th>
<th>DECIMAL</th>
<th>NULL</th>
<th>BINARY</th>
<th>CALENDAR</th>
<th>ARRAY</th>
<th>MAP</th>
<th>STRUCT</th>
<th>UDT</th>
</tr>
<tr>
<td rowSpan="6">BitwiseAnd</td>
<td rowSpan="6">`&`</td>
<td rowSpan="6">Returns the bitwise AND of the operands</td>
Expand Down
25 changes: 25 additions & 0 deletions integration_tests/src/main/python/array_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,3 +307,28 @@ def test_get_array_struct_fields(data_gen):
max_length=6)
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, array_struct_gen).selectExpr('a.child0'))

@pytest.mark.parametrize('data_gen', [ArrayGen(string_gen), ArrayGen(int_gen)])
@pytest.mark.parametrize('threeVL', [
pytest.param(False, id='3VL:off'),
pytest.param(True, id='3VL:on'),
])
def test_array_exists(data_gen, threeVL):
def do_it(spark):
columns = ['a']
element_type = data_gen.data_type.elementType
if isinstance(element_type, IntegralType):
columns.extend([
'exists(a, item -> item % 2 = 0) as exists_even',
'exists(a, item -> item < 0) as exists_negative',
'exists(a, item -> item >= 0) as exists_non_negative'
])

if isinstance(element_type, StringType):
columns.extend(['exists(a, entry -> length(entry) > 5) as exists_longer_than_5'])

return unary_op_df(spark, data_gen).selectExpr(columns)

assert_gpu_and_cpu_are_equal_collect(do_it, conf= {
'spark.sql.legacy.followThreeValuedLogicInArrayExists' : threeVL,
})
Original file line number Diff line number Diff line change
Expand Up @@ -2857,6 +2857,25 @@ object GpuOverrides extends Logging {
GpuArrayTransform(childExprs.head.convertToGpu(), childExprs(1).convertToGpu())
}
}),
expr[ArrayExists](
"Return true if any element satisfies the predicate LambdaFunction",
ExprChecks.projectOnly(TypeSig.BOOLEAN, TypeSig.BOOLEAN,
Seq(
ParamCheck("argument",
TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL +
TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP),
TypeSig.ARRAY.nested(TypeSig.all)),
ParamCheck("function", TypeSig.BOOLEAN, TypeSig.BOOLEAN))),
(in, conf, p, r) => new ExprMeta[ArrayExists](in, conf, p, r) {
override def convertToGpu(): GpuExpression = {
GpuArrayExists(
childExprs.head.convertToGpu(),
childExprs(1).convertToGpu(),
SQLConf.get.getConf(SQLConf.LEGACY_ARRAY_EXISTS_FOLLOWS_THREE_VALUED_LOGIC)
)
}
}),

expr[TransformKeys](
"Transform keys in a map using a transform function",
ExprChecks.projectOnly(TypeSig.MAP.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 +
Expand Down
Loading

0 comments on commit 5ed86dd

Please sign in to comment.