Skip to content

Commit

Permalink
Only manifest the current batch in cached block shuffle read iterator (
Browse files Browse the repository at this point in the history
…#892)

Signed-off-by: Jason Lowe <jlowe@nvidia.com>
  • Loading branch information
jlowe authored Sep 30, 2020
1 parent 2603a95 commit 666f89b
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ case class ShuffleBufferId(
/** Catalog for lookup of shuffle buffers by block ID */
class ShuffleBufferCatalog(
catalog: RapidsBufferCatalog,
diskBlockManager: RapidsDiskBlockManager) extends Logging {
diskBlockManager: RapidsDiskBlockManager) extends Arm with Logging {
/**
* Information stored for each active shuffle.
* NOTE: ArrayBuffer in blockMap must be explicitly locked when using it!
Expand Down Expand Up @@ -176,6 +176,16 @@ class ShuffleBufferCatalog(
*/
def registerNewBuffer(buffer: RapidsBuffer): Unit = catalog.registerNewBuffer(buffer)

/**
* Update the spill priority of a shuffle buffer that soon will be read locally.
* @param id shuffle buffer identifier of buffer to update
*/
def updateSpillPriorityForLocalRead(id: ShuffleBufferId): Unit = {
withResource(catalog.acquireBuffer(id)) { buffer =>
buffer.setSpillPriority(SpillPriorities.INPUT_FROM_SHUFFLE_PRIORITY)
}
}

/**
* Lookup the shuffle buffer that corresponds to the specified shuffle buffer ID and acquire it.
* NOTE: It is the responsibility of the caller to close the buffer.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import com.nvidia.spark.rapids.shuffle.{RapidsShuffleIterator, RapidsShuffleTran
import org.apache.spark.{InterruptibleIterator, TaskContext}
import org.apache.spark.internal.Logging
import org.apache.spark.shuffle.{ShuffleReader, ShuffleReadMetricsReporter}
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockBatchId, ShuffleBlockId}
import org.apache.spark.util.CompletionIterator

Expand All @@ -50,19 +49,18 @@ class RapidsCachingReader[K, C](
rapidsConf: RapidsConf,
localId: BlockManagerId,
blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])],
gpuHandle: GpuShuffleHandle[_, _],
context: TaskContext,
metrics: ShuffleReadMetricsReporter,
transport: Option[RapidsShuffleTransport],
catalog: ShuffleBufferCatalog)
extends ShuffleReader[K, C] with Logging {
extends ShuffleReader[K, C] with Arm with Logging {

override def read(): Iterator[Product2[K, C]] = {
val readRange = new NvtxRange(s"RapidsCachingReader.read", NvtxColor.DARK_GREEN)
try {
val blocksForRapidsTransport = new ArrayBuffer[(BlockManagerId, Seq[(BlockId, Long, Int)])]()
val cachedBlocks = new ArrayBuffer[BlockId]()
val cachedBatches = new ArrayBuffer[ColumnarBatch]()
val cachedBufferIds = new ArrayBuffer[ShuffleBufferId]()
val blocksByAddressMap: Map[BlockManagerId, Seq[(BlockId, Long, Int)]] = blocksByAddress.toMap

blocksByAddressMap.keys.foreach(blockManagerId => {
Expand All @@ -75,28 +73,25 @@ class RapidsCachingReader[K, C](
blockInfos.foreach(
blockInfo => {
val blockId = blockInfo._1
val shuffleBufferIds: Seq[ShuffleBufferId] = blockId match {
val shuffleBufferIds: IndexedSeq[ShuffleBufferId] = blockId match {
case sbbid: ShuffleBlockBatchId =>
(sbbid.startReduceId to sbbid.endReduceId).flatMap { reduceId =>
cachedBlocks.append(blockId)
val sBlockId = ShuffleBlockId(sbbid.shuffleId, sbbid.mapId, reduceId)
catalog.blockIdToBuffersIds(sBlockId).toSeq
catalog.blockIdToBuffersIds(sBlockId)
}
case sbid: ShuffleBlockId =>
cachedBlocks.append(blockId)
catalog.blockIdToBuffersIds(sbid).toSeq
catalog.blockIdToBuffersIds(sbid)
case _ => throw new IllegalArgumentException(
s"${blockId.getClass} $blockId is not currently supported")
}

shuffleBufferIds.foreach { id =>
val asb = catalog.acquireBuffer(id)
try {
cachedBatches += asb.getColumnarBatch
} finally {
asb.close()
}
}
cachedBufferIds ++= shuffleBufferIds

// Update the spill priorities of these buffers to indicate they are about
// to be read and therefore should not be spilled if possible.
shuffleBufferIds.foreach(catalog.updateSpillPriorityForLocalRead)

if (shuffleBufferIds.nonEmpty) {
metrics.incLocalBlocksFetched(1)
Expand Down Expand Up @@ -135,8 +130,11 @@ class RapidsCachingReader[K, C](

val itRange = new NvtxRange("Shuffle Iterator prep", NvtxColor.BLUE)
try {
val cachedIt = cachedBatches.iterator.map(cb => {
val cachedIt = cachedBufferIds.iterator.map(bufferId => {
GpuSemaphore.acquireIfNecessary(context)
val cb = withResource(catalog.acquireBuffer(bufferId)) { buffer =>
buffer.getColumnarBatch
}
val cachedBytesRead = GpuColumnVector.getTotalDeviceMemoryUsed(cb)
metrics.incLocalBytesRead(cachedBytesRead)
metrics.incRecordsRead(cb.numRows())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ import org.apache.spark.scheduler.MapStatus
import org.apache.spark.shuffle._
import org.apache.spark.shuffle.sort.SortShuffleManager
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.storage._

Expand Down Expand Up @@ -331,7 +330,6 @@ abstract class RapidsShuffleInternalManagerBase(conf: SparkConf, isDriver: Boole

new RapidsCachingReader(rapidsConf, localBlockManagerId,
blocksByAddress,
gpu,
context,
metrics,
transport,
Expand Down

0 comments on commit 666f89b

Please sign in to comment.