Skip to content

Commit

Permalink
Add in support for lists in some joins (#2702)
Browse files Browse the repository at this point in the history
Signed-off-by: Robert (Bobby) Evans <bobby@apache.org>
  • Loading branch information
revans2 authored Jun 11, 2021
1 parent 80383bb commit 44a48ca
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 8 deletions.
4 changes: 2 additions & 2 deletions docs/supported_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -650,7 +650,7 @@ Accelerator supports are described below.
<td>S</td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS* (missing nested BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT)</em></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
Expand All @@ -673,7 +673,7 @@ Accelerator supports are described below.
<td>S</td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS* (missing nested BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT)</em></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
Expand Down
8 changes: 4 additions & 4 deletions integration_tests/src/main/python/join_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def do_join(spark):
# local sort because of https://github.com/NVIDIA/spark-rapids/issues/84
# After 3.1.0 is the min spark version we can drop this
@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', all_gen, ids=idfn)
@pytest.mark.parametrize('data_gen', all_gen + single_level_array_gens, ids=idfn)
@pytest.mark.parametrize('batch_size', ['100', '1g'], ids=idfn) # set the batch size so we can test multiple stream batches
def test_cartesean_join(data_gen, batch_size):
def do_join(spark):
Expand All @@ -221,7 +221,7 @@ def do_join(spark):
@ignore_order(local=True)
@pytest.mark.xfail(condition=is_databricks_runtime(),
reason='https://github.com/NVIDIA/spark-rapids/issues/334')
@pytest.mark.parametrize('data_gen', all_gen, ids=idfn)
@pytest.mark.parametrize('data_gen', all_gen + single_level_array_gens, ids=idfn)
@pytest.mark.parametrize('batch_size', ['100', '1g'], ids=idfn) # set the batch size so we can test multiple stream batches
def test_cartesean_join_special_case_count(data_gen, batch_size):
def do_join(spark):
Expand Down Expand Up @@ -249,7 +249,7 @@ def do_join(spark):
# local sort because of https://github.com/NVIDIA/spark-rapids/issues/84
# After 3.1.0 is the min spark version we can drop this
@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', all_gen, ids=idfn)
@pytest.mark.parametrize('data_gen', all_gen + single_level_array_gens, ids=idfn)
@pytest.mark.parametrize('batch_size', ['100', '1g'], ids=idfn) # set the batch size so we can test multiple stream batches
def test_broadcast_nested_loop_join(data_gen, batch_size):
def do_join(spark):
Expand All @@ -262,7 +262,7 @@ def do_join(spark):
# local sort because of https://github.com/NVIDIA/spark-rapids/issues/84
# After 3.1.0 is the min spark version we can drop this
@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', all_gen, ids=idfn)
@pytest.mark.parametrize('data_gen', all_gen + single_level_array_gens, ids=idfn)
@pytest.mark.parametrize('batch_size', ['100', '1g'], ids=idfn) # set the batch size so we can test multiple stream batches
def test_broadcast_nested_loop_join_special_case_count(data_gen, batch_size):
def do_join(spark):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2962,11 +2962,15 @@ object GpuOverrides {
(exchange, conf, p, r) => new GpuBroadcastMeta(exchange, conf, p, r)),
exec[BroadcastNestedLoopJoinExec](
"Implementation of join using brute force",
ExecChecks(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL, TypeSig.all),
ExecChecks(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL +
TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL),
TypeSig.all),
(join, conf, p, r) => new GpuBroadcastNestedLoopJoinMeta(join, conf, p, r)),
exec[CartesianProductExec](
"Implementation of join using brute force",
ExecChecks(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL, TypeSig.all),
ExecChecks(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL +
TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL),
TypeSig.all),
(join, conf, p, r) => new SparkPlanMeta[CartesianProductExec](join, conf, p, r) {
val condition: Option[BaseExprMeta[_]] =
join.condition.map(GpuOverrides.wrapExpr(_, conf, Some(this)))
Expand Down

0 comments on commit 44a48ca

Please sign in to comment.