Skip to content

Commit

Permalink
Add Support for Multiple Filtering Keys for Subquery Broadcast (#10858)
Browse files Browse the repository at this point in the history
* Add support for multiple filtering keys for subquery broadcast

* Signing off

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

* Fixed test compilation

---------

Signed-off-by: Raza Jafri <rjafri@nvidia.com>
  • Loading branch information
razajafri authored May 28, 2024
1 parent c5da29d commit 3001852
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 33 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021-2023, NVIDIA CORPORATION.
* Copyright (c) 2021-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -41,14 +41,14 @@ import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.ThreadUtils


class GpuSubqueryBroadcastMeta(
abstract class GpuSubqueryBroadcastMetaBase(
s: SubqueryBroadcastExec,
conf: RapidsConf,
p: Option[RapidsMeta[_, _, _]],
r: DataFromReplacementRule) extends
SparkPlanMeta[SubqueryBroadcastExec](s, conf, p, r) {

private var broadcastBuilder: () => SparkPlan = _
protected var broadcastBuilder: () => SparkPlan = _

override val childExprs: Seq[BaseExprMeta[_]] = Nil

Expand Down Expand Up @@ -140,13 +140,8 @@ class GpuSubqueryBroadcastMeta(
*/
override def convertToCpu(): SparkPlan = s

override def convertToGpu(): GpuExec = {
GpuSubqueryBroadcastExec(s.name, s.index, s.buildKeys, broadcastBuilder())(
getBroadcastModeKeyExprs)
}

/** Extract the broadcast mode key expressions if there are any. */
private def getBroadcastModeKeyExprs: Option[Seq[Expression]] = {
protected def getBroadcastModeKeyExprs: Option[Seq[Expression]] = {
val broadcastMode = s.child match {
case b: BroadcastExchangeExec =>
b.mode
Expand All @@ -170,7 +165,7 @@ class GpuSubqueryBroadcastMeta(

case class GpuSubqueryBroadcastExec(
name: String,
index: Int,
indices: Seq[Int],
buildKeys: Seq[Expression],
child: SparkPlan)(modeKeys: Option[Seq[Expression]])
extends ShimBaseSubqueryExec with GpuExec with ShimUnaryExecNode {
Expand All @@ -182,16 +177,18 @@ case class GpuSubqueryBroadcastExec(
// correctly report the output length, so that `InSubqueryExec` can know it's the single-column
// execution mode, not multi-column.
override def output: Seq[Attribute] = {
val key = buildKeys(index)
val name = key match {
case n: NamedExpression =>
n.name
case cast: Cast if cast.child.isInstanceOf[NamedExpression] =>
cast.child.asInstanceOf[NamedExpression].name
case _ =>
"key"
indices.map { index =>
val key = buildKeys(index)
val name = key match {
case n: NamedExpression =>
n.name
case cast: Cast if cast.child.isInstanceOf[NamedExpression] =>
cast.child.asInstanceOf[NamedExpression].name
case _ =>
"key"
}
AttributeReference(name, key.dataType, key.nullable)()
}
Seq(AttributeReference(name, key.dataType, key.nullable)())
}

override lazy val additionalMetrics: Map[String, GpuMetric] = Map(
Expand All @@ -200,7 +197,7 @@ case class GpuSubqueryBroadcastExec(

override def doCanonicalize(): SparkPlan = {
val keys = buildKeys.map(k => QueryPlan.normalizeExpressions(k, child.output))
GpuSubqueryBroadcastExec("dpp", index, keys, child.canonicalized)(modeKeys)
GpuSubqueryBroadcastExec("dpp", indices, keys, child.canonicalized)(modeKeys)
}

@transient
Expand Down Expand Up @@ -235,28 +232,30 @@ case class GpuSubqueryBroadcastExec(
// are being extracted. The CPU already has the key projections applied in the broadcast
// data and thus does not have similar logic here.
val broadcastModeProject = modeKeys.map { keyExprs =>
val keyExpr = if (GpuHashJoin.canRewriteAsLongType(buildKeys)) {
val exprs = if (GpuHashJoin.canRewriteAsLongType(buildKeys)) {
// in this case, there is only 1 key expression since it's a packed version that encompasses
// multiple integral values into a single long using bit logic. In CPU Spark, the broadcast
// would create a LongHashedRelation instead of a standard HashedRelation.
keyExprs.head
indices.map { _ => keyExprs.head }
} else {
keyExprs(index)
indices.map { idx => keyExprs(idx) }
}
UnsafeProjection.create(keyExpr)
UnsafeProjection.create(exprs)
}

// Use the single output of the broadcast mode projection if it exists
val rowProjectIndex = if (broadcastModeProject.isDefined) 0 else index
val rowExpr = if (GpuHashJoin.canRewriteAsLongType(buildKeys)) {
val rowExprs = if (GpuHashJoin.canRewriteAsLongType(buildKeys)) {
// Since this is the expected output for a LongHashedRelation, we can extract the key from the
// long packed key using bit logic, using this method available in HashJoin to give us the
// correct key expression.
HashJoin.extractKeyExprAt(buildKeys, index)
// long packed key using bit logic, using this method available in HashJoin to give us the
// correct key expression.
indices.map { idx => HashJoin.extractKeyExprAt(buildKeys, idx) }
} else {
BoundReference(rowProjectIndex, buildKeys(index).dataType, buildKeys(index).nullable)
indices.map { idx =>
// Use the single output of the broadcast mode projection if it exists
val rowProjectIndex = if (broadcastModeProject.isDefined) 0 else idx
BoundReference(rowProjectIndex, buildKeys(idx).dataType, buildKeys(idx).nullable)
}
}
val rowProject = UnsafeProjection.create(rowExpr)
val rowProject = UnsafeProjection.create(rowExprs)

// Deserializes the batch on the host. Then, transforms it to rows and performs row-wise
// projection. We should NOT run any device operation on the driver node.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*** spark-rapids-shim-json-lines
{"spark": "311"}
{"spark": "312"}
{"spark": "313"}
{"spark": "320"}
{"spark": "321"}
{"spark": "321cdh"}
{"spark": "322"}
{"spark": "323"}
{"spark": "324"}
{"spark": "330"}
{"spark": "330cdh"}
{"spark": "331"}
{"spark": "332"}
{"spark": "332cdh"}
{"spark": "333"}
{"spark": "334"}
{"spark": "340"}
{"spark": "341"}
{"spark": "342"}
{"spark": "343"}
{"spark": "350"}
{"spark": "351"}
spark-rapids-shim-json-lines ***/
package org.apache.spark.sql.rapids.execution

import com.nvidia.spark.rapids.{DataFromReplacementRule, GpuExec, RapidsConf, RapidsMeta}

import org.apache.spark.sql.execution.SubqueryBroadcastExec

class GpuSubqueryBroadcastMeta(
s: SubqueryBroadcastExec,
conf: RapidsConf,
p: Option[RapidsMeta[_, _, _]],
r: DataFromReplacementRule) extends
GpuSubqueryBroadcastMetaBase(s, conf, p, r) {
override def convertToGpu(): GpuExec = {
GpuSubqueryBroadcastExec(s.name, Seq(s.index), s.buildKeys, broadcastBuilder())(
getBroadcastModeKeyExprs)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*** spark-rapids-shim-json-lines
{"spark": "400"}
spark-rapids-shim-json-lines ***/
package org.apache.spark.sql.rapids.execution

import com.nvidia.spark.rapids.{DataFromReplacementRule, GpuExec, RapidsConf, RapidsMeta}

import org.apache.spark.sql.execution.SubqueryBroadcastExec

class GpuSubqueryBroadcastMeta(
s: SubqueryBroadcastExec,
conf: RapidsConf,
p: Option[RapidsMeta[_, _, _]],
r: DataFromReplacementRule) extends
GpuSubqueryBroadcastMetaBase(s, conf, p, r) {
override def convertToGpu(): GpuExec = {
GpuSubqueryBroadcastExec(s.name, s.indices, s.buildKeys, broadcastBuilder())(
getBroadcastModeKeyExprs)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class DynamicPruningSuite
// NOTE: We remove the AdaptiveSparkPlanExec since we can't re-run the new plan
// under AQE because that fundamentally requires some rewrite and stage
// ordering which we can't do for this test.
case GpuSubqueryBroadcastExec(name, index, buildKeys, child) =>
case GpuSubqueryBroadcastExec(name, Seq(index), buildKeys, child) =>
val newChild = child match {
case a @ AdaptiveSparkPlanExec(_, _, _, _, _) =>
(new GpuTransitionOverrides()).apply(ColumnarToRowExec(a.executedPlan))
Expand Down

0 comments on commit 3001852

Please sign in to comment.