Skip to content

Commit

Permalink
refactor to combine metrics and reader in single match statement
Browse files Browse the repository at this point in the history
Signed-off-by: Andy Grove <andygrove@nvidia.com>
  • Loading branch information
andygrove committed Nov 17, 2020
1 parent dedecd4 commit 69c39a7
Showing 1 changed file with 22 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -156,46 +156,26 @@ class ShuffledBatchRDD(

override def compute(split: Partition, context: TaskContext): Iterator[ColumnarBatch] = {
val shuffledBatchPartition = split.asInstanceOf[ShuffledBatchRDDPartition]
// update partition metrics
metrics(GpuMetricNames.NUM_PARTITIONS).add(1)
val partitionSize = shuffledBatchPartition.spec match {
case CoalescedPartitionSpec(startReducerIndex, endReducerIndex) =>
val blocksByAddress = SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId(
dependency.shuffleHandle.shuffleId, startReducerIndex, endReducerIndex)
blocksByAddress.flatMap(_._2).map(_._2).sum

case PartialReducerPartitionSpec(reducerIndex, startMapIndex, endMapIndex) =>
val blocksByAddress = SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId(
dependency.shuffleHandle.shuffleId, reducerIndex, reducerIndex + 1)
blocksByAddress.flatMap(_._2)
.filter(tuple => tuple._3 >= startMapIndex && tuple._3 < endMapIndex)
.map(_._2).sum

case PartialMapperPartitionSpec(mapIndex, startReducerIndex, endReducerIndex) =>
val blocksByAddress = SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId(
dependency.shuffleHandle.shuffleId, startReducerIndex, endReducerIndex)
blocksByAddress.flatMap(_._2)
.filter(_._3 == mapIndex)
.map(_._2).sum
}
metrics(GpuMetricNames.PARTITION_SIZE).add(partitionSize)

val tempMetrics = context.taskMetrics().createTempShuffleReadMetrics()
// `SQLShuffleReadMetricsReporter` will update its own metrics for SQL exchange operator,
// as well as the `tempMetrics` for basic shuffle metrics.
val sqlMetricsReporter = new SQLShuffleReadMetricsReporter(tempMetrics, metrics)
val shuffleManagerShims = ShimLoader.getSparkShims.getShuffleManagerShims()
val reader = shuffledBatchPartition.spec match {
val (reader, partitionSize) = shuffledBatchPartition.spec match {
case CoalescedPartitionSpec(startReducerIndex, endReducerIndex) =>
SparkEnv.get.shuffleManager.getReader(
val reader = SparkEnv.get.shuffleManager.getReader(
dependency.shuffleHandle,
startReducerIndex,
endReducerIndex,
context,
sqlMetricsReporter)
val blocksByAddress = SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId(
dependency.shuffleHandle.shuffleId, startReducerIndex, endReducerIndex)
val partitionSize = blocksByAddress.flatMap(_._2).map(_._2).sum
(reader, partitionSize)

case PartialReducerPartitionSpec(reducerIndex, startMapIndex, endMapIndex) =>
shuffleManagerShims.getReader(
val reader = shuffleManagerShims.getReader(
SparkEnv.get.shuffleManager,
dependency.shuffleHandle,
startMapIndex,
Expand All @@ -204,9 +184,15 @@ class ShuffledBatchRDD(
reducerIndex + 1,
context,
sqlMetricsReporter)
val blocksByAddress = SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId(
dependency.shuffleHandle.shuffleId, reducerIndex, reducerIndex + 1)
val partitionSize = blocksByAddress.flatMap(_._2)
.filter(tuple => tuple._3 >= startMapIndex && tuple._3 < endMapIndex)
.map(_._2).sum
(reader, partitionSize)

case PartialMapperPartitionSpec(mapIndex, startReducerIndex, endReducerIndex) =>
shuffleManagerShims.getReader(
val reader = shuffleManagerShims.getReader(
SparkEnv.get.shuffleManager,
dependency.shuffleHandle,
mapIndex,
Expand All @@ -215,7 +201,15 @@ class ShuffledBatchRDD(
endReducerIndex,
context,
sqlMetricsReporter)
val blocksByAddress = SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId(
dependency.shuffleHandle.shuffleId, startReducerIndex, endReducerIndex)
val partitionSize = blocksByAddress.flatMap(_._2)
.filter(_._3 == mapIndex)
.map(_._2).sum
(reader, partitionSize)
}
metrics(GpuMetricNames.NUM_PARTITIONS).add(1)
metrics(GpuMetricNames.PARTITION_SIZE).add(partitionSize)
reader.read().asInstanceOf[Iterator[Product2[Int, ColumnarBatch]]].map(_._2)
}

Expand Down

0 comments on commit 69c39a7

Please sign in to comment.