Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support split broadcast join condition into ast and non-ast [databricks] #9760

Merged
merged 36 commits into from
Dec 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
a9c32a4
Support split broadcast join condition into ast and non-ast
winningsix Nov 16, 2023
4f1ce38
Fix
winningsix Nov 20, 2023
3fe880b
Fix compile
winningsix Nov 20, 2023
5823e65
Fix
winningsix Nov 20, 2023
b6d4c3e
Fix
winningsix Nov 20, 2023
9538f36
Address comments
winningsix Nov 22, 2023
73bb81f
Fix
winningsix Nov 23, 2023
16876f1
Fix
winningsix Nov 24, 2023
94e27a5
Fix
winningsix Nov 25, 2023
8a31238
Fix Spark 311
winningsix Nov 25, 2023
944f64d
Update sql-plugin/src/main/scala/org/apache/spark/sql/rapids/executio…
winningsix Nov 27, 2023
100adb5
Merge remote-tracking branch 'origin/branch-24.02' into asthbj
winningsix Dec 5, 2023
f3fc677
Address comments
winningsix Dec 5, 2023
3ea1179
Update sql-plugin/src/main/scala/com/nvidia/spark/rapids/AstUtil.scala
winningsix Dec 5, 2023
7f833fd
Update sql-plugin/src/main/scala/com/nvidia/spark/rapids/AstUtil.scala
winningsix Dec 5, 2023
f385b5c
Update sql-plugin/src/main/scala/com/nvidia/spark/rapids/AstUtil.scala
winningsix Dec 5, 2023
09fdc24
Update sql-plugin/src/main/scala/com/nvidia/spark/rapids/AstUtil.scala
winningsix Dec 5, 2023
a8fee3a
Fix scala 2.13, code style, refactor
winningsix Dec 6, 2023
3858e9a
Minor fix
winningsix Dec 6, 2023
2b4c275
minor
winningsix Dec 6, 2023
056a6b8
Fix scala 2.13
winningsix Dec 6, 2023
0751af9
Fix DBX
winningsix Dec 6, 2023
838a580
Update sql-plugin/src/main/scala/com/nvidia/spark/rapids/AstUtil.scala
winningsix Dec 6, 2023
53efed6
Update sql-plugin/src/main/scala/com/nvidia/spark/rapids/AstUtil.scala
winningsix Dec 6, 2023
0513900
Fix scala 2.13, refactor
winningsix Dec 6, 2023
1ea39d0
Revert unnecessary changes
winningsix Dec 7, 2023
2d99043
Fix failed UT
winningsix Dec 7, 2023
514282d
Update sql-plugin/src/main/scala/com/nvidia/spark/rapids/AstUtil.scala
winningsix Dec 7, 2023
441dd75
Update sql-plugin/src/main/scala/com/nvidia/spark/rapids/AstUtil.scala
winningsix Dec 7, 2023
660c357
Update sql-plugin/src/main/scala/com/nvidia/spark/rapids/AstUtil.scala
winningsix Dec 7, 2023
ca5e80a
Update sql-plugin/src/main/scala/org/apache/spark/sql/rapids/executio…
winningsix Dec 7, 2023
447c24a
Update sql-plugin/src/main/scala/org/apache/spark/sql/rapids/executio…
winningsix Dec 7, 2023
1ae0711
Update sql-plugin/src/main/scala/org/apache/spark/sql/rapids/executio…
winningsix Dec 7, 2023
ea5738f
Update sql-plugin/src/main/scala/org/apache/spark/sql/rapids/executio…
winningsix Dec 7, 2023
e7ec35a
Update sql-plugin/src/main/scala/com/nvidia/spark/rapids/AstUtil.scala
winningsix Dec 7, 2023
fbde5c9
Update sql-plugin/src/main/scala/com/nvidia/spark/rapids/AstUtil.scala
winningsix Dec 7, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 35 additions & 2 deletions integration_tests/src/main/python/join_test.py
jlowe marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@
from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_fallback_collect, assert_cpu_and_gpu_are_equal_collect_with_capture
from conftest import is_databricks_runtime, is_emr_runtime, is_not_utc
from data_gen import *
from marks import ignore_order, allow_non_gpu, incompat, validate_execs_in_gpu_plan
from spark_session import with_cpu_session, is_before_spark_330, is_databricks_runtime
from marks import ignore_order, allow_non_gpu, incompat, validate_execs_in_gpu_plan, \
datagen_overrides
from spark_session import with_cpu_session, is_before_spark_330, is_databricks113_or_later, is_databricks_runtime

pytestmark = [pytest.mark.nightly_resource_consuming_test]

Expand Down Expand Up @@ -434,6 +435,38 @@ def do_join(spark):
return broadcast(left).join(right, left.a > f.log(right.r_a), join_type)
assert_gpu_fallback_collect(do_join, 'BroadcastNestedLoopJoinExec')

# Allowing non Gpu for ShuffleExchangeExec is mainly for Databricks where its exchange is CPU based ('Exchange SinglePartition, EXECUTOR_BROADCAST').
db_113_cpu_bhj_join_allow=["ShuffleExchangeExec"] if is_databricks113_or_later() else []


@allow_non_gpu(*db_113_cpu_bhj_join_allow)
@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', [IntegerGen(), LongGen()], ids=idfn)
@pytest.mark.parametrize('join_type', ['Left', 'Inner', 'LeftSemi', 'LeftAnti'], ids=idfn)
def test_broadcast_hash_join_on_non_ast_condition_without_fallback(data_gen, join_type):
# This is to test BHJ with a condition not fully supported by AST. With extra project nodes wrapped, join can still run on GPU other than fallback.
def do_join(spark):
left, right = create_df(spark, data_gen, 50, 25)
# AST does not support cast or logarithm yet
return left.join(right.hint("broadcast"), ((left.b == right.r_b) & (f.round(left.a).cast('integer') > f.round(f.log(right.r_a).cast('integer')))), join_type)
assert_gpu_and_cpu_are_equal_collect(do_join, conf = {"spark.rapids.sql.castFloatToIntegralTypes.enabled": True})


@allow_non_gpu('BroadcastHashJoinExec', 'BroadcastExchangeExec')
@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', [IntegerGen(), LongGen()], ids=idfn)
@pytest.mark.parametrize('join_type', ['Left', 'LeftSemi', 'LeftAnti'], ids=idfn)
def test_broadcast_hash_join_on_non_ast_condition_fallback(data_gen, join_type):
# This is to test BHJ with a condition not fully supported by AST. Since AST doesn't support double, this query fallback to CPU.
# Inner join is not included since it can be supported by GPU via a post filter.
def do_join(spark):
left, right = create_df(spark, data_gen, 50, 25)
# AST does not support cast or logarithm yet and also it's not able to be split as project
# node those both sides are involved in join condition
return left.join(right.hint("broadcast"), ((left.b == right.r_b) & (left.a.cast('double') > right.r_a.cast('double'))), join_type)
assert_gpu_fallback_collect(do_join, 'BroadcastHashJoinExec')


@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', [byte_gen, short_gen, int_gen, long_gen,
float_gen, double_gen,
Expand Down
166 changes: 165 additions & 1 deletion sql-plugin/src/main/scala/com/nvidia/spark/rapids/AstUtil.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,15 @@

package com.nvidia.spark.rapids

import java.io.Serializable

import com.nvidia.spark.rapids.Arm.withResource
import scala.collection.mutable
import scala.collection.mutable.ListBuffer

import org.apache.spark.sql.catalyst.expressions.{AttributeSeq, Expression, ExprId, NamedExpression}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSeq, Expression, ExprId, NamedExpression}
import org.apache.spark.sql.rapids.catalyst.expressions.{GpuEquivalentExpressions, GpuExpressionEquals}
import org.apache.spark.sql.vectorized.ColumnarBatch


object AstUtil {
Expand Down Expand Up @@ -119,4 +123,164 @@ object AstUtil {
}
}
}

/**
* Transforms the original join condition into extra filter/project when necessary.
* It's targeted for some cases join condition is not fully evaluated by ast.
* Based on join condition, it can be transformed into three major strategies:
* (1) [NoopJoinCondSplit]: noop when join condition can be fully evaluated with ast.
* (2) [JoinCondSplitAsPostFilter]: entire join condition is pulled out as a post filter
* after join condition.
* (3) [JoinCondSplitAsProject]: extract not supported join condition into pre-project nodes
* on each join child. One extra project node is introduced to remove intermediate attributes.
*/
abstract class JoinCondSplitStrategy(left: Seq[NamedExpression],
right: Seq[NamedExpression], buildSide: GpuBuildSide) extends Serializable {

// Actual output of build/stream side project due to join condition split
private[this] val (buildOutputAttr, streamOutputAttr) = buildSide match {
case GpuBuildLeft => (joinLeftOutput, joinRightOutput)
case GpuBuildRight => (joinRightOutput, joinLeftOutput)
}

// This is the left side child of join. In `split as project` strategy, it may be different
// from original left child with extracted join condition attribute.
def leftOutput(): Seq[NamedExpression] = left

// This is the right side child of join. In `split as project` strategy, it may be different
// from original right child with extracted join condition attribute.
def rightOutput(): Seq[NamedExpression] = right

def astCondition(): Option[Expression]

def processBuildSideAndClose(input: ColumnarBatch): ColumnarBatch = input

def processStreamSideAndClose(input: ColumnarBatch): ColumnarBatch = input

def processPostJoin(iter: Iterator[ColumnarBatch]): Iterator[ColumnarBatch] = iter

// This is the left side child of join. In `split as project` strategy, it may be different
// from original left child with extracted join condition attribute.
def joinLeftOutput(): Seq[Attribute] = leftOutput.map(expr => expr.toAttribute)

// This is the right side child of join. In `split as project` strategy, it may be different
// from original right child with extracted join condition attribute.
def joinRightOutput(): Seq[Attribute] = rightOutput.map(expr => expr.toAttribute)

// Updated build attribute list after join condition split as project node.
// It may include extra attributes from split join condition.
def buildSideOutput(): Seq[Attribute] = buildOutputAttr

// Updated stream attribute list after join condition split as project node.
// It may include extra attributes from split join condition.
def streamedSideOutput(): Seq[Attribute] = streamOutputAttr
}

// For the case entire join condition can be evaluated as ast.
case class NoopJoinCondSplit(condition: Option[Expression], left: Seq[NamedExpression],
right: Seq[NamedExpression], buildSide: GpuBuildSide)
extends JoinCondSplitStrategy(left, right, buildSide) {
override def astCondition(): Option[Expression] = condition
}

// For inner joins we can apply a post-join condition for any conditions that cannot be
// evaluated directly in a mixed join that leverages a cudf AST expression.
case class JoinCondSplitAsPostFilter(expr: Option[Expression],
attributeSeq: Seq[Attribute], left: Seq[NamedExpression],
right: Seq[NamedExpression], buildSide: GpuBuildSide)
extends JoinCondSplitStrategy(left, right, buildSide) {
private[this] val postFilter = expr.map { e =>
GpuBindReferences.bindGpuReferencesTiered(
Seq(e), attributeSeq, false)
}

override def astCondition(): Option[Expression] = None

override def processPostJoin(iter: Iterator[ColumnarBatch]): Iterator[ColumnarBatch] = {
postFilter.map { filter =>
iter.flatMap { cb =>
GpuFilter.filterAndClose(cb, filter, NoopMetric, NoopMetric, NoopMetric)
}
}.getOrElse(iter)
}
}

/**
* This is the split strategy targeting on the case where ast not supported join condition can be
* extracted and wrapped into extra project node(s).
*
* @param astCond remained join condition after extracting ast not supported parts
* @param left original expressions from join's left child. It's left project input
* attribute.
* @param leftProj extra expressions extracted from original join condition which is not
* supported by ast. It will be evaluated as a project on left side batch.
* @param right original expressions from join's right child. It's left project input
* attribute.
* @param rightProj extra expressions extracted from original join condition which is not
* supported by ast. It will be evaluated as a project on right side batch.
* @param post eliminate the extra columns introduced by join condition split
* @param buildSide indicates which side is build
*/
case class JoinCondSplitAsProject(
astCond: Option[Expression],
left: Seq[NamedExpression], leftProj: Seq[NamedExpression],
right: Seq[NamedExpression], rightProj: Seq[NamedExpression],
post: Seq[NamedExpression], buildSide: GpuBuildSide
) extends JoinCondSplitStrategy(left ++ leftProj, right ++ rightProj, buildSide) {
private[this] val leftInput = left.map(_.toAttribute)
private[this] val rightInput = right.map(_.toAttribute)

// Used to build build/stream side project
private[this] val (buildOutput, streamOutput, buildInput, streamInput) = buildSide match {
case GpuBuildLeft =>
(leftOutput, rightOutput, leftInput, rightInput)
case GpuBuildRight =>
(rightOutput, leftOutput, rightInput, leftInput)
}

private[this] val buildProj = if (!buildOutput.isEmpty) {
Some(GpuBindReferences.bindGpuReferencesTiered(buildOutput, buildInput, false))
} else None

private[this] val streamProj = if (!streamOutput.isEmpty) {
Some(GpuBindReferences.bindGpuReferencesTiered(streamOutput, streamInput, false))
} else None

// Remove the intermediate attributes from left and right side project nodes. Output attributes
// need to be updated based on join type. And its attributes covers both original plan and
// extra project node.
private[this] val postProj = if (!post.isEmpty) {
Some(
GpuBindReferences.bindGpuReferencesTiered(
post, (leftOutput ++ rightOutput).map(_.toAttribute), false))
} else None

override def astCondition(): Option[Expression] = astCond

override def processBuildSideAndClose(input: ColumnarBatch): ColumnarBatch = {
buildProj.map { pj =>
withResource(input) { cb =>
pj.project(cb)
}
}.getOrElse(input)
}

override def processStreamSideAndClose(input: ColumnarBatch): ColumnarBatch = {
streamProj.map { pj =>
withResource(input) { cb =>
pj.project(cb)
}
}.getOrElse(input)
}

override def processPostJoin(iter: Iterator[ColumnarBatch]): Iterator[ColumnarBatch] = {
postProj.map { proj =>
iter.map { cb =>
withResource(cb) { b =>
proj.project(b)
}
}
}.getOrElse(iter)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.rapids.execution
import ai.rapids.cudf.{NvtxColor, NvtxRange}
import com.nvidia.spark.rapids._
import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource}
import com.nvidia.spark.rapids.AstUtil.JoinCondSplitStrategy
import com.nvidia.spark.rapids.shims.{GpuBroadcastJoinMeta, ShimBinaryExecNode}

import org.apache.spark.TaskContext
Expand Down Expand Up @@ -55,6 +56,28 @@ abstract class GpuBroadcastHashJoinMetaBase(

override val childExprs: Seq[BaseExprMeta[_]] = leftKeys ++ rightKeys ++ conditionMeta

private var taggedForAstCheck = false

// Avoid checking multiple times
private var isAstCond = false

/**
* Check whether condition can be ast-able. It includes two cases: 1) all join conditions are
* ast-able; 2) join conditions are ast-able after split and push down to child plans.
*/
def canJoinCondAstAble(): Boolean = {
if (!taggedForAstCheck) {
val Seq(leftPlan, rightPlan) = childPlans
isAstCond = conditionMeta match {
case Some(e) => AstUtil.canExtractNonAstConditionIfNeed(
e, leftPlan.outputAttributes.map(_.exprId), rightPlan.outputAttributes.map(_.exprId))
case None => true
}
taggedForAstCheck = true
}
isAstCond
}

override def tagPlanForGpu(): Unit = {
GpuHashJoin.tagJoin(this, join.joinType, buildSide, join.leftKeys, join.rightKeys,
conditionMeta)
Expand Down Expand Up @@ -103,6 +126,7 @@ abstract class GpuBroadcastHashJoinExecBase(
joinType: JoinType,
buildSide: GpuBuildSide,
override val condition: Option[Expression],
override val joinCondSplitStrategy: JoinCondSplitStrategy,
left: SparkPlan,
right: SparkPlan) extends ShimBinaryExecNode with GpuHashJoin {
import GpuMetric._
Expand Down
Loading