diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleBufferCatalog.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleBufferCatalog.scala index 5856a9c7e30..0f228a0735f 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleBufferCatalog.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleBufferCatalog.scala @@ -49,7 +49,7 @@ case class ShuffleBufferId( /** Catalog for lookup of shuffle buffers by block ID */ class ShuffleBufferCatalog( catalog: RapidsBufferCatalog, - diskBlockManager: RapidsDiskBlockManager) extends Logging { + diskBlockManager: RapidsDiskBlockManager) extends Arm with Logging { /** * Information stored for each active shuffle. * NOTE: ArrayBuffer in blockMap must be explicitly locked when using it! @@ -176,6 +176,16 @@ class ShuffleBufferCatalog( */ def registerNewBuffer(buffer: RapidsBuffer): Unit = catalog.registerNewBuffer(buffer) + /** + * Update the spill priority of a shuffle buffer that soon will be read locally. + * @param id shuffle buffer identifier of buffer to update + */ + def updateSpillPriorityForLocalRead(id: ShuffleBufferId): Unit = { + withResource(catalog.acquireBuffer(id)) { buffer => + buffer.setSpillPriority(SpillPriorities.INPUT_FROM_SHUFFLE_PRIORITY) + } + } + /** * Lookup the shuffle buffer that corresponds to the specified shuffle buffer ID and acquire it. * NOTE: It is the responsibility of the caller to close the buffer. diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsCachingReader.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsCachingReader.scala index 59097606276..ce67d61351f 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsCachingReader.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsCachingReader.scala @@ -25,7 +25,6 @@ import com.nvidia.spark.rapids.shuffle.{RapidsShuffleIterator, RapidsShuffleTran import org.apache.spark.{InterruptibleIterator, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.shuffle.{ShuffleReader, ShuffleReadMetricsReporter} -import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockBatchId, ShuffleBlockId} import org.apache.spark.util.CompletionIterator @@ -50,19 +49,18 @@ class RapidsCachingReader[K, C]( rapidsConf: RapidsConf, localId: BlockManagerId, blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])], - gpuHandle: GpuShuffleHandle[_, _], context: TaskContext, metrics: ShuffleReadMetricsReporter, transport: Option[RapidsShuffleTransport], catalog: ShuffleBufferCatalog) - extends ShuffleReader[K, C] with Logging { + extends ShuffleReader[K, C] with Arm with Logging { override def read(): Iterator[Product2[K, C]] = { val readRange = new NvtxRange(s"RapidsCachingReader.read", NvtxColor.DARK_GREEN) try { val blocksForRapidsTransport = new ArrayBuffer[(BlockManagerId, Seq[(BlockId, Long, Int)])]() val cachedBlocks = new ArrayBuffer[BlockId]() - val cachedBatches = new ArrayBuffer[ColumnarBatch]() + val cachedBufferIds = new ArrayBuffer[ShuffleBufferId]() val blocksByAddressMap: Map[BlockManagerId, Seq[(BlockId, Long, Int)]] = blocksByAddress.toMap blocksByAddressMap.keys.foreach(blockManagerId => { @@ -75,28 +73,25 @@ class RapidsCachingReader[K, C]( blockInfos.foreach( blockInfo => { val blockId = blockInfo._1 - val shuffleBufferIds: Seq[ShuffleBufferId] = blockId match { + val shuffleBufferIds: IndexedSeq[ShuffleBufferId] = blockId match { case sbbid: ShuffleBlockBatchId => (sbbid.startReduceId to sbbid.endReduceId).flatMap { reduceId => cachedBlocks.append(blockId) val sBlockId = ShuffleBlockId(sbbid.shuffleId, sbbid.mapId, reduceId) - catalog.blockIdToBuffersIds(sBlockId).toSeq + catalog.blockIdToBuffersIds(sBlockId) } case sbid: ShuffleBlockId => cachedBlocks.append(blockId) - catalog.blockIdToBuffersIds(sbid).toSeq + catalog.blockIdToBuffersIds(sbid) case _ => throw new IllegalArgumentException( s"${blockId.getClass} $blockId is not currently supported") } - shuffleBufferIds.foreach { id => - val asb = catalog.acquireBuffer(id) - try { - cachedBatches += asb.getColumnarBatch - } finally { - asb.close() - } - } + cachedBufferIds ++= shuffleBufferIds + + // Update the spill priorities of these buffers to indicate they are about + // to be read and therefore should not be spilled if possible. + shuffleBufferIds.foreach(catalog.updateSpillPriorityForLocalRead) if (shuffleBufferIds.nonEmpty) { metrics.incLocalBlocksFetched(1) @@ -135,8 +130,11 @@ class RapidsCachingReader[K, C]( val itRange = new NvtxRange("Shuffle Iterator prep", NvtxColor.BLUE) try { - val cachedIt = cachedBatches.iterator.map(cb => { + val cachedIt = cachedBufferIds.iterator.map(bufferId => { GpuSemaphore.acquireIfNecessary(context) + val cb = withResource(catalog.acquireBuffer(bufferId)) { buffer => + buffer.getColumnarBatch + } val cachedBytesRead = GpuColumnVector.getTotalDeviceMemoryUsed(cb) metrics.incLocalBytesRead(cachedBytesRead) metrics.incRecordsRead(cb.numRows()) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManager.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManager.scala index e5ffb4abe0c..bcaa6f9c0ad 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManager.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManager.scala @@ -29,7 +29,6 @@ import org.apache.spark.scheduler.MapStatus import org.apache.spark.shuffle._ import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.sql.execution.metric.SQLMetric -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.storage._ @@ -331,7 +330,6 @@ abstract class RapidsShuffleInternalManagerBase(conf: SparkConf, isDriver: Boole new RapidsCachingReader(rapidsConf, localBlockManagerId, blocksByAddress, - gpu, context, metrics, transport,