diff --git a/integration_tests/src/main/python/asserts.py b/integration_tests/src/main/python/asserts.py index 27b33848c2c..7d97c81af04 100644 --- a/integration_tests/src/main/python/asserts.py +++ b/integration_tests/src/main/python/asserts.py @@ -351,24 +351,31 @@ def assert_gpu_fallback_write(write_func, jvm.org.apache.spark.sql.rapids.ExecutionPlanCaptureCallback.startCapture() gpu_start = time.time() gpu_path = base_path + '/GPU' - with_gpu_session(lambda spark : write_func(spark, gpu_path), conf=conf) - gpu_end = time.time() - jvm.org.apache.spark.sql.rapids.ExecutionPlanCaptureCallback.assertCapturedAndGpuFellBack(cpu_fallback_class_name_list, 10000) - print('### WRITE: GPU TOOK {} CPU TOOK {} ###'.format( - gpu_end - gpu_start, cpu_end - cpu_start)) - - (cpu_bring_back, cpu_collect_type) = _prep_func_for_compare( - lambda spark: read_func(spark, cpu_path), 'COLLECT') - (gpu_bring_back, gpu_collect_type) = _prep_func_for_compare( - lambda spark: read_func(spark, gpu_path), 'COLLECT') - - from_cpu = with_cpu_session(cpu_bring_back, conf=conf) - from_gpu = with_cpu_session(gpu_bring_back, conf=conf) - if should_sort_locally(): - from_cpu.sort(key=_RowCmp) - from_gpu.sort(key=_RowCmp) + try: + with_gpu_session(lambda spark : write_func(spark, gpu_path), conf=conf) + gpu_end = time.time() + jvm.org.apache.spark.sql.rapids.ExecutionPlanCaptureCallback.assertCapturedAndGpuFellBack(cpu_fallback_class_name_list, 10000) + print('### WRITE: GPU TOOK {} CPU TOOK {} ###'.format( + gpu_end - gpu_start, cpu_end - cpu_start)) + + (cpu_bring_back, cpu_collect_type) = _prep_func_for_compare( + lambda spark: read_func(spark, cpu_path), 'COLLECT') + (gpu_bring_back, gpu_collect_type) = _prep_func_for_compare( + lambda spark: read_func(spark, gpu_path), 'COLLECT') + + from_cpu = with_cpu_session(cpu_bring_back, conf=conf) + from_gpu = with_cpu_session(gpu_bring_back, conf=conf) + if should_sort_locally(): + from_cpu.sort(key=_RowCmp) + from_gpu.sort(key=_RowCmp) + + assert_equal(from_cpu, from_gpu) + finally: + # Ensure `shouldCapture` state is restored. This may happen when GpuPlan is failed to be executed, + # then `shouldCapture` state is failed to restore in `assertCapturedAndGpuFellBack` method. + # This mostly happen within a xfail case where error may be ignored. + jvm.org.apache.spark.sql.rapids.ExecutionPlanCaptureCallback.endCapture() - assert_equal(from_cpu, from_gpu) def assert_cpu_and_gpu_are_equal_collect_with_capture(func, exist_classes='', diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/ExecutionPlanCaptureCallback.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/ExecutionPlanCaptureCallback.scala index c082de95241..bfdebc086c0 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/ExecutionPlanCaptureCallback.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/ExecutionPlanCaptureCallback.scala @@ -26,6 +26,8 @@ trait ExecutionPlanCaptureCallbackBase { def captureIfNeeded(qe: QueryExecution): Unit def startCapture(): Unit def startCapture(timeoutMillis: Long): Unit + def endCapture(): Unit + def endCapture(timeoutMillis: Long): Unit def getResultsWithTimeout(timeoutMs: Long = 10000): Array[SparkPlan] def extractExecutedPlan(plan: SparkPlan): SparkPlan def assertContains(gpuPlan: SparkPlan, className: String): Unit @@ -57,6 +59,10 @@ object ExecutionPlanCaptureCallback extends ExecutionPlanCaptureCallbackBase { override def startCapture(timeoutMillis: Long): Unit = impl.startCapture(timeoutMillis) + override def endCapture(): Unit = impl.endCapture() + + override def endCapture(timeoutMillis: Long): Unit = impl.endCapture(timeoutMillis) + override def getResultsWithTimeout(timeoutMs: Long = 10000): Array[SparkPlan] = impl.getResultsWithTimeout(timeoutMs) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/ShimmedExecutionPlanCaptureCallbackImpl.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/ShimmedExecutionPlanCaptureCallbackImpl.scala index 8f811496b3d..00379026f05 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/ShimmedExecutionPlanCaptureCallbackImpl.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/ShimmedExecutionPlanCaptureCallbackImpl.scala @@ -57,6 +57,15 @@ class ShimmedExecutionPlanCaptureCallbackImpl extends ExecutionPlanCaptureCallba } } + override def endCapture(): Unit = endCapture(10000) + + override def endCapture(timeoutMillis: Long): Unit = synchronized { + if (shouldCapture) { + shouldCapture = false + execPlans.clear() + } + } + override def getResultsWithTimeout(timeoutMs: Long = 10000): Array[SparkPlan] = { try { val spark = SparkSession.active