diff --git a/integration_tests/src/main/scala/com/nvidia/spark/rapids/tests/common/BenchUtils.scala b/integration_tests/src/main/scala/com/nvidia/spark/rapids/tests/common/BenchUtils.scala index dcc1541795cc..338d717e9f76 100644 --- a/integration_tests/src/main/scala/com/nvidia/spark/rapids/tests/common/BenchUtils.scala +++ b/integration_tests/src/main/scala/com/nvidia/spark/rapids/tests/common/BenchUtils.scala @@ -15,7 +15,7 @@ */ package com.nvidia.spark.rapids.tests.common -import java.io.{File, FileOutputStream, FileWriter, PrintWriter} +import java.io.{File, FileOutputStream, FileWriter, PrintWriter, StringWriter} import java.time.Instant import java.util.concurrent.TimeUnit import java.util.concurrent.TimeUnit.NANOSECONDS @@ -150,34 +150,46 @@ object BenchUtils { val queryStartTime = Instant.now() val queryPlansWithMetrics = new ListBuffer[SparkPlanNode]() + val exceptions = new ListBuffer[String]() var df: DataFrame = null val queryTimes = new ListBuffer[Long]() for (i <- 0 until iterations) { - // capture spark plan metrics on the final run - if (i+1 == iterations) { - spark.listenerManager.register(new BenchmarkListener(queryPlansWithMetrics)) + // capture spark plan metrics on the first run + if (i == 0) { + spark.listenerManager.register(new BenchmarkListener(queryPlansWithMetrics, exceptions)) } println(s"*** Start iteration $i:") val start = System.nanoTime() - df = createDataFrame(spark) - - resultsAction match { - case Collect() => df.collect() - case WriteCsv(path, mode, options) => - df.write.mode(mode).options(options).csv(path) - case WriteOrc(path, mode, options) => - df.write.mode(mode).options(options).orc(path) - case WriteParquet(path, mode, options) => - df.write.mode(mode).options(options).parquet(path) - } + try { + df = createDataFrame(spark) + + resultsAction match { + case Collect() => df.collect() + case WriteCsv(path, mode, options) => + df.write.mode(mode).options(options).csv(path) + case WriteOrc(path, mode, options) => + df.write.mode(mode).options(options).orc(path) + case WriteParquet(path, mode, options) => + df.write.mode(mode).options(options).parquet(path) + } - val end = System.nanoTime() - val elapsed = NANOSECONDS.toMillis(end - start) - queryTimes.append(elapsed) - println(s"*** Iteration $i took $elapsed msec.") + val end = System.nanoTime() + val elapsed = NANOSECONDS.toMillis(end - start) + queryTimes.append(elapsed) + println(s"*** Iteration $i took $elapsed msec.") + + } catch { + case e: Exception => + val end = System.nanoTime() + val elapsed = NANOSECONDS.toMillis(end - start) + println(s"*** Iteration $i failed after $elapsed msec.") + queryTimes.append(-1) + exceptions.append(BenchUtils.toString(e)) + e.printStackTrace() + } // cause Spark to call unregisterShuffle if (gcBetweenRuns) { @@ -186,24 +198,29 @@ object BenchUtils { } } - // summarize all query times - for (i <- 0 until iterations) { - println(s"Iteration $i took ${queryTimes(i)} msec.") - } + // only show query times if there were no failed queries + if (!queryTimes.contains(-1)) { - // for multiple runs, summarize cold/hot timings - if (iterations > 1) { - println(s"Cold run: ${queryTimes(0)} msec.") - val hotRuns = queryTimes.drop(1) - val numHotRuns = hotRuns.length - println(s"Best of $numHotRuns hot run(s): ${hotRuns.min} msec.") - println(s"Worst of $numHotRuns hot run(s): ${hotRuns.max} msec.") - println(s"Average of $numHotRuns hot run(s): " + - s"${hotRuns.sum.toDouble / numHotRuns} msec.") + // summarize all query times + for (i <- 0 until iterations) { + println(s"Iteration $i took ${queryTimes(i)} msec.") + } + + // for multiple runs, summarize cold/hot timings + if (iterations > 1) { + println(s"Cold run: ${queryTimes(0)} msec.") + val hotRuns = queryTimes.drop(1) + val numHotRuns = hotRuns.length + println(s"Best of $numHotRuns hot run(s): ${hotRuns.min} msec.") + println(s"Worst of $numHotRuns hot run(s): ${hotRuns.max} msec.") + println(s"Average of $numHotRuns hot run(s): " + + s"${hotRuns.sum.toDouble / numHotRuns} msec.") + } } // write results to file - val filename = s"$filenameStub-${queryStartTime.toEpochMilli}.json" + val suffix = if (exceptions.isEmpty) "" else "-failed" + val filename = s"$filenameStub-${queryStartTime.toEpochMilli}$suffix.json" println(s"Saving benchmark report to $filename") // try not to leak secrets @@ -236,7 +253,8 @@ object BenchUtils { queryDescription, queryPlan, queryPlansWithMetrics, - queryTimes) + queryTimes, + exceptions) case w: WriteCsv => BenchmarkReport( filename, @@ -248,7 +266,8 @@ object BenchUtils { queryDescription, queryPlan, queryPlansWithMetrics, - queryTimes) + queryTimes, + exceptions) case w: WriteOrc => BenchmarkReport( filename, @@ -260,7 +279,8 @@ object BenchUtils { queryDescription, queryPlan, queryPlansWithMetrics, - queryTimes) + queryTimes, + exceptions) case w: WriteParquet => BenchmarkReport( filename, @@ -272,7 +292,8 @@ object BenchUtils { queryDescription, queryPlan, queryPlansWithMetrics, - queryTimes) + queryTimes, + exceptions) } writeReport(report, filename) @@ -317,7 +338,6 @@ object BenchUtils { } } - /** * Generate a DOT graph for one query plan, or showing differences between two query plans. * @@ -614,37 +634,49 @@ object BenchUtils { case (a, b) => a == b } } + + def toString(e: Exception): String = { + val sw = new StringWriter() + val w = new PrintWriter(sw) + e.printStackTrace(w) + w.close() + sw.toString + } } -class BenchmarkListener(list: ListBuffer[SparkPlanNode]) extends QueryExecutionListener { +class BenchmarkListener( + queryPlans: ListBuffer[SparkPlanNode], + exceptions: ListBuffer[String]) extends QueryExecutionListener { override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = { - def toJson(plan: SparkPlan): SparkPlanNode = { - plan match { - case WholeStageCodegenExec(child) => toJson(child) - case InputAdapter(child) => toJson(child) - case _ => - val children: Seq[SparkPlanNode] = plan match { - case s: AdaptiveSparkPlanExec => Seq(toJson(s.executedPlan)) - case s: QueryStageExec => Seq(toJson(s.plan)) - case _ => plan.children.map(child => toJson(child)) - } - val metrics: Seq[SparkSQLMetric] = plan.metrics - .map(m => SparkSQLMetric(m._1, m._2.metricType, m._2.value)).toSeq - - SparkPlanNode( - plan.id, - plan.nodeName, - plan.simpleStringWithNodeId(), - metrics, - children) - } - } - list += toJson(qe.executedPlan) + queryPlans += toJson(qe.executedPlan) } override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = { - exception.printStackTrace() + queryPlans += toJson(qe.executedPlan) + exceptions += BenchUtils.toString(exception) + } + + private def toJson(plan: SparkPlan): SparkPlanNode = { + plan match { + case WholeStageCodegenExec(child) => toJson(child) + case InputAdapter(child) => toJson(child) + case _ => + val children: Seq[SparkPlanNode] = plan match { + case s: AdaptiveSparkPlanExec => Seq(toJson(s.executedPlan)) + case s: QueryStageExec => Seq(toJson(s.plan)) + case _ => plan.children.map(child => toJson(child)) + } + val metrics: Seq[SparkSQLMetric] = plan.metrics + .map(m => SparkSQLMetric(m._1, m._2.metricType, m._2.value)).toSeq + + SparkPlanNode( + plan.id, + plan.nodeName, + plan.simpleStringWithNodeId(), + metrics, + children) + } } } @@ -668,7 +700,8 @@ case class BenchmarkReport( query: String, queryPlan: QueryPlan, queryPlans: Seq[SparkPlanNode], - queryTimes: Seq[Long]) + queryTimes: Seq[Long], + exceptions: Seq[String]) /** Configuration options that affect how the tests are run */ case class TestConfiguration( diff --git a/integration_tests/src/test/scala/com/nvidia/spark/rapids/tests/common/BenchUtilsSuite.scala b/integration_tests/src/test/scala/com/nvidia/spark/rapids/tests/common/BenchUtilsSuite.scala index 8d6faa1fa75a..415550519db1 100644 --- a/integration_tests/src/test/scala/com/nvidia/spark/rapids/tests/common/BenchUtilsSuite.scala +++ b/integration_tests/src/test/scala/com/nvidia/spark/rapids/tests/common/BenchUtilsSuite.scala @@ -50,7 +50,8 @@ class BenchUtilsSuite extends FunSuite with BeforeAndAfterEach { query = "q1", queryPlan = QueryPlan("logical", "physical"), Seq.empty, - queryTimes = Seq(99, 88, 77)) + queryTimes = Seq(99, 88, 77), + exceptions = Seq.empty) val filename = s"$TEST_FILES_ROOT/BenchUtilsSuite-${System.currentTimeMillis()}.json" BenchUtils.writeReport(report, filename)