From f7e8536e11a1dea79c023fcb16d1760a76b469d0 Mon Sep 17 00:00:00 2001 From: Jason Lowe Date: Tue, 11 Aug 2020 09:24:00 -0500 Subject: [PATCH] Use fresh SparkSession when capturing to avoid late capture of previous query (#537) Signed-off-by: Jason Lowe --- .../apache/spark/sql/rapids/execution/TrampolineUtil.scala | 4 ++++ .../com/nvidia/spark/rapids/SparkQueryCompareTestSuite.scala | 3 +++ 2 files changed, 7 insertions(+) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/TrampolineUtil.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/TrampolineUtil.scala index 03f238ba124..a38cf364ba1 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/TrampolineUtil.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/TrampolineUtil.scala @@ -21,6 +21,7 @@ import org.json4s.JsonAST import org.apache.spark.{SparkContext, SparkEnv, SparkUpgradeException} import org.apache.spark.broadcast.Broadcast import org.apache.spark.executor.InputMetrics +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, IdentityBroadcastMode} import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.joins.HashedRelationBroadcastMode @@ -69,4 +70,7 @@ object TrampolineUtil { cause: Throwable): SparkUpgradeException = { new SparkUpgradeException(version, message, cause) } + + /** Shuts down and cleans up any existing Spark session */ + def cleanupAnyExistingSession(): Unit = SparkSession.cleanupAnyExistingSession() } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/SparkQueryCompareTestSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/SparkQueryCompareTestSuite.scala index 32ddd666067..e1c7e579ea2 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/SparkQueryCompareTestSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/SparkQueryCompareTestSuite.scala @@ -28,6 +28,7 @@ import org.apache.spark.SparkConf import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, Row, SparkSession} import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.rapids.execution.TrampolineUtil import org.apache.spark.sql.types._ object TestResourceFinder { @@ -225,6 +226,8 @@ trait SparkQueryCompareTestSuite extends FunSuite with Arm { : (Array[Row], SparkPlan, Array[Row], SparkPlan) = { conf.setIfMissing("spark.sql.shuffle.partitions", "2") + // force a new session to avoid accidentally capturing a late callback from a previous query + TrampolineUtil.cleanupAnyExistingSession() ExecutionPlanCaptureCallback.startCapture() var cpuPlan: Option[SparkPlan] = null val fromCpu =