Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
Update all versions
  • Loading branch information
winningsix committed Nov 24, 2023
1 parent 9538f36 commit 73bb81f
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 31 deletions.
5 changes: 4 additions & 1 deletion integration_tests/src/main/python/join_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_fallback_collect, assert_cpu_and_gpu_are_equal_collect_with_capture
from conftest import is_databricks_runtime, is_emr_runtime
from data_gen import *
from marks import ignore_order, allow_non_gpu, incompat, validate_execs_in_gpu_plan
from marks import ignore_order, allow_non_gpu, incompat, validate_execs_in_gpu_plan, \
datagen_overrides
from spark_session import with_cpu_session, is_before_spark_330, is_databricks113_or_later, is_databricks_runtime

pytestmark = [pytest.mark.nightly_resource_consuming_test]
Expand Down Expand Up @@ -427,6 +428,7 @@ def do_join(spark):
# Allowing non Gpu for ShuffleExchangeExec is mainly for Databricks where its exchange is CPU based ('Exchange SinglePartition, EXECUTOR_BROADCAST').
db_113_cpu_bhj_join_allow=["ShuffleExchangeExec"] if is_databricks113_or_later() else []


@allow_non_gpu(*db_113_cpu_bhj_join_allow)
@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', [IntegerGen(), LongGen()], ids=idfn)
Expand All @@ -439,6 +441,7 @@ def do_join(spark):
return left.join(right.hint("broadcast"), ((left.b == right.r_b) & (f.round(left.a).cast('integer') > f.round(f.log(right.r_a).cast('integer')))), join_type)
assert_gpu_and_cpu_are_equal_collect(do_join, conf = {"spark.rapids.sql.castFloatToIntegralTypes.enabled": True})


@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', [byte_gen, short_gen, int_gen, long_gen,
float_gen, double_gen,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,20 @@ abstract class GpuBroadcastHashJoinExecBase(
postBuildCondition.map(expr => expr.toAttribute)
}

override lazy val leftOutput: Seq[Attribute] =
if (!postBuildCondition.isEmpty && buildSide == GpuBuildLeft) {
postBuildCondition.map(expr => expr.toAttribute)
} else {
left.output
}

override lazy val rightOutput: Seq[Attribute] =
if (!postBuildCondition.isEmpty && buildSide == GpuBuildRight) {
postBuildCondition.map(expr => expr.toAttribute)
} else {
right.output
}

// Needed when original join condition contains non-ast-able condition. Some conditions are
// extracted and evaluated as project node on top of built batch.
private lazy val postBuildProj: Option[GpuTieredProject] = if (!postBuildCondition.isEmpty) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -902,7 +902,15 @@ trait GpuHashJoin extends GpuExec {
case GpuBuildRight => (right, left)
}

// This can be override when a post-build project happens.
// This can be overridden when a post-build project happens. When some non-ast-supported cases
// exist in join condition, it will be extracted from join condition and wrapped as pre-project
// nodes.
protected lazy val leftOutput = left.output;
protected lazy val rightOutput = right.output;

// This can be overridden when a post-build project happens. When some non-ast-supported cases
// exist in join condition, it will be extracted from join condition and wrapped as pre-project
// nodes.
protected lazy val buildAttrList: List[Attribute] = buildPlan.output.toList
protected lazy val streamedAttrList: List[Attribute] = streamedPlan.output.toList

Expand All @@ -919,7 +927,7 @@ trait GpuHashJoin extends GpuExec {
}

override def output: Seq[Attribute] = {
GpuHashJoin.output(joinType, left.output, right.output)
GpuHashJoin.output(joinType, leftOutput, rightOutput)
}

// If we have a single batch streamed in then we will produce a single batch of output
Expand Down Expand Up @@ -972,8 +980,8 @@ trait GpuHashJoin extends GpuExec {
GpuHashJoin.anyNullableStructChild(buildKeys)

protected lazy val (boundBuildKeys, boundStreamKeys) = {
val lkeys = GpuBindReferences.bindGpuReferences(leftKeys, left.output)
val rkeys = GpuBindReferences.bindGpuReferences(rightKeys, right.output)
val lkeys = GpuBindReferences.bindGpuReferences(leftKeys, leftOutput)
val rkeys = GpuBindReferences.bindGpuReferences(rightKeys, rightOutput)

buildSide match {
case GpuBuildLeft => (lkeys, rkeys)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,20 +65,15 @@ class GpuBroadcastHashJoinMeta(
.output, right.output, true)

// Reconstruct the child with wrapped project node if needed.
val leftChild = if (!leftExpr.isEmpty && buildSide != GpuBuildLeft) {
GpuProjectExec(leftExpr ++ left.output, left)(true)
} else {
left
}
val rightChild = if (!rightExpr.isEmpty && buildSide == GpuBuildLeft) {
GpuProjectExec(rightExpr ++ right.output, right)(true)
} else {
right
}
val leftChild =
if (!leftExpr.isEmpty) GpuProjectExec(leftExpr ++ left.output, left)(true) else left
val rightChild =
if (!rightExpr.isEmpty) GpuProjectExec(rightExpr ++ right.output, right)(true) else right

val (postBuildAttr, postBuildCondition) = if (buildSide == GpuBuildLeft) {
(left.output, leftExpr ++ left.output)
(leftExpr.map(_.toAttribute) ++ left.output, leftExpr ++ left.output)
} else {
(right.output, rightExpr ++ right.output)
(rightExpr.map(_.toAttribute) ++ right.output, rightExpr ++ right.output)
}

val joinExec = GpuBroadcastHashJoinExec(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,22 +62,29 @@ class GpuBroadcastHashJoinMeta(
.output, right.output, true)

// Reconstruct the child with wrapped project node if needed.
val leftChild = if (!leftExpr.isEmpty && buildSide != GpuBuildLeft) {
GpuProjectExec(leftExpr ++ left.output, left)(true)
} else {
left
}
val rightChild = if (!rightExpr.isEmpty && buildSide == GpuBuildLeft) {
GpuProjectExec(rightExpr ++ right.output, right)(true)
} else {
right
}
// val leftChild = if (!leftExpr.isEmpty && buildSide != GpuBuildLeft) {
// GpuProjectExec(leftExpr ++ left.output, left)(true)
// } else {
// left
// }
// val rightChild = if (!rightExpr.isEmpty && buildSide == GpuBuildLeft) {
// GpuProjectExec(rightExpr ++ right.output, right)(true)
// } else {
// right
// }

val leftChild =
if (!leftExpr.isEmpty) GpuProjectExec(leftExpr ++ left.output, left)(true) else left
val rightChild =
if (!rightExpr.isEmpty) GpuProjectExec(rightExpr ++ right.output, right)(true) else right

val (postBuildAttr, postBuildCondition) = if (buildSide == GpuBuildLeft) {
(leftExpr.map(_.toAttribute) ++ left.output, leftExpr ++ left.output)
} else {
(rightExpr.map(_.toAttribute) ++ right.output, rightExpr ++ right.output)
}


val joinExec = GpuBroadcastHashJoinExec(
leftKeys.map(_.convertToGpu()),
rightKeys.map(_.convertToGpu()),
Expand Down Expand Up @@ -163,8 +170,7 @@ case class GpuBroadcastHashJoinExec(
case ReusedExchangeExec(_, g: GpuShuffleExchangeExec) => g
case _ => throw new IllegalStateException(s"cannot locate GPU shuffle in $p")
}
// Use getBroadcastPlan to get child of project. This happens when non-AST condition split case
getBroadcastPlan(buildPlan) match {
buildPlan match {
case gpu: GpuShuffleExchangeExec => gpu
case sqse: ShuffleQueryStageExec => from(sqse)
case reused: ReusedExchangeExec => reused.child.asInstanceOf[GpuShuffleExchangeExec]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,11 @@ class GpuBroadcastHashJoinMeta(
} else {
right
}

val (postBuildAttr, postBuildCondition) = if (buildSide == GpuBuildLeft) {
(left.output.toList, leftExpr ++ left.output)
(leftExpr.map(_.toAttribute) ++ left.output, leftExpr ++ left.output)
} else {
(right.output.toList, rightExpr ++ right.output)
(rightExpr.map(_.toAttribute) ++ right.output, rightExpr ++ right.output)
}

val joinExec = GpuBroadcastHashJoinExec(
Expand Down

0 comments on commit 73bb81f

Please sign in to comment.