Skip to content

Commit

Permalink
Support split broadcast join condition into ast and non-ast [databric…
Browse files Browse the repository at this point in the history
…ks] (NVIDIA#9760)

* Support split broadcast join condition into ast and non-ast

Signed-off-by: Ferdinand Xu <ferdinandx@nvidia.com>

* Fix

* Fix compile

* Fix

* Fix

* Address comments

* Fix

Update all versions

* Fix

* Fix

* Fix Spark 311

* Update sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastHashJoinExecBase.scala

Co-authored-by: Jason Lowe <jlowe@nvidia.com>

* Address comments

* Update sql-plugin/src/main/scala/com/nvidia/spark/rapids/AstUtil.scala

Co-authored-by: Jason Lowe <jlowe@nvidia.com>

* Update sql-plugin/src/main/scala/com/nvidia/spark/rapids/AstUtil.scala

Co-authored-by: Jason Lowe <jlowe@nvidia.com>

* Update sql-plugin/src/main/scala/com/nvidia/spark/rapids/AstUtil.scala

Co-authored-by: Jason Lowe <jlowe@nvidia.com>

* Update sql-plugin/src/main/scala/com/nvidia/spark/rapids/AstUtil.scala

Co-authored-by: Jason Lowe <jlowe@nvidia.com>

* Fix scala 2.13, code style, refactor

* Minor fix

* minor

* Fix scala 2.13

* Fix DBX

* Update sql-plugin/src/main/scala/com/nvidia/spark/rapids/AstUtil.scala

Co-authored-by: Navin Kumar <97137715+NVnavkumar@users.noreply.github.com>

* Update sql-plugin/src/main/scala/com/nvidia/spark/rapids/AstUtil.scala

Co-authored-by: Jason Lowe <jlowe@nvidia.com>

* Fix scala 2.13, refactor

* Revert unnecessary changes

* Fix failed UT

* Update sql-plugin/src/main/scala/com/nvidia/spark/rapids/AstUtil.scala

Co-authored-by: Jason Lowe <jlowe@nvidia.com>

* Update sql-plugin/src/main/scala/com/nvidia/spark/rapids/AstUtil.scala

Co-authored-by: Jason Lowe <jlowe@nvidia.com>

* Update sql-plugin/src/main/scala/com/nvidia/spark/rapids/AstUtil.scala

Co-authored-by: Jason Lowe <jlowe@nvidia.com>

* Update sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuHashJoin.scala

Co-authored-by: Jason Lowe <jlowe@nvidia.com>

* Update sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuHashJoin.scala

Co-authored-by: Jason Lowe <jlowe@nvidia.com>

* Update sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuHashJoin.scala

Co-authored-by: Jason Lowe <jlowe@nvidia.com>

* Update sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuHashJoin.scala

Co-authored-by: Jason Lowe <jlowe@nvidia.com>

* Update sql-plugin/src/main/scala/com/nvidia/spark/rapids/AstUtil.scala

Co-authored-by: Jason Lowe <jlowe@nvidia.com>

* Update sql-plugin/src/main/scala/com/nvidia/spark/rapids/AstUtil.scala

Co-authored-by: Jason Lowe <jlowe@nvidia.com>

---------

Signed-off-by: Ferdinand Xu <ferdinandx@nvidia.com>
Co-authored-by: Jason Lowe <jlowe@nvidia.com>
Co-authored-by: Navin Kumar <97137715+NVnavkumar@users.noreply.github.com>
  • Loading branch information
3 people committed Jan 22, 2024
1 parent 2caf870 commit aed41d2
Show file tree
Hide file tree
Showing 8 changed files with 374 additions and 84 deletions.
37 changes: 35 additions & 2 deletions integration_tests/src/main/python/join_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@
from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_fallback_collect, assert_cpu_and_gpu_are_equal_collect_with_capture
from conftest import is_databricks_runtime, is_emr_runtime, is_not_utc
from data_gen import *
from marks import ignore_order, allow_non_gpu, incompat, validate_execs_in_gpu_plan
from spark_session import with_cpu_session, is_before_spark_330, is_databricks_runtime
from marks import ignore_order, allow_non_gpu, incompat, validate_execs_in_gpu_plan, \
datagen_overrides
from spark_session import with_cpu_session, is_before_spark_330, is_databricks113_or_later, is_databricks_runtime

pytestmark = [pytest.mark.nightly_resource_consuming_test]

Expand Down Expand Up @@ -436,6 +437,38 @@ def do_join(spark):
return broadcast(left).join(right, left.a > f.log(right.r_a), join_type)
assert_gpu_fallback_collect(do_join, 'BroadcastNestedLoopJoinExec')

# Allowing non Gpu for ShuffleExchangeExec is mainly for Databricks where its exchange is CPU based ('Exchange SinglePartition, EXECUTOR_BROADCAST').
db_113_cpu_bhj_join_allow=["ShuffleExchangeExec"] if is_databricks113_or_later() else []


@allow_non_gpu(*db_113_cpu_bhj_join_allow)
@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', [IntegerGen(), LongGen()], ids=idfn)
@pytest.mark.parametrize('join_type', ['Left', 'Inner', 'LeftSemi', 'LeftAnti'], ids=idfn)
def test_broadcast_hash_join_on_non_ast_condition_without_fallback(data_gen, join_type):
# This is to test BHJ with a condition not fully supported by AST. With extra project nodes wrapped, join can still run on GPU other than fallback.
def do_join(spark):
left, right = create_df(spark, data_gen, 50, 25)
# AST does not support cast or logarithm yet
return left.join(right.hint("broadcast"), ((left.b == right.r_b) & (f.round(left.a).cast('integer') > f.round(f.log(right.r_a).cast('integer')))), join_type)
assert_gpu_and_cpu_are_equal_collect(do_join, conf = {"spark.rapids.sql.castFloatToIntegralTypes.enabled": True})


@allow_non_gpu('BroadcastHashJoinExec', 'BroadcastExchangeExec')
@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', [IntegerGen(), LongGen()], ids=idfn)
@pytest.mark.parametrize('join_type', ['Left', 'LeftSemi', 'LeftAnti'], ids=idfn)
def test_broadcast_hash_join_on_non_ast_condition_fallback(data_gen, join_type):
# This is to test BHJ with a condition not fully supported by AST. Since AST doesn't support double, this query fallback to CPU.
# Inner join is not included since it can be supported by GPU via a post filter.
def do_join(spark):
left, right = create_df(spark, data_gen, 50, 25)
# AST does not support cast or logarithm yet and also it's not able to be split as project
# node those both sides are involved in join condition
return left.join(right.hint("broadcast"), ((left.b == right.r_b) & (left.a.cast('double') > right.r_a.cast('double'))), join_type)
assert_gpu_fallback_collect(do_join, 'BroadcastHashJoinExec')


@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', [byte_gen, short_gen, int_gen, long_gen,
float_gen, double_gen,
Expand Down
166 changes: 165 additions & 1 deletion sql-plugin/src/main/scala/com/nvidia/spark/rapids/AstUtil.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,15 @@

package com.nvidia.spark.rapids

import java.io.Serializable

import com.nvidia.spark.rapids.Arm.withResource
import scala.collection.mutable
import scala.collection.mutable.ListBuffer

import org.apache.spark.sql.catalyst.expressions.{AttributeSeq, Expression, ExprId, NamedExpression}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSeq, Expression, ExprId, NamedExpression}
import org.apache.spark.sql.rapids.catalyst.expressions.{GpuEquivalentExpressions, GpuExpressionEquals}
import org.apache.spark.sql.vectorized.ColumnarBatch


object AstUtil {
Expand Down Expand Up @@ -119,4 +123,164 @@ object AstUtil {
}
}
}

/**
* Transforms the original join condition into extra filter/project when necessary.
* It's targeted for some cases join condition is not fully evaluated by ast.
* Based on join condition, it can be transformed into three major strategies:
* (1) [NoopJoinCondSplit]: noop when join condition can be fully evaluated with ast.
* (2) [JoinCondSplitAsPostFilter]: entire join condition is pulled out as a post filter
* after join condition.
* (3) [JoinCondSplitAsProject]: extract not supported join condition into pre-project nodes
* on each join child. One extra project node is introduced to remove intermediate attributes.
*/
abstract class JoinCondSplitStrategy(left: Seq[NamedExpression],
right: Seq[NamedExpression], buildSide: GpuBuildSide) extends Serializable {

// Actual output of build/stream side project due to join condition split
private[this] val (buildOutputAttr, streamOutputAttr) = buildSide match {
case GpuBuildLeft => (joinLeftOutput, joinRightOutput)
case GpuBuildRight => (joinRightOutput, joinLeftOutput)
}

// This is the left side child of join. In `split as project` strategy, it may be different
// from original left child with extracted join condition attribute.
def leftOutput(): Seq[NamedExpression] = left

// This is the right side child of join. In `split as project` strategy, it may be different
// from original right child with extracted join condition attribute.
def rightOutput(): Seq[NamedExpression] = right

def astCondition(): Option[Expression]

def processBuildSideAndClose(input: ColumnarBatch): ColumnarBatch = input

def processStreamSideAndClose(input: ColumnarBatch): ColumnarBatch = input

def processPostJoin(iter: Iterator[ColumnarBatch]): Iterator[ColumnarBatch] = iter

// This is the left side child of join. In `split as project` strategy, it may be different
// from original left child with extracted join condition attribute.
def joinLeftOutput(): Seq[Attribute] = leftOutput.map(expr => expr.toAttribute)

// This is the right side child of join. In `split as project` strategy, it may be different
// from original right child with extracted join condition attribute.
def joinRightOutput(): Seq[Attribute] = rightOutput.map(expr => expr.toAttribute)

// Updated build attribute list after join condition split as project node.
// It may include extra attributes from split join condition.
def buildSideOutput(): Seq[Attribute] = buildOutputAttr

// Updated stream attribute list after join condition split as project node.
// It may include extra attributes from split join condition.
def streamedSideOutput(): Seq[Attribute] = streamOutputAttr
}

// For the case entire join condition can be evaluated as ast.
case class NoopJoinCondSplit(condition: Option[Expression], left: Seq[NamedExpression],
right: Seq[NamedExpression], buildSide: GpuBuildSide)
extends JoinCondSplitStrategy(left, right, buildSide) {
override def astCondition(): Option[Expression] = condition
}

// For inner joins we can apply a post-join condition for any conditions that cannot be
// evaluated directly in a mixed join that leverages a cudf AST expression.
case class JoinCondSplitAsPostFilter(expr: Option[Expression],
attributeSeq: Seq[Attribute], left: Seq[NamedExpression],
right: Seq[NamedExpression], buildSide: GpuBuildSide)
extends JoinCondSplitStrategy(left, right, buildSide) {
private[this] val postFilter = expr.map { e =>
GpuBindReferences.bindGpuReferencesTiered(
Seq(e), attributeSeq, false)
}

override def astCondition(): Option[Expression] = None

override def processPostJoin(iter: Iterator[ColumnarBatch]): Iterator[ColumnarBatch] = {
postFilter.map { filter =>
iter.flatMap { cb =>
GpuFilter.filterAndClose(cb, filter, NoopMetric, NoopMetric, NoopMetric)
}
}.getOrElse(iter)
}
}

/**
* This is the split strategy targeting on the case where ast not supported join condition can be
* extracted and wrapped into extra project node(s).
*
* @param astCond remained join condition after extracting ast not supported parts
* @param left original expressions from join's left child. It's left project input
* attribute.
* @param leftProj extra expressions extracted from original join condition which is not
* supported by ast. It will be evaluated as a project on left side batch.
* @param right original expressions from join's right child. It's left project input
* attribute.
* @param rightProj extra expressions extracted from original join condition which is not
* supported by ast. It will be evaluated as a project on right side batch.
* @param post eliminate the extra columns introduced by join condition split
* @param buildSide indicates which side is build
*/
case class JoinCondSplitAsProject(
astCond: Option[Expression],
left: Seq[NamedExpression], leftProj: Seq[NamedExpression],
right: Seq[NamedExpression], rightProj: Seq[NamedExpression],
post: Seq[NamedExpression], buildSide: GpuBuildSide
) extends JoinCondSplitStrategy(left ++ leftProj, right ++ rightProj, buildSide) {
private[this] val leftInput = left.map(_.toAttribute)
private[this] val rightInput = right.map(_.toAttribute)

// Used to build build/stream side project
private[this] val (buildOutput, streamOutput, buildInput, streamInput) = buildSide match {
case GpuBuildLeft =>
(leftOutput, rightOutput, leftInput, rightInput)
case GpuBuildRight =>
(rightOutput, leftOutput, rightInput, leftInput)
}

private[this] val buildProj = if (!buildOutput.isEmpty) {
Some(GpuBindReferences.bindGpuReferencesTiered(buildOutput, buildInput, false))
} else None

private[this] val streamProj = if (!streamOutput.isEmpty) {
Some(GpuBindReferences.bindGpuReferencesTiered(streamOutput, streamInput, false))
} else None

// Remove the intermediate attributes from left and right side project nodes. Output attributes
// need to be updated based on join type. And its attributes covers both original plan and
// extra project node.
private[this] val postProj = if (!post.isEmpty) {
Some(
GpuBindReferences.bindGpuReferencesTiered(
post, (leftOutput ++ rightOutput).map(_.toAttribute), false))
} else None

override def astCondition(): Option[Expression] = astCond

override def processBuildSideAndClose(input: ColumnarBatch): ColumnarBatch = {
buildProj.map { pj =>
withResource(input) { cb =>
pj.project(cb)
}
}.getOrElse(input)
}

override def processStreamSideAndClose(input: ColumnarBatch): ColumnarBatch = {
streamProj.map { pj =>
withResource(input) { cb =>
pj.project(cb)
}
}.getOrElse(input)
}

override def processPostJoin(iter: Iterator[ColumnarBatch]): Iterator[ColumnarBatch] = {
postProj.map { proj =>
iter.map { cb =>
withResource(cb) { b =>
proj.project(b)
}
}
}.getOrElse(iter)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.rapids.execution
import ai.rapids.cudf.{NvtxColor, NvtxRange}
import com.nvidia.spark.rapids._
import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource}
import com.nvidia.spark.rapids.AstUtil.JoinCondSplitStrategy
import com.nvidia.spark.rapids.shims.{GpuBroadcastJoinMeta, ShimBinaryExecNode}

import org.apache.spark.TaskContext
Expand Down Expand Up @@ -55,6 +56,28 @@ abstract class GpuBroadcastHashJoinMetaBase(

override val childExprs: Seq[BaseExprMeta[_]] = leftKeys ++ rightKeys ++ conditionMeta

private var taggedForAstCheck = false

// Avoid checking multiple times
private var isAstCond = false

/**
* Check whether condition can be ast-able. It includes two cases: 1) all join conditions are
* ast-able; 2) join conditions are ast-able after split and push down to child plans.
*/
def canJoinCondAstAble(): Boolean = {
if (!taggedForAstCheck) {
val Seq(leftPlan, rightPlan) = childPlans
isAstCond = conditionMeta match {
case Some(e) => AstUtil.canExtractNonAstConditionIfNeed(
e, leftPlan.outputAttributes.map(_.exprId), rightPlan.outputAttributes.map(_.exprId))
case None => true
}
taggedForAstCheck = true
}
isAstCond
}

override def tagPlanForGpu(): Unit = {
GpuHashJoin.tagJoin(this, join.joinType, buildSide, join.leftKeys, join.rightKeys,
conditionMeta)
Expand Down Expand Up @@ -103,6 +126,7 @@ abstract class GpuBroadcastHashJoinExecBase(
joinType: JoinType,
buildSide: GpuBuildSide,
override val condition: Option[Expression],
override val joinCondSplitStrategy: JoinCondSplitStrategy,
left: SparkPlan,
right: SparkPlan) extends ShimBinaryExecNode with GpuHashJoin {
import GpuMetric._
Expand Down
Loading

0 comments on commit aed41d2

Please sign in to comment.