From 6b7cc777858a32e1942f9355f4560d0353c511c7 Mon Sep 17 00:00:00 2001 From: Niranjan Artal <50492963+nartal1@users.noreply.github.com> Date: Thu, 24 Sep 2020 13:04:23 -0700 Subject: [PATCH] Add AQE unit tests (#781) * Add AQE unit tests Signed-off-by: Niranjan Artal * Change broadcast threshold Signed-off-by: Niranjan Artal * addressed review comment Signed-off-by: Niranjan Artal * disable compression in tests Signed-off-by: Andy Grove * remove failing test Signed-off-by: Niranjan Artal * remove parquet compression flag as it is not needed Signed-off-by: Niranjan Artal Co-authored-by: Andy Grove --- .../spark/rapids/AdaptiveQueryExecSuite.scala | 105 +++++++++++++++++- 1 file changed, 101 insertions(+), 4 deletions(-) diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/AdaptiveQueryExecSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/AdaptiveQueryExecSuite.scala index 5bc0c218f83..9f58ed7a905 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/AdaptiveQueryExecSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/AdaptiveQueryExecSuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveS import org.apache.spark.sql.execution.command.DataWritingCommandExec import org.apache.spark.sql.execution.exchange.{Exchange, ReusedExchangeExec} import org.apache.spark.sql.execution.joins.SortMergeJoinExec -import org.apache.spark.sql.functions.when +import org.apache.spark.sql.functions.{col, when} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.rapids.execution.{GpuCustomShuffleReaderExec, GpuShuffledHashJoinBase} @@ -293,6 +293,88 @@ class AdaptiveQueryExecSuite }, conf) } + test("Change merge join to broadcast join without local shuffle reader") { + + assumeSpark301orLater + + val conf = new SparkConf() + .set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "true") + .set(SQLConf.LOCAL_SHUFFLE_READER_ENABLED.key, "true") + .set(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key, "400") + .set(RapidsConf.ENABLE_CAST_STRING_TO_INTEGER.key, "true") + .set(SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key, "50") + // disable DemoteBroadcastHashJoin rule from removing BHJ due to empty partitions + .set(SQLConf.NON_EMPTY_PARTITION_RATIO_FOR_BROADCAST_JOIN.key, "0") + + withGpuSparkSession(spark => { + setupTestData(spark) + val (plan, adaptivePlan) = runAdaptiveAndVerifyResult(spark, + """ + |SELECT * FROM lowerCaseData t1 join testData2 t2 + |ON t1.n = t2.a join testData3 t3 on t2.a = t3.a + |where t1.l = 1 + """.stripMargin) + + val smj = findTopLevelSortMergeJoin(plan) + assert(smj.size == 2) + val bhj = findTopLevelGpuBroadcastHashJoin(adaptivePlan) + assert(bhj.size == 1) + // There is still a SMJ, and its two shuffles can't apply local reader. + checkNumLocalShuffleReaders(adaptivePlan, 2) + }, conf) + } + + test("Verify the reader is LocalShuffleReaderExec") { + + assumeSpark301orLater + + val conf = new SparkConf() + .set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "true") + .set(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key, "400") + .set(SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key, "50") + // disable DemoteBroadcastHashJoin rule from removing BHJ due to empty partitions + .set(SQLConf.NON_EMPTY_PARTITION_RATIO_FOR_BROADCAST_JOIN.key, "0") + .set(SQLConf.SHUFFLE_PARTITIONS.key, "5") + .set(RapidsConf.ENABLE_CAST_STRING_TO_INTEGER.key, "true") + + withGpuSparkSession(spark => { + setupTestData(spark) + + val (plan, adaptivePlan) = runAdaptiveAndVerifyResult(spark, "SELECT * FROM testData join " + + "testData2 ON key = a where value = '1'") + + val smj = findTopLevelSortMergeJoin(plan) + assert(smj.size == 1) + + val bhj = findTopLevelGpuBroadcastHashJoin(adaptivePlan) + assert(bhj.size == 1) + val localReaders = collect(adaptivePlan) { + case reader: GpuCustomShuffleReaderExec if reader.isLocalReader => reader + } + // Verify local readers length + assert(localReaders.length == 2) + }, conf) + } + + private def checkNumLocalShuffleReaders( + plan: SparkPlan, + numShufflesWithoutLocalReader: Int = 0): Int = { + val numShuffles = collect(plan) { + case s: ShuffleQueryStageExec => s + }.length + + val numLocalReaders = collect(plan) { + case reader: GpuCustomShuffleReaderExec if reader.isLocalReader => reader + } + numLocalReaders.foreach { r => + val rdd = r.executeColumnar() + val parts = rdd.partitions + assert(parts.forall(rdd.preferredLocations(_).nonEmpty)) + } + assert(numShuffles === (numLocalReaders.length + numShufflesWithoutLocalReader)) + numLocalReaders.length + } + def skewJoinTest(fun: SparkSession => Unit) { assumeSpark301orLater @@ -369,6 +451,7 @@ class AdaptiveQueryExecSuite testData(spark) testData2(spark) testData3(spark) + lowerCaseData(spark) } /** Ported from org.apache.spark.sql.test.SQLTestData */ @@ -376,7 +459,7 @@ class AdaptiveQueryExecSuite import spark.implicits._ val data: Seq[(Int, String)] = (1 to 100).map(i => (i, i.toString)) val df = data.toDF("key", "value") - .repartition(6) + .repartition(col("key")) registerAsParquetTable(spark, df, "testData") } /** Ported from org.apache.spark.sql.test.SQLTestData */ @@ -384,7 +467,7 @@ class AdaptiveQueryExecSuite import spark.implicits._ val df = Seq[(Int, Int)]((1, 1), (1, 2), (2, 1), (2, 2), (3, 1), (3, 2)) .toDF("a", "b") - .repartition(2) + .repartition(col("a")) registerAsParquetTable(spark, df, "testData2") } @@ -393,10 +476,24 @@ class AdaptiveQueryExecSuite import spark.implicits._ val df = Seq[(Int, Option[Int])]((1, None), (2, Some(2))) .toDF("a", "b") - .repartition(6) + .repartition(col("a")) registerAsParquetTable(spark, df, "testData3") } + /** Ported from org.apache.spark.sql.test.SQLTestData */ + private def lowerCaseData(spark: SparkSession) { + import spark.implicits._ + // note that this differs from the original Spark test by generating a larger data set so that + // we can trigger larger stats in the logical mode, preventing BHJ, and then our queries filter + // this down to a smaller data set so that SMJ can be replaced with BHJ at execution time when + // AQE is enabled` + val data: Seq[(Int, String)] = (0 to 10000).map(i => (i, if (i<5) i.toString else "z")) + val df = data + .toDF("n", "l") + .repartition(col("n")) + registerAsParquetTable(spark, df, "lowercaseData") + } + private def registerAsParquetTable(spark: SparkSession, df: Dataset[Row], name: String) { val path = new File(TEST_FILES_ROOT, s"$name.parquet").getAbsolutePath df.write