Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement getShuffleRDD and fixup mismatched output types on shuffle reuse [databricks] #4257

Merged
merged 2 commits into from
Dec 2, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions integration_tests/src/main/python/repart_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,3 +267,14 @@ def test_hash_repartition_exact(gen, num_parts):
.withColumn('hashed', f.hash(*part_on))\
.selectExpr('*', 'pmod(hashed, {})'.format(num_parts)),
conf = allow_negative_scale_of_decimal_conf)

# Test a query that should cause Spark to leverage getShuffleRDD
@ignore_order(local=True)
def test_union_with_filter():
def doit(spark):
dfa = spark.range(1, 100).withColumn("id2", f.col("id"))
dfb = dfa.groupBy("id").agg(f.size(f.collect_set("id2")).alias("idc"))
dfc = dfb.filter(f.col("idc") == 1).select("id")
return dfc.union(dfc)
conf = { "spark.sql.adaptive.enabled": "true" }
assert_gpu_and_cpu_are_equal_collect(doit, conf)
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.plans.logical.Statistics
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.execution.{ShufflePartitionSpec, SparkPlan}
import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike
import org.apache.spark.sql.rapids.execution.GpuShuffleExchangeExecBaseWithMetrics
import org.apache.spark.sql.rapids.execution.{GpuShuffleExchangeExecBaseWithMetrics, ShuffledBatchRDD}

case class GpuShuffleExchangeExec(
gpuOutputPartitioning: GpuPartitioning,
Expand All @@ -43,7 +43,7 @@ case class GpuShuffleExchangeExec(
override def getShuffleRDD(
partitionSpecs: Array[ShufflePartitionSpec],
partitionSizes: Option[Array[Long]]): RDD[_] = {
throw new UnsupportedOperationException
new ShuffledBatchRDD(shuffleDependencyColumnar, metrics ++ readMetrics, partitionSpecs)
}

override def runtimeStatistics: Statistics = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.plans.logical.Statistics
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.execution.{ShufflePartitionSpec, SparkPlan}
import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike
import org.apache.spark.sql.rapids.execution.GpuShuffleExchangeExecBaseWithMetrics
import org.apache.spark.sql.rapids.execution.{GpuShuffleExchangeExecBaseWithMetrics, ShuffledBatchRDD}

case class GpuShuffleExchangeExec(
gpuOutputPartitioning: GpuPartitioning,
Expand All @@ -41,7 +41,7 @@ case class GpuShuffleExchangeExec(
override def numPartitions: Int = shuffleDependencyColumnar.partitioner.numPartitions

override def getShuffleRDD(partitionSpecs: Array[ShufflePartitionSpec]): RDD[_] = {
throw new UnsupportedOperationException
new ShuffledBatchRDD(shuffleDependencyColumnar, metrics ++ readMetrics, partitionSpecs)
}

override def runtimeStatistics: Statistics = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.plans.logical.Statistics
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.execution.{ShufflePartitionSpec, SparkPlan}
import org.apache.spark.sql.execution.exchange.{ShuffleExchangeLike, ShuffleOrigin}
import org.apache.spark.sql.rapids.execution.GpuShuffleExchangeExecBaseWithMetrics
import org.apache.spark.sql.rapids.execution.{GpuShuffleExchangeExecBaseWithMetrics, ShuffledBatchRDD}

case class GpuShuffleExchangeExec(
gpuOutputPartitioning: GpuPartitioning,
Expand All @@ -42,7 +42,7 @@ case class GpuShuffleExchangeExec(
override def numPartitions: Int = shuffleDependencyColumnar.partitioner.numPartitions

override def getShuffleRDD(partitionSpecs: Array[ShufflePartitionSpec]): RDD[_] = {
throw new UnsupportedOperationException
new ShuffledBatchRDD(shuffleDependencyColumnar, metrics ++ readMetrics, partitionSpecs)
}

override def runtimeStatistics: Statistics = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.plans.logical.Statistics
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.execution.{ShufflePartitionSpec, SparkPlan}
import org.apache.spark.sql.execution.exchange.{ShuffleExchangeLike, ShuffleOrigin}
import org.apache.spark.sql.rapids.execution.GpuShuffleExchangeExecBase
import org.apache.spark.sql.rapids.execution.{GpuShuffleExchangeExecBase, ShuffledBatchRDD}

case class GpuShuffleExchangeExec(
gpuOutputPartitioning: GpuPartitioning,
Expand Down Expand Up @@ -54,7 +54,7 @@ case class GpuShuffleExchangeExec(
override def getShuffleRDD(
partitionSpecs: Array[ShufflePartitionSpec],
partitionSizes: Option[Array[Long]]): RDD[_] = {
throw new UnsupportedOperationException
new ShuffledBatchRDD(shuffleDependencyColumnar, metrics ++ readMetrics, partitionSpecs)
}

// DB SPECIFIC - throw if called since we don't know how its used
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,30 @@ object GpuOverrides extends Logging {
}
}

/**
* Searches the plan for ReusedExchangeExec instances containing a GPU shuffle where the
* output types between the two plan nodes do not match. In such a case the ReusedExchangeExec
* will be updated to match the GPU shuffle output types.
*/
def fixupReusedExchangeExecs(plan: SparkPlan): SparkPlan = {
def outputTypesMatch(a: Seq[Attribute], b: Seq[Attribute]): Boolean =
a.corresponds(b)((x, y) => x.dataType == y.dataType)
plan.transformUp {
case sqse: ShuffleQueryStageExec =>
sqse.plan match {
case ReusedExchangeExec(output, gsee: GpuShuffleExchangeExecBase) if (
!outputTypesMatch(output, gsee.output)) =>
val newOutput = sqse.plan.output.zip(gsee.output).map { case (c, g) =>
assert(c.isInstanceOf[AttributeReference] && g.isInstanceOf[AttributeReference],
s"Expected AttributeReference but found $c and $g")
AttributeReference(c.name, g.dataType, c.nullable, c.metadata)(c.exprId, c.qualifier)
}
sqse.newReuseInstance(sqse.id, newOutput)
case _ => sqse
}
}
}

@scala.annotation.tailrec
def extractLit(exp: Expression): Option[Literal] = exp match {
case l: Literal => Some(l)
Expand Down Expand Up @@ -3910,7 +3934,11 @@ case class GpuOverrides() extends Rule[SparkPlan] with Logging {
val updatedPlan = if (plan.conf.adaptiveExecutionEnabled) {
// AQE can cause Spark to inject undesired CPU shuffles into the plan because GPU and CPU
// distribution expressions are not semantically equal.
GpuOverrides.removeExtraneousShuffles(plan, conf)
val newPlan = GpuOverrides.removeExtraneousShuffles(plan, conf)

// AQE can cause ReusedExchangeExec instance to cache the wrong aggregation buffer type
// compared to the desired buffer type from a reused GPU shuffle.
GpuOverrides.fixupReusedExchangeExecs(newPlan)
} else {
plan
}
Expand Down