Skip to content

Commit

Permalink
Add AQE unit tests (#781)
Browse files Browse the repository at this point in the history
* Add AQE unit tests

Signed-off-by: Niranjan Artal <nartal@nvidia.com>

* Change broadcast threshold

Signed-off-by: Niranjan Artal <nartal@nvidia.com>

* addressed review comment

Signed-off-by: Niranjan Artal <nartal@nvidia.com>

* disable compression in tests

Signed-off-by: Andy Grove <andygrove@nvidia.com>

* remove failing test

Signed-off-by: Niranjan Artal <nartal@nvidia.com>

* remove parquet compression flag as it is not needed

Signed-off-by: Niranjan Artal <nartal@nvidia.com>

Co-authored-by: Andy Grove <andygrove@nvidia.com>
  • Loading branch information
nartal1 and andygrove authored Sep 24, 2020
1 parent 8acf721 commit 6b7cc77
Showing 1 changed file with 101 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveS
import org.apache.spark.sql.execution.command.DataWritingCommandExec
import org.apache.spark.sql.execution.exchange.{Exchange, ReusedExchangeExec}
import org.apache.spark.sql.execution.joins.SortMergeJoinExec
import org.apache.spark.sql.functions.when
import org.apache.spark.sql.functions.{col, when}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.rapids.execution.{GpuCustomShuffleReaderExec, GpuShuffledHashJoinBase}

Expand Down Expand Up @@ -293,6 +293,88 @@ class AdaptiveQueryExecSuite
}, conf)
}

test("Change merge join to broadcast join without local shuffle reader") {

assumeSpark301orLater

val conf = new SparkConf()
.set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "true")
.set(SQLConf.LOCAL_SHUFFLE_READER_ENABLED.key, "true")
.set(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key, "400")
.set(RapidsConf.ENABLE_CAST_STRING_TO_INTEGER.key, "true")
.set(SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key, "50")
// disable DemoteBroadcastHashJoin rule from removing BHJ due to empty partitions
.set(SQLConf.NON_EMPTY_PARTITION_RATIO_FOR_BROADCAST_JOIN.key, "0")

withGpuSparkSession(spark => {
setupTestData(spark)
val (plan, adaptivePlan) = runAdaptiveAndVerifyResult(spark,
"""
|SELECT * FROM lowerCaseData t1 join testData2 t2
|ON t1.n = t2.a join testData3 t3 on t2.a = t3.a
|where t1.l = 1
""".stripMargin)

val smj = findTopLevelSortMergeJoin(plan)
assert(smj.size == 2)
val bhj = findTopLevelGpuBroadcastHashJoin(adaptivePlan)
assert(bhj.size == 1)
// There is still a SMJ, and its two shuffles can't apply local reader.
checkNumLocalShuffleReaders(adaptivePlan, 2)
}, conf)
}

test("Verify the reader is LocalShuffleReaderExec") {

assumeSpark301orLater

val conf = new SparkConf()
.set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "true")
.set(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key, "400")
.set(SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key, "50")
// disable DemoteBroadcastHashJoin rule from removing BHJ due to empty partitions
.set(SQLConf.NON_EMPTY_PARTITION_RATIO_FOR_BROADCAST_JOIN.key, "0")
.set(SQLConf.SHUFFLE_PARTITIONS.key, "5")
.set(RapidsConf.ENABLE_CAST_STRING_TO_INTEGER.key, "true")

withGpuSparkSession(spark => {
setupTestData(spark)

val (plan, adaptivePlan) = runAdaptiveAndVerifyResult(spark, "SELECT * FROM testData join " +
"testData2 ON key = a where value = '1'")

val smj = findTopLevelSortMergeJoin(plan)
assert(smj.size == 1)

val bhj = findTopLevelGpuBroadcastHashJoin(adaptivePlan)
assert(bhj.size == 1)
val localReaders = collect(adaptivePlan) {
case reader: GpuCustomShuffleReaderExec if reader.isLocalReader => reader
}
// Verify local readers length
assert(localReaders.length == 2)
}, conf)
}

private def checkNumLocalShuffleReaders(
plan: SparkPlan,
numShufflesWithoutLocalReader: Int = 0): Int = {
val numShuffles = collect(plan) {
case s: ShuffleQueryStageExec => s
}.length

val numLocalReaders = collect(plan) {
case reader: GpuCustomShuffleReaderExec if reader.isLocalReader => reader
}
numLocalReaders.foreach { r =>
val rdd = r.executeColumnar()
val parts = rdd.partitions
assert(parts.forall(rdd.preferredLocations(_).nonEmpty))
}
assert(numShuffles === (numLocalReaders.length + numShufflesWithoutLocalReader))
numLocalReaders.length
}

def skewJoinTest(fun: SparkSession => Unit) {
assumeSpark301orLater

Expand Down Expand Up @@ -369,22 +451,23 @@ class AdaptiveQueryExecSuite
testData(spark)
testData2(spark)
testData3(spark)
lowerCaseData(spark)
}

/** Ported from org.apache.spark.sql.test.SQLTestData */
private def testData(spark: SparkSession) {
import spark.implicits._
val data: Seq[(Int, String)] = (1 to 100).map(i => (i, i.toString))
val df = data.toDF("key", "value")
.repartition(6)
.repartition(col("key"))
registerAsParquetTable(spark, df, "testData") }

/** Ported from org.apache.spark.sql.test.SQLTestData */
private def testData2(spark: SparkSession) {
import spark.implicits._
val df = Seq[(Int, Int)]((1, 1), (1, 2), (2, 1), (2, 2), (3, 1), (3, 2))
.toDF("a", "b")
.repartition(2)
.repartition(col("a"))
registerAsParquetTable(spark, df, "testData2")
}

Expand All @@ -393,10 +476,24 @@ class AdaptiveQueryExecSuite
import spark.implicits._
val df = Seq[(Int, Option[Int])]((1, None), (2, Some(2)))
.toDF("a", "b")
.repartition(6)
.repartition(col("a"))
registerAsParquetTable(spark, df, "testData3")
}

/** Ported from org.apache.spark.sql.test.SQLTestData */
private def lowerCaseData(spark: SparkSession) {
import spark.implicits._
// note that this differs from the original Spark test by generating a larger data set so that
// we can trigger larger stats in the logical mode, preventing BHJ, and then our queries filter
// this down to a smaller data set so that SMJ can be replaced with BHJ at execution time when
// AQE is enabled`
val data: Seq[(Int, String)] = (0 to 10000).map(i => (i, if (i<5) i.toString else "z"))
val df = data
.toDF("n", "l")
.repartition(col("n"))
registerAsParquetTable(spark, df, "lowercaseData")
}

private def registerAsParquetTable(spark: SparkSession, df: Dataset[Row], name: String) {
val path = new File(TEST_FILES_ROOT, s"$name.parquet").getAbsolutePath
df.write
Expand Down

0 comments on commit 6b7cc77

Please sign in to comment.