diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/ShuffledBatchRDD.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/ShuffledBatchRDD.scala index 8f1fcf12043e..0719a7f90bde 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/ShuffledBatchRDD.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/ShuffledBatchRDD.scala @@ -156,46 +156,26 @@ class ShuffledBatchRDD( override def compute(split: Partition, context: TaskContext): Iterator[ColumnarBatch] = { val shuffledBatchPartition = split.asInstanceOf[ShuffledBatchRDDPartition] - // update partition metrics - metrics(GpuMetricNames.NUM_PARTITIONS).add(1) - val partitionSize = shuffledBatchPartition.spec match { - case CoalescedPartitionSpec(startReducerIndex, endReducerIndex) => - val blocksByAddress = SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId( - dependency.shuffleHandle.shuffleId, startReducerIndex, endReducerIndex) - blocksByAddress.flatMap(_._2).map(_._2).sum - - case PartialReducerPartitionSpec(reducerIndex, startMapIndex, endMapIndex) => - val blocksByAddress = SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId( - dependency.shuffleHandle.shuffleId, reducerIndex, reducerIndex + 1) - blocksByAddress.flatMap(_._2) - .filter(tuple => tuple._3 >= startMapIndex && tuple._3 < endMapIndex) - .map(_._2).sum - - case PartialMapperPartitionSpec(mapIndex, startReducerIndex, endReducerIndex) => - val blocksByAddress = SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId( - dependency.shuffleHandle.shuffleId, startReducerIndex, endReducerIndex) - blocksByAddress.flatMap(_._2) - .filter(_._3 == mapIndex) - .map(_._2).sum - } - metrics(GpuMetricNames.PARTITION_SIZE).add(partitionSize) - val tempMetrics = context.taskMetrics().createTempShuffleReadMetrics() // `SQLShuffleReadMetricsReporter` will update its own metrics for SQL exchange operator, // as well as the `tempMetrics` for basic shuffle metrics. val sqlMetricsReporter = new SQLShuffleReadMetricsReporter(tempMetrics, metrics) val shuffleManagerShims = ShimLoader.getSparkShims.getShuffleManagerShims() - val reader = shuffledBatchPartition.spec match { + val (reader, partitionSize) = shuffledBatchPartition.spec match { case CoalescedPartitionSpec(startReducerIndex, endReducerIndex) => - SparkEnv.get.shuffleManager.getReader( + val reader = SparkEnv.get.shuffleManager.getReader( dependency.shuffleHandle, startReducerIndex, endReducerIndex, context, sqlMetricsReporter) + val blocksByAddress = SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId( + dependency.shuffleHandle.shuffleId, startReducerIndex, endReducerIndex) + val partitionSize = blocksByAddress.flatMap(_._2).map(_._2).sum + (reader, partitionSize) case PartialReducerPartitionSpec(reducerIndex, startMapIndex, endMapIndex) => - shuffleManagerShims.getReader( + val reader = shuffleManagerShims.getReader( SparkEnv.get.shuffleManager, dependency.shuffleHandle, startMapIndex, @@ -204,9 +184,15 @@ class ShuffledBatchRDD( reducerIndex + 1, context, sqlMetricsReporter) + val blocksByAddress = SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId( + dependency.shuffleHandle.shuffleId, reducerIndex, reducerIndex + 1) + val partitionSize = blocksByAddress.flatMap(_._2) + .filter(tuple => tuple._3 >= startMapIndex && tuple._3 < endMapIndex) + .map(_._2).sum + (reader, partitionSize) case PartialMapperPartitionSpec(mapIndex, startReducerIndex, endReducerIndex) => - shuffleManagerShims.getReader( + val reader = shuffleManagerShims.getReader( SparkEnv.get.shuffleManager, dependency.shuffleHandle, mapIndex, @@ -215,7 +201,15 @@ class ShuffledBatchRDD( endReducerIndex, context, sqlMetricsReporter) + val blocksByAddress = SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId( + dependency.shuffleHandle.shuffleId, startReducerIndex, endReducerIndex) + val partitionSize = blocksByAddress.flatMap(_._2) + .filter(_._3 == mapIndex) + .map(_._2).sum + (reader, partitionSize) } + metrics(GpuMetricNames.NUM_PARTITIONS).add(1) + metrics(GpuMetricNames.PARTITION_SIZE).add(partitionSize) reader.read().asInstanceOf[Iterator[Product2[Int, ColumnarBatch]]].map(_._2) }