diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/TrampolineUtil.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/TrampolineUtil.scala index 03f238ba124..a38cf364ba1 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/TrampolineUtil.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/TrampolineUtil.scala @@ -21,6 +21,7 @@ import org.json4s.JsonAST import org.apache.spark.{SparkContext, SparkEnv, SparkUpgradeException} import org.apache.spark.broadcast.Broadcast import org.apache.spark.executor.InputMetrics +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, IdentityBroadcastMode} import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.joins.HashedRelationBroadcastMode @@ -69,4 +70,7 @@ object TrampolineUtil { cause: Throwable): SparkUpgradeException = { new SparkUpgradeException(version, message, cause) } + + /** Shuts down and cleans up any existing Spark session */ + def cleanupAnyExistingSession(): Unit = SparkSession.cleanupAnyExistingSession() } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/SparkQueryCompareTestSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/SparkQueryCompareTestSuite.scala index 32ddd666067..e1c7e579ea2 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/SparkQueryCompareTestSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/SparkQueryCompareTestSuite.scala @@ -28,6 +28,7 @@ import org.apache.spark.SparkConf import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, Row, SparkSession} import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.rapids.execution.TrampolineUtil import org.apache.spark.sql.types._ object TestResourceFinder { @@ -225,6 +226,8 @@ trait SparkQueryCompareTestSuite extends FunSuite with Arm { : (Array[Row], SparkPlan, Array[Row], SparkPlan) = { conf.setIfMissing("spark.sql.shuffle.partitions", "2") + // force a new session to avoid accidentally capturing a late callback from a previous query + TrampolineUtil.cleanupAnyExistingSession() ExecutionPlanCaptureCallback.startCapture() var cpuPlan: Option[SparkPlan] = null val fromCpu =