diff --git a/integration_tests/src/main/python/join_test.py b/integration_tests/src/main/python/join_test.py index 6660e663c92..9e7f5a13cb9 100644 --- a/integration_tests/src/main/python/join_test.py +++ b/integration_tests/src/main/python/join_test.py @@ -19,8 +19,9 @@ from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_fallback_collect, assert_cpu_and_gpu_are_equal_collect_with_capture from conftest import is_databricks_runtime, is_emr_runtime, is_not_utc from data_gen import * -from marks import ignore_order, allow_non_gpu, incompat, validate_execs_in_gpu_plan -from spark_session import with_cpu_session, is_before_spark_330, is_databricks_runtime +from marks import ignore_order, allow_non_gpu, incompat, validate_execs_in_gpu_plan, \ + datagen_overrides +from spark_session import with_cpu_session, is_before_spark_330, is_databricks113_or_later, is_databricks_runtime pytestmark = [pytest.mark.nightly_resource_consuming_test] @@ -434,6 +435,38 @@ def do_join(spark): return broadcast(left).join(right, left.a > f.log(right.r_a), join_type) assert_gpu_fallback_collect(do_join, 'BroadcastNestedLoopJoinExec') +# Allowing non Gpu for ShuffleExchangeExec is mainly for Databricks where its exchange is CPU based ('Exchange SinglePartition, EXECUTOR_BROADCAST'). +db_113_cpu_bhj_join_allow=["ShuffleExchangeExec"] if is_databricks113_or_later() else [] + + +@allow_non_gpu(*db_113_cpu_bhj_join_allow) +@ignore_order(local=True) +@pytest.mark.parametrize('data_gen', [IntegerGen(), LongGen()], ids=idfn) +@pytest.mark.parametrize('join_type', ['Left', 'Inner', 'LeftSemi', 'LeftAnti'], ids=idfn) +def test_broadcast_hash_join_on_non_ast_condition_without_fallback(data_gen, join_type): + # This is to test BHJ with a condition not fully supported by AST. With extra project nodes wrapped, join can still run on GPU other than fallback. + def do_join(spark): + left, right = create_df(spark, data_gen, 50, 25) + # AST does not support cast or logarithm yet + return left.join(right.hint("broadcast"), ((left.b == right.r_b) & (f.round(left.a).cast('integer') > f.round(f.log(right.r_a).cast('integer')))), join_type) + assert_gpu_and_cpu_are_equal_collect(do_join, conf = {"spark.rapids.sql.castFloatToIntegralTypes.enabled": True}) + + +@allow_non_gpu('BroadcastHashJoinExec', 'BroadcastExchangeExec') +@ignore_order(local=True) +@pytest.mark.parametrize('data_gen', [IntegerGen(), LongGen()], ids=idfn) +@pytest.mark.parametrize('join_type', ['Left', 'LeftSemi', 'LeftAnti'], ids=idfn) +def test_broadcast_hash_join_on_non_ast_condition_fallback(data_gen, join_type): + # This is to test BHJ with a condition not fully supported by AST. Since AST doesn't support double, this query fallback to CPU. + # Inner join is not included since it can be supported by GPU via a post filter. + def do_join(spark): + left, right = create_df(spark, data_gen, 50, 25) + # AST does not support cast or logarithm yet and also it's not able to be split as project + # node those both sides are involved in join condition + return left.join(right.hint("broadcast"), ((left.b == right.r_b) & (left.a.cast('double') > right.r_a.cast('double'))), join_type) + assert_gpu_fallback_collect(do_join, 'BroadcastHashJoinExec') + + @ignore_order(local=True) @pytest.mark.parametrize('data_gen', [byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen, diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/AstUtil.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/AstUtil.scala index 5062d8e4a99..3f68f5d3d60 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/AstUtil.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/AstUtil.scala @@ -16,11 +16,15 @@ package com.nvidia.spark.rapids +import java.io.Serializable + +import com.nvidia.spark.rapids.Arm.withResource import scala.collection.mutable import scala.collection.mutable.ListBuffer -import org.apache.spark.sql.catalyst.expressions.{AttributeSeq, Expression, ExprId, NamedExpression} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSeq, Expression, ExprId, NamedExpression} import org.apache.spark.sql.rapids.catalyst.expressions.{GpuEquivalentExpressions, GpuExpressionEquals} +import org.apache.spark.sql.vectorized.ColumnarBatch object AstUtil { @@ -119,4 +123,164 @@ object AstUtil { } } } + + /** + * Transforms the original join condition into extra filter/project when necessary. + * It's targeted for some cases join condition is not fully evaluated by ast. + * Based on join condition, it can be transformed into three major strategies: + * (1) [NoopJoinCondSplit]: noop when join condition can be fully evaluated with ast. + * (2) [JoinCondSplitAsPostFilter]: entire join condition is pulled out as a post filter + * after join condition. + * (3) [JoinCondSplitAsProject]: extract not supported join condition into pre-project nodes + * on each join child. One extra project node is introduced to remove intermediate attributes. + */ + abstract class JoinCondSplitStrategy(left: Seq[NamedExpression], + right: Seq[NamedExpression], buildSide: GpuBuildSide) extends Serializable { + + // Actual output of build/stream side project due to join condition split + private[this] val (buildOutputAttr, streamOutputAttr) = buildSide match { + case GpuBuildLeft => (joinLeftOutput, joinRightOutput) + case GpuBuildRight => (joinRightOutput, joinLeftOutput) + } + + // This is the left side child of join. In `split as project` strategy, it may be different + // from original left child with extracted join condition attribute. + def leftOutput(): Seq[NamedExpression] = left + + // This is the right side child of join. In `split as project` strategy, it may be different + // from original right child with extracted join condition attribute. + def rightOutput(): Seq[NamedExpression] = right + + def astCondition(): Option[Expression] + + def processBuildSideAndClose(input: ColumnarBatch): ColumnarBatch = input + + def processStreamSideAndClose(input: ColumnarBatch): ColumnarBatch = input + + def processPostJoin(iter: Iterator[ColumnarBatch]): Iterator[ColumnarBatch] = iter + + // This is the left side child of join. In `split as project` strategy, it may be different + // from original left child with extracted join condition attribute. + def joinLeftOutput(): Seq[Attribute] = leftOutput.map(expr => expr.toAttribute) + + // This is the right side child of join. In `split as project` strategy, it may be different + // from original right child with extracted join condition attribute. + def joinRightOutput(): Seq[Attribute] = rightOutput.map(expr => expr.toAttribute) + + // Updated build attribute list after join condition split as project node. + // It may include extra attributes from split join condition. + def buildSideOutput(): Seq[Attribute] = buildOutputAttr + + // Updated stream attribute list after join condition split as project node. + // It may include extra attributes from split join condition. + def streamedSideOutput(): Seq[Attribute] = streamOutputAttr + } + + // For the case entire join condition can be evaluated as ast. + case class NoopJoinCondSplit(condition: Option[Expression], left: Seq[NamedExpression], + right: Seq[NamedExpression], buildSide: GpuBuildSide) + extends JoinCondSplitStrategy(left, right, buildSide) { + override def astCondition(): Option[Expression] = condition + } + + // For inner joins we can apply a post-join condition for any conditions that cannot be + // evaluated directly in a mixed join that leverages a cudf AST expression. + case class JoinCondSplitAsPostFilter(expr: Option[Expression], + attributeSeq: Seq[Attribute], left: Seq[NamedExpression], + right: Seq[NamedExpression], buildSide: GpuBuildSide) + extends JoinCondSplitStrategy(left, right, buildSide) { + private[this] val postFilter = expr.map { e => + GpuBindReferences.bindGpuReferencesTiered( + Seq(e), attributeSeq, false) + } + + override def astCondition(): Option[Expression] = None + + override def processPostJoin(iter: Iterator[ColumnarBatch]): Iterator[ColumnarBatch] = { + postFilter.map { filter => + iter.flatMap { cb => + GpuFilter.filterAndClose(cb, filter, NoopMetric, NoopMetric, NoopMetric) + } + }.getOrElse(iter) + } + } + + /** + * This is the split strategy targeting on the case where ast not supported join condition can be + * extracted and wrapped into extra project node(s). + * + * @param astCond remained join condition after extracting ast not supported parts + * @param left original expressions from join's left child. It's left project input + * attribute. + * @param leftProj extra expressions extracted from original join condition which is not + * supported by ast. It will be evaluated as a project on left side batch. + * @param right original expressions from join's right child. It's left project input + * attribute. + * @param rightProj extra expressions extracted from original join condition which is not + * supported by ast. It will be evaluated as a project on right side batch. + * @param post eliminate the extra columns introduced by join condition split + * @param buildSide indicates which side is build + */ + case class JoinCondSplitAsProject( + astCond: Option[Expression], + left: Seq[NamedExpression], leftProj: Seq[NamedExpression], + right: Seq[NamedExpression], rightProj: Seq[NamedExpression], + post: Seq[NamedExpression], buildSide: GpuBuildSide + ) extends JoinCondSplitStrategy(left ++ leftProj, right ++ rightProj, buildSide) { + private[this] val leftInput = left.map(_.toAttribute) + private[this] val rightInput = right.map(_.toAttribute) + + // Used to build build/stream side project + private[this] val (buildOutput, streamOutput, buildInput, streamInput) = buildSide match { + case GpuBuildLeft => + (leftOutput, rightOutput, leftInput, rightInput) + case GpuBuildRight => + (rightOutput, leftOutput, rightInput, leftInput) + } + + private[this] val buildProj = if (!buildOutput.isEmpty) { + Some(GpuBindReferences.bindGpuReferencesTiered(buildOutput, buildInput, false)) + } else None + + private[this] val streamProj = if (!streamOutput.isEmpty) { + Some(GpuBindReferences.bindGpuReferencesTiered(streamOutput, streamInput, false)) + } else None + + // Remove the intermediate attributes from left and right side project nodes. Output attributes + // need to be updated based on join type. And its attributes covers both original plan and + // extra project node. + private[this] val postProj = if (!post.isEmpty) { + Some( + GpuBindReferences.bindGpuReferencesTiered( + post, (leftOutput ++ rightOutput).map(_.toAttribute), false)) + } else None + + override def astCondition(): Option[Expression] = astCond + + override def processBuildSideAndClose(input: ColumnarBatch): ColumnarBatch = { + buildProj.map { pj => + withResource(input) { cb => + pj.project(cb) + } + }.getOrElse(input) + } + + override def processStreamSideAndClose(input: ColumnarBatch): ColumnarBatch = { + streamProj.map { pj => + withResource(input) { cb => + pj.project(cb) + } + }.getOrElse(input) + } + + override def processPostJoin(iter: Iterator[ColumnarBatch]): Iterator[ColumnarBatch] = { + postProj.map { proj => + iter.map { cb => + withResource(cb) { b => + proj.project(b) + } + } + }.getOrElse(iter) + } + } } diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastHashJoinExecBase.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastHashJoinExecBase.scala index 4982c6e3c9c..7531223fba6 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastHashJoinExecBase.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastHashJoinExecBase.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.rapids.execution import ai.rapids.cudf.{NvtxColor, NvtxRange} import com.nvidia.spark.rapids._ import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} +import com.nvidia.spark.rapids.AstUtil.JoinCondSplitStrategy import com.nvidia.spark.rapids.shims.{GpuBroadcastJoinMeta, ShimBinaryExecNode} import org.apache.spark.TaskContext @@ -55,6 +56,28 @@ abstract class GpuBroadcastHashJoinMetaBase( override val childExprs: Seq[BaseExprMeta[_]] = leftKeys ++ rightKeys ++ conditionMeta + private var taggedForAstCheck = false + + // Avoid checking multiple times + private var isAstCond = false + + /** + * Check whether condition can be ast-able. It includes two cases: 1) all join conditions are + * ast-able; 2) join conditions are ast-able after split and push down to child plans. + */ + def canJoinCondAstAble(): Boolean = { + if (!taggedForAstCheck) { + val Seq(leftPlan, rightPlan) = childPlans + isAstCond = conditionMeta match { + case Some(e) => AstUtil.canExtractNonAstConditionIfNeed( + e, leftPlan.outputAttributes.map(_.exprId), rightPlan.outputAttributes.map(_.exprId)) + case None => true + } + taggedForAstCheck = true + } + isAstCond + } + override def tagPlanForGpu(): Unit = { GpuHashJoin.tagJoin(this, join.joinType, buildSide, join.leftKeys, join.rightKeys, conditionMeta) @@ -103,6 +126,7 @@ abstract class GpuBroadcastHashJoinExecBase( joinType: JoinType, buildSide: GpuBuildSide, override val condition: Option[Expression], + override val joinCondSplitStrategy: JoinCondSplitStrategy, left: SparkPlan, right: SparkPlan) extends ShimBinaryExecNode with GpuHashJoin { import GpuMetric._ diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuHashJoin.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuHashJoin.scala index cbaa1cbe47c..f86f75104a3 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuHashJoin.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuHashJoin.scala @@ -19,6 +19,7 @@ import ai.rapids.cudf.{ColumnView, DType, GatherMap, GroupByAggregation, NullEqu import ai.rapids.cudf.ast.CompiledExpression import com.nvidia.spark.rapids._ import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} +import com.nvidia.spark.rapids.AstUtil.{JoinCondSplitStrategy, NoopJoinCondSplit} import com.nvidia.spark.rapids.RapidsPluginImplicits.AutoCloseableProducingSeq import com.nvidia.spark.rapids.RmmRapidsRetryIterator.{withRestoreOnRetry, withRetryNoSplit} import com.nvidia.spark.rapids.jni.GpuOOM @@ -116,9 +117,11 @@ object GpuHashJoin { joinType match { case _: InnerLike => case RightOuter | LeftOuter | LeftSemi | LeftAnti | ExistenceJoin(_) => - conditionMeta.foreach(meta.requireAstForGpuOn) + // First to check whether can be split if not ast-able. If false, then check requireAst to + // send not-work-on-GPU reason if not replace-able. + conditionMeta.foreach(cond => if (!canJoinCondAstAble(meta)) meta.requireAstForGpuOn(cond)) case FullOuter => - conditionMeta.foreach(meta.requireAstForGpuOn) + conditionMeta.foreach(cond => if (!canJoinCondAstAble(meta)) meta.requireAstForGpuOn(cond)) // FullOuter join cannot support with struct keys as two issues below // * https://github.com/NVIDIA/spark-rapids/issues/2126 // * https://github.com/rapidsai/cudf/issues/7947 @@ -138,6 +141,15 @@ object GpuHashJoin { } } + // Check whether the entire tree is ast-able or being able to split non-ast-able conditions + // into child nodes. Now only support broad hash join. + private[this] def canJoinCondAstAble(meta: SparkPlanMeta[_]): Boolean = { + meta match { + case meta: GpuBroadcastHashJoinMeta => meta.canJoinCondAstAble + case _ => false + } + } + /** Determine if this type of join supports using the right side of the join as the build side. */ def canBuildRight(joinType: JoinType): Boolean = joinType match { case _: InnerLike | LeftOuter | LeftSemi | LeftAnti | FullOuter | _: ExistenceJoin => true @@ -254,6 +266,25 @@ object GpuHashJoin { keys.forall(_.dataType.isInstanceOf[IntegralType]) && keys.map(_.dataType.defaultSize).sum <= 8 } + + def output(joinType: JoinType, left: Seq[Attribute], right: Seq[Attribute]): Seq[Attribute] = { + joinType match { + case _: InnerLike => + left ++ right + case LeftOuter => + left ++ right.map(_.withNullability(true)) + case RightOuter => + left.map(_.withNullability(true)) ++ right + case j: ExistenceJoin => + left :+ j.exists + case LeftExistence(_) => + left + case FullOuter => + left.map(_.withNullability(true)) ++ right.map(_.withNullability(true)) + case x => + throw new IllegalArgumentException(s"GpuHashJoin should not take $x as the JoinType") + } + } } abstract class BaseHashJoinIterator( @@ -866,6 +897,8 @@ trait GpuHashJoin extends GpuExec { def leftKeys: Seq[Expression] def rightKeys: Seq[Expression] def buildSide: GpuBuildSide + def joinCondSplitStrategy: JoinCondSplitStrategy = NoopJoinCondSplit( + condition, left.output, right.output, buildSide) protected lazy val (buildPlan, streamedPlan) = buildSide match { case GpuBuildLeft => (left, right) @@ -885,22 +918,7 @@ trait GpuHashJoin extends GpuExec { } override def output: Seq[Attribute] = { - joinType match { - case _: InnerLike => - left.output ++ right.output - case LeftOuter => - left.output ++ right.output.map(_.withNullability(true)) - case RightOuter => - left.output.map(_.withNullability(true)) ++ right.output - case j: ExistenceJoin => - left.output :+ j.exists - case LeftExistence(_) => - left.output - case FullOuter => - left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) - case x => - throw new IllegalArgumentException(s"GpuHashJoin should not take $x as the JoinType") - } + GpuHashJoin.output(joinType, left.output, right.output) } // If we have a single batch streamed in then we will produce a single batch of output @@ -953,8 +971,10 @@ trait GpuHashJoin extends GpuExec { GpuHashJoin.anyNullableStructChild(buildKeys) protected lazy val (boundBuildKeys, boundStreamKeys) = { - val lkeys = GpuBindReferences.bindGpuReferences(leftKeys, left.output) - val rkeys = GpuBindReferences.bindGpuReferences(rightKeys, right.output) + val lkeys = + GpuBindReferences.bindGpuReferences(leftKeys, joinCondSplitStrategy.joinLeftOutput) + val rkeys = + GpuBindReferences.bindGpuReferences(rightKeys, joinCondSplitStrategy.joinRightOutput) buildSide match { case GpuBuildLeft => (lkeys, rkeys) @@ -963,14 +983,25 @@ trait GpuHashJoin extends GpuExec { } protected lazy val (numFirstConditionTableColumns, boundCondition) = { - val (joinLeft, joinRight) = joinType match { - case RightOuter => (right, left) - case _ => (left, right) + val joinLeft = joinType match { + case RightOuter => + if(buildSide == GpuBuildRight) { + joinCondSplitStrategy.buildSideOutput + } else { + joinCondSplitStrategy.streamedSideOutput + } + case _ => + if (buildSide == GpuBuildRight) { + joinCondSplitStrategy.streamedSideOutput + } else { + joinCondSplitStrategy.buildSideOutput + } } val boundCondition = condition.map { c => - GpuBindReferences.bindGpuReference(c, joinLeft.output ++ joinRight.output) + GpuBindReferences.bindGpuReference(c, + joinCondSplitStrategy.streamedSideOutput ++ joinCondSplitStrategy.buildSideOutput) } - (joinLeft.output.size, boundCondition) + (joinLeft.size, boundCondition) } def doJoin( @@ -994,13 +1025,14 @@ trait GpuHashJoin extends GpuExec { builtBatch } - val spillableBuiltBatch = withResource(nullFiltered) { + val spillableBuiltBatch = withResource(joinCondSplitStrategy + .processBuildSideAndClose(nullFiltered)) { LazySpillableColumnarBatch(_, "built") } val lazyStream = stream.map { cb => - withResource(cb) { cb => - LazySpillableColumnarBatch(cb, "stream_batch") + withResource(joinCondSplitStrategy.processStreamSideAndClose(cb)) { updatedBatch => + LazySpillableColumnarBatch(updatedBatch, "stream_batch") } } @@ -1019,25 +1051,29 @@ trait GpuHashJoin extends GpuExec { opTime, joinTime) case FullOuter => - new HashFullJoinIterator(spillableBuiltBatch, boundBuildKeys, lazyStream, - boundStreamKeys, streamedPlan.output, boundCondition, numFirstConditionTableColumns, - targetSize, buildSide, compareNullsEqual, opTime, joinTime) + new HashFullJoinIterator( + spillableBuiltBatch, boundBuildKeys, lazyStream, + boundStreamKeys, joinCondSplitStrategy.streamedSideOutput, boundCondition, + numFirstConditionTableColumns, targetSize, buildSide, compareNullsEqual, opTime, + joinTime) case _ => if (boundCondition.isDefined) { // ConditionalHashJoinIterator will close the compiled condition val compiledCondition = boundCondition.get.convertToAst(numFirstConditionTableColumns).compile() - new ConditionalHashJoinIterator(spillableBuiltBatch, boundBuildKeys, lazyStream, - boundStreamKeys, streamedPlan.output, compiledCondition, + new ConditionalHashJoinIterator( + spillableBuiltBatch, boundBuildKeys, lazyStream, + boundStreamKeys, joinCondSplitStrategy.streamedSideOutput, compiledCondition, targetSize, joinType, buildSide, compareNullsEqual, opTime, joinTime) } else { - new HashJoinIterator(spillableBuiltBatch, boundBuildKeys, lazyStream, boundStreamKeys, - streamedPlan.output, targetSize, joinType, buildSide, compareNullsEqual, - opTime, joinTime) + new HashJoinIterator( + spillableBuiltBatch, boundBuildKeys, lazyStream, boundStreamKeys, + joinCondSplitStrategy.streamedSideOutput, targetSize, joinType, buildSide, + compareNullsEqual, opTime, joinTime) } } - joinIterator.map { cb => + joinCondSplitStrategy.processPostJoin(joinIterator).map { cb => joinOutputRows += cb.numRows() numOutputRows += cb.numRows() numOutputBatches += 1 diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuSubPartitionHashJoin.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuSubPartitionHashJoin.scala index 33ca8c906f6..816aa4ac07d 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuSubPartitionHashJoin.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuSubPartitionHashJoin.scala @@ -614,7 +614,7 @@ trait GpuSubPartitionHashJoin extends Logging { self: GpuHashJoin => } } // Leverage the original join iterators - val joinIter = doJoin(buildCb, streamIter, targetSize, + val joinIter = doJoin(buildCb, streamIter, targetSize, numOutputRows, joinOutputRows, numOutputBatches, opTime, joinTime) Some(joinIter) } diff --git a/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastHashJoinExec.scala b/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastHashJoinExec.scala index 0b1be70234b..cd640057a58 100644 --- a/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastHashJoinExec.scala +++ b/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastHashJoinExec.scala @@ -35,6 +35,7 @@ spark-rapids-shim-json-lines ***/ package org.apache.spark.sql.rapids.execution import com.nvidia.spark.rapids._ +import com.nvidia.spark.rapids.AstUtil.{JoinCondSplitAsPostFilter, JoinCondSplitAsProject, JoinCondSplitStrategy, NoopJoinCondSplit} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.JoinType @@ -48,12 +49,6 @@ class GpuBroadcastHashJoinMeta( rule: DataFromReplacementRule) extends GpuBroadcastHashJoinMetaBase(join, conf, parent, rule) { override def convertToGpu(): GpuExec = { - val condition = conditionMeta.map(_.convertToGpu()) - val (joinCondition, filterCondition) = if (conditionMeta.forall(_.canThisBeAst)) { - (condition, None) - } else { - (None, condition) - } val Seq(left, right) = childPlans.map(_.convertIfNeeded()) // The broadcast part of this must be a BroadcastExchangeExec val buildSideMeta = buildSide match { @@ -61,16 +56,32 @@ class GpuBroadcastHashJoinMeta( case GpuBuildRight => right } verifyBuildSideWasReplaced(buildSideMeta) - val joinExec = GpuBroadcastHashJoinExec( + // First to check whether we can extract some non-supported AST conditions. If not, will do a + // post-join filter right after hash join node. Otherwise, do split as project. + val nonAstJoinCond = if (!canJoinCondAstAble()) { + JoinCondSplitAsPostFilter( + conditionMeta.map(_.convertToGpu()), GpuHashJoin.output( + join.joinType, left.output, right.output), left.output, right.output, buildSide) + } else { + val (remain, leftExpr, rightExpr) = AstUtil.extractNonAstFromJoinCond( + conditionMeta, left.output, right.output, true) + if (leftExpr.isEmpty && rightExpr.isEmpty) { + NoopJoinCondSplit(remain, left.output, right.output, buildSide) + } else { + JoinCondSplitAsProject( + remain, left.output, leftExpr, right.output, rightExpr, + GpuHashJoin.output(join.joinType, left.output, right.output), buildSide) + } + } + + GpuBroadcastHashJoinExec( leftKeys.map(_.convertToGpu()), rightKeys.map(_.convertToGpu()), join.joinType, buildSide, - joinCondition, + nonAstJoinCond.astCondition(), + nonAstJoinCond, left, right) - // For inner joins we can apply a post-join condition for any conditions that cannot be - // evaluated directly in a mixed join that leverages a cudf AST expression - filterCondition.map(c => GpuFilterExec(c, joinExec)()).getOrElse(joinExec) } } @@ -80,6 +91,7 @@ case class GpuBroadcastHashJoinExec( joinType: JoinType, buildSide: GpuBuildSide, override val condition: Option[Expression], + override val joinCondSplitStrategy: JoinCondSplitStrategy, left: SparkPlan, right: SparkPlan) extends GpuBroadcastHashJoinExecBase( - leftKeys, rightKeys, joinType, buildSide, condition, left, right) \ No newline at end of file + leftKeys, rightKeys, joinType, buildSide, condition, joinCondSplitStrategy, left, right) \ No newline at end of file diff --git a/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastHashJoinExec.scala b/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastHashJoinExec.scala index ca4b0dfa31a..5db2c11c3b1 100644 --- a/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastHashJoinExec.scala +++ b/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastHashJoinExec.scala @@ -24,6 +24,7 @@ package org.apache.spark.sql.rapids.execution import ai.rapids.cudf.{NvtxColor, NvtxRange} import com.nvidia.spark.rapids._ import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} +import com.nvidia.spark.rapids.AstUtil.{JoinCondSplitAsPostFilter, JoinCondSplitAsProject, JoinCondSplitStrategy, NoopJoinCondSplit} import org.apache.spark.TaskContext import org.apache.spark.rapids.shims.GpuShuffleExchangeExec @@ -45,12 +46,6 @@ class GpuBroadcastHashJoinMeta( rule: DataFromReplacementRule) extends GpuBroadcastHashJoinMetaBase(join, conf, parent, rule) { override def convertToGpu(): GpuExec = { - val condition = conditionMeta.map(_.convertToGpu()) - val (joinCondition, filterCondition) = if (conditionMeta.forall(_.canThisBeAst)) { - (condition, None) - } else { - (None, condition) - } val Seq(left, right) = childPlans.map(_.convertIfNeeded()) // The broadcast part of this must be a BroadcastExchangeExec val buildSideMeta = buildSide match { @@ -58,18 +53,32 @@ class GpuBroadcastHashJoinMeta( case GpuBuildRight => right } verifyBuildSideWasReplaced(buildSideMeta) - val joinExec = GpuBroadcastHashJoinExec( + // First to check whether we can extract some non-supported AST conditions. If not, will do a + // post-join filter right after hash join node. Otherwise, do split as project. + val nonAstJoinCond = if (!canJoinCondAstAble()) { + JoinCondSplitAsPostFilter( + conditionMeta.map(_.convertToGpu()), GpuHashJoin.output( + join.joinType, left.output, right.output), left.output, right.output, buildSide) + } else { + val (remain, leftExpr, rightExpr) = AstUtil.extractNonAstFromJoinCond( + conditionMeta, left.output, right.output, true) + if (leftExpr.isEmpty && rightExpr.isEmpty) { + NoopJoinCondSplit(remain, left.output, right.output, buildSide) + } else { + JoinCondSplitAsProject( + remain, left.output, leftExpr, right.output, rightExpr, + GpuHashJoin.output(join.joinType, left.output, right.output), buildSide) + } + } + + GpuBroadcastHashJoinExec( leftKeys.map(_.convertToGpu()), rightKeys.map(_.convertToGpu()), join.joinType, buildSide, - joinCondition, - left, - right, - join.isExecutorBroadcast) - // For inner joins we can apply a post-join condition for any conditions that cannot be - // evaluated directly in a mixed join that leverages a cudf AST expression - filterCondition.map(c => GpuFilterExec(c, joinExec)()).getOrElse(joinExec) + nonAstJoinCond.astCondition(), + nonAstJoinCond, + left, right, join.isExecutorBroadcast) } } @@ -79,11 +88,12 @@ case class GpuBroadcastHashJoinExec( joinType: JoinType, buildSide: GpuBuildSide, override val condition: Option[Expression], + override val joinCondSplitStrategy: JoinCondSplitStrategy, left: SparkPlan, - right: SparkPlan, + right: SparkPlan, executorBroadcast: Boolean) - extends GpuBroadcastHashJoinExecBase( - leftKeys, rightKeys, joinType, buildSide, condition, left, right) { + extends GpuBroadcastHashJoinExecBase(leftKeys, rightKeys, joinType, buildSide, + condition, joinCondSplitStrategy, left, right) { import GpuMetric._ override lazy val additionalMetrics: Map[String, GpuMetric] = Map( @@ -147,8 +157,8 @@ case class GpuBroadcastHashJoinExec( GpuSemaphore.acquireIfNecessary(TaskContext.get()) } } - val buildBatch = GpuExecutorBroadcastHelper.getExecutorBroadcastBatch(buildRelation, - buildSchema, buildOutput, metricsMap, targetSize) + val buildBatch = GpuExecutorBroadcastHelper.getExecutorBroadcastBatch( + buildRelation, buildSchema, buildOutput, metricsMap, targetSize) (buildBatch, bufferedStreamIter) } } diff --git a/sql-plugin/src/main/spark340/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastHashJoinExec.scala b/sql-plugin/src/main/spark340/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastHashJoinExec.scala index 4985d791829..e2a8bb51ba5 100644 --- a/sql-plugin/src/main/spark340/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastHashJoinExec.scala +++ b/sql-plugin/src/main/spark340/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastHashJoinExec.scala @@ -22,6 +22,7 @@ spark-rapids-shim-json-lines ***/ package org.apache.spark.sql.rapids.execution import com.nvidia.spark.rapids._ +import com.nvidia.spark.rapids.AstUtil.{JoinCondSplitAsPostFilter, JoinCondSplitAsProject, JoinCondSplitStrategy, NoopJoinCondSplit} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.JoinType @@ -35,12 +36,6 @@ class GpuBroadcastHashJoinMeta( rule: DataFromReplacementRule) extends GpuBroadcastHashJoinMetaBase(join, conf, parent, rule) { override def convertToGpu(): GpuExec = { - val condition = conditionMeta.map(_.convertToGpu()) - val (joinCondition, filterCondition) = if (conditionMeta.forall(_.canThisBeAst)) { - (condition, None) - } else { - (None, condition) - } val Seq(left, right) = childPlans.map(_.convertIfNeeded()) // The broadcast part of this must be a BroadcastExchangeExec val buildSideMeta = buildSide match { @@ -48,16 +43,31 @@ class GpuBroadcastHashJoinMeta( case GpuBuildRight => right } verifyBuildSideWasReplaced(buildSideMeta) - val joinExec = GpuBroadcastHashJoinExec( + // First to check whether we can extract some non-supported AST conditions. If not, will do a + // post-join filter right after hash join node. Otherwise, do split as project. + val nonAstJoinCond = if (!canJoinCondAstAble()) { + JoinCondSplitAsPostFilter(conditionMeta.map(_.convertToGpu()), GpuHashJoin.output( + join.joinType, left.output, right.output), left.output, right.output, buildSide) + } else { + val (remain, leftExpr, rightExpr) = AstUtil.extractNonAstFromJoinCond( + conditionMeta, left.output, right.output, true) + if(leftExpr.isEmpty && rightExpr.isEmpty) { + NoopJoinCondSplit(remain, left.output, right.output, buildSide) + } else { + JoinCondSplitAsProject( + remain, left.output, leftExpr, right.output, rightExpr, + GpuHashJoin.output(join.joinType, left.output, right.output), buildSide) + } + } + + GpuBroadcastHashJoinExec( leftKeys.map(_.convertToGpu()), rightKeys.map(_.convertToGpu()), join.joinType, buildSide, - joinCondition, + nonAstJoinCond.astCondition(), + nonAstJoinCond, left, right) - // For inner joins we can apply a post-join condition for any conditions that cannot be - // evaluated directly in a mixed join that leverages a cudf AST expression - filterCondition.map(c => GpuFilterExec(c, joinExec)()).getOrElse(joinExec) } } @@ -67,6 +77,7 @@ case class GpuBroadcastHashJoinExec( joinType: JoinType, buildSide: GpuBuildSide, override val condition: Option[Expression], + override val joinCondSplitStrategy: JoinCondSplitStrategy, left: SparkPlan, right: SparkPlan) extends GpuBroadcastHashJoinExecBase( - leftKeys, rightKeys, joinType, buildSide, condition, left, right) + leftKeys, rightKeys, joinType, buildSide, condition, joinCondSplitStrategy, left, right)