diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/InternalRowToColumnarBatchIterator.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/InternalRowToColumnarBatchIterator.java index a1f878cb078..7d7046b2f24 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/InternalRowToColumnarBatchIterator.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/InternalRowToColumnarBatchIterator.java @@ -147,55 +147,56 @@ public ColumnarBatch next() { // Update our estimate for number of rows with the final size used to allocate the buffers. numRowsEstimate = (int) bufsAndNumRows._2.targetSize(); long dataLength = calcDataLengthEstimate(numRowsEstimate); - try ( - SpillableHostBuffer sdb = bufsAndNumRows._1[0]; - SpillableHostBuffer sob = bufsAndNumRows._1[1]; + int used[]; + try (SpillableHostBuffer spillableDataBuffer = bufsAndNumRows._1[0]; + SpillableHostBuffer spillableOffsetsBuffer = bufsAndNumRows._1[1]; ) { - // Fill in buffer under write lock for host buffers - batchAndRange = sdb.withHostBufferWriteLock( (dataBuffer) -> { - return sob.withHostBufferWriteLock( (offsetsBuffer) -> { - int[] used = fillBatch(dataBuffer, offsetsBuffer, dataLength, numRowsEstimate); - int dataOffset = used[0]; - int currentRow = used[1]; - // We don't want to loop forever trying to copy nothing - assert (currentRow > 0); - if (numInputRows != null) { - numInputRows.add(currentRow); - } - if (numOutputRows != null) { - numOutputRows.add(currentRow); - } - if (numOutputBatches != null) { - numOutputBatches.add(1); - } - // Now that we have filled the buffers with the data, we need to turn them into a - // HostColumnVector and copy them to the device so the GPU can turn it into a Table. - // To do this we first need to make a HostColumnCoreVector for the data, and then - // put that into a HostColumnVector as its child. This the basics of building up - // a column of lists of bytes in CUDF but it is typically hidden behind the higer level - // APIs. - dataBuffer.incRefCount(); - offsetsBuffer.incRefCount(); - try (HostColumnVectorCore dataCv = - new HostColumnVectorCore(DType.INT8, dataOffset, Optional.of(0L), - dataBuffer, null, null, new ArrayList<>()); - HostColumnVector hostColumn = new HostColumnVector(DType.LIST, - currentRow, Optional.of(0L), null, null, - offsetsBuffer, Collections.singletonList(dataCv))) { + HostMemoryBuffer[] hBufs = + getHostBuffersWithRetry(spillableDataBuffer, spillableOffsetsBuffer); + try(HostMemoryBuffer dataBuffer = hBufs[0]; + HostMemoryBuffer offsetsBuffer = hBufs[1]; + ) { + used = fillBatch(dataBuffer, offsetsBuffer, dataLength, numRowsEstimate); + int dataOffset = used[0]; + int currentRow = used[1]; + // We don't want to loop forever trying to copy nothing + assert (currentRow > 0); + if (numInputRows != null) { + numInputRows.add(currentRow); + } + if (numOutputRows != null) { + numOutputRows.add(currentRow); + } + if (numOutputBatches != null) { + numOutputBatches.add(1); + } + // Now that we have filled the buffers with the data, we need to turn them into a + // HostColumnVector and copy them to the device so the GPU can turn it into a Table. + // To do this we first need to make a HostColumnCoreVector for the data, and then + // put that into a HostColumnVector as its child. This the basics of building up + // a column of lists of bytes in CUDF but it is typically hidden behind the higer level + // APIs. + dataBuffer.incRefCount(); + offsetsBuffer.incRefCount(); + try (HostColumnVectorCore dataCv = + new HostColumnVectorCore(DType.INT8, dataOffset, Optional.of(0L), + dataBuffer, null, null, new ArrayList<>()); + HostColumnVector hostColumn = new HostColumnVector(DType.LIST, + currentRow, Optional.of(0L), null, null, + offsetsBuffer, Collections.singletonList(dataCv))) { - long ct = System.nanoTime() - collectStart; - streamTime.add(ct); + long ct = System.nanoTime() - collectStart; + streamTime.add(ct); - // Grab the semaphore because we are about to put data onto the GPU. - GpuSemaphore$.MODULE$.acquireIfNecessary(TaskContext.get()); - NvtxRange range = NvtxWithMetrics.apply("RowToColumnar: build", NvtxColor.GREEN, - Option.apply(opTime)); - ColumnVector devColumn = - RmmRapidsRetryIterator.withRetryNoSplit(hostColumn::copyToDevice); - return Tuple2.apply(makeSpillableBatch(devColumn), range); - } - }); - }); + // Grab the semaphore because we are about to put data onto the GPU. + GpuSemaphore$.MODULE$.acquireIfNecessary(TaskContext.get()); + NvtxRange range = NvtxWithMetrics.apply("RowToColumnar: build", NvtxColor.GREEN, + Option.apply(opTime)); + ColumnVector devColumn = + RmmRapidsRetryIterator.withRetryNoSplit(hostColumn::copyToDevice); + batchAndRange = Tuple2.apply(makeSpillableBatch(devColumn), range); + } + } } try (NvtxRange ignored = batchAndRange._2; Table tab = @@ -208,6 +209,20 @@ public ColumnarBatch next() { } } + private HostMemoryBuffer[] getHostBuffersWithRetry( + SpillableHostBuffer spillableDataBuffer, SpillableHostBuffer spillableOffsetsBuffer) { + return RmmRapidsRetryIterator.withRetryNoSplit( () -> { + try (HostMemoryBuffer dataBuffer = spillableDataBuffer.getHostBuffer(); + HostMemoryBuffer offsetsBuffer = spillableOffsetsBuffer.getHostBuffer(); + ) { + // Increment these to keep them. + dataBuffer.incRefCount(); + offsetsBuffer.incRefCount(); + return new HostMemoryBuffer[] { dataBuffer, offsetsBuffer }; + } + }); + } + private Tuple2 allocBuffers(SpillableHostBuffer[] sBufs, AutoCloseableTargetSize numRowsWrapper) { HostMemoryBuffer[] hBufs = new HostMemoryBuffer[]{ null, null }; diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBuffer.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBuffer.scala index 657cbb33dd0..a332755745f 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBuffer.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBuffer.scala @@ -21,7 +21,7 @@ import java.nio.channels.WritableByteChannel import scala.collection.mutable.ArrayBuffer -import ai.rapids.cudf.{Cuda, DeviceMemoryBuffer, MemoryBuffer, Table} +import ai.rapids.cudf.{Cuda, DeviceMemoryBuffer, HostMemoryBuffer, MemoryBuffer, Table} import com.nvidia.spark.rapids.Arm.withResource import com.nvidia.spark.rapids.RapidsPluginImplicits._ import com.nvidia.spark.rapids.StorageTier.StorageTier @@ -320,6 +320,15 @@ trait RapidsBuffer extends AutoCloseable { */ def getDeviceMemoryBuffer: DeviceMemoryBuffer + /** + * Get the host memory buffer from the underlying storage. If the buffer currently resides + * outside of host memory, a new HostMemoryBuffer is created with the data copied over. + * The caller must have successfully acquired the buffer beforehand. + * @see [[addReference]] + * @note It is the responsibility of the caller to close the buffer. + */ + def getHostMemoryBuffer: HostMemoryBuffer + /** * Try to add a reference to this buffer to acquire it. * @note The close method must be called for every successfully obtained reference. @@ -425,6 +434,9 @@ sealed class DegenerateRapidsBuffer( override def getDeviceMemoryBuffer: DeviceMemoryBuffer = throw new UnsupportedOperationException("degenerate buffer has no device memory buffer") + override def getHostMemoryBuffer: HostMemoryBuffer = + throw new UnsupportedOperationException("degenerate buffer has no host memory buffer") + override def addReference(): Boolean = true override def getSpillPriority: Long = Long.MaxValue diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferCatalog.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferCatalog.scala index 61a636c1708..f98b52ae022 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferCatalog.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferCatalog.scala @@ -613,7 +613,7 @@ class RapidsBufferCatalog( } } - def updateTiers(bufferSpill: BufferSpill): Long = bufferSpill match { + def updateTiers(bufferSpill: SpillAction): Long = bufferSpill match { case BufferSpill(spilledBuffer, maybeNewBuffer) => logDebug(s"Spilled ${spilledBuffer.id} from tier ${spilledBuffer.storageTier}. " + s"Removing. Registering ${maybeNewBuffer.map(_.id).getOrElse ("None")} " + @@ -621,6 +621,14 @@ class RapidsBufferCatalog( maybeNewBuffer.foreach(registerNewBuffer) removeBufferTier(spilledBuffer.id, spilledBuffer.storageTier) spilledBuffer.memoryUsedBytes + + case BufferUnspill(unspilledBuffer, maybeNewBuffer) => + logDebug(s"Unspilled ${unspilledBuffer.id} from tier ${unspilledBuffer.storageTier}. " + + s"Removing. Registering ${maybeNewBuffer.map(_.id).getOrElse ("None")} " + + s"${maybeNewBuffer}") + maybeNewBuffer.foreach(registerNewBuffer) + removeBufferTier(unspilledBuffer.id, unspilledBuffer.storageTier) + unspilledBuffer.memoryUsedBytes } /** @@ -647,6 +655,34 @@ class RapidsBufferCatalog( } } + /** + * Copies `buffer` to the `hostStorage` store, registering a new `RapidsBuffer` in + * the process + * + * @param buffer - buffer to copy + * @param stream - Cuda.Stream to synchronize on + * @return - The `RapidsBuffer` instance that was added to the host store. + */ + def unspillBufferToHostStore( + buffer: RapidsBuffer, + stream: Cuda.Stream): RapidsBuffer = synchronized { + // try to acquire the buffer, if it's already in the store + // do not create a new one, else add a reference + acquireBuffer(buffer.id, StorageTier.HOST) match { + case Some(existingBuffer) => existingBuffer + case None => + val maybeNewBuffer = hostStorage.copyBuffer(buffer, this, stream) + maybeNewBuffer.map { newBuffer => + logDebug(s"got new RapidsHostMemoryStore buffer ${newBuffer.id}") + newBuffer.addReference() // add a reference since we are about to use it + updateTiers(BufferUnspill(buffer, Some(newBuffer))) + buffer.safeFree() + newBuffer + }.get // the host store has to return a buffer here for now, or throw OOM + } + } + + /** * Remove a buffer ID from the catalog at the specified storage tier. * @note public for testing diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferStore.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferStore.scala index 511686f8557..98023259d82 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferStore.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferStore.scala @@ -24,7 +24,7 @@ import scala.collection.mutable import ai.rapids.cudf.{BaseDeviceMemoryBuffer, Cuda, DeviceMemoryBuffer, HostMemoryBuffer, MemoryBuffer, NvtxColor, NvtxRange} import com.nvidia.spark.rapids.Arm._ import com.nvidia.spark.rapids.RapidsPluginImplicits._ -import com.nvidia.spark.rapids.StorageTier.{DEVICE, StorageTier} +import com.nvidia.spark.rapids.StorageTier.{DEVICE, HOST, StorageTier} import com.nvidia.spark.rapids.format.TableMeta import org.apache.spark.internal.Logging @@ -32,13 +32,22 @@ import org.apache.spark.sql.types.DataType import org.apache.spark.sql.vectorized.ColumnarBatch /** - * A helper case class that contains the buffer we spilled from our current tier - * and likely a new buffer created in a spill store tier, but it can be set to None. - * If the buffer already exists in the target spill store, `newBuffer` will be None. - * @param spilledBuffer a `RapidsBuffer` we spilled from this store - * @param newBuffer an optional `RapidsBuffer` in the target spill store. + * Helper case classes that contain the buffer we spilled or unspilled from our current tier + * and likely a new buffer created in a target store tier, but it can be set to None. + * If the buffer already exists in the target store, `newBuffer` will be None. + * @param spillBuffer a `RapidsBuffer` we spilled or unspilled from this store + * @param newBuffer an optional `RapidsBuffer` in the target store. */ -case class BufferSpill(spilledBuffer: RapidsBuffer, newBuffer: Option[RapidsBuffer]) +trait SpillAction { + val spillBuffer: RapidsBuffer + val newBuffer: Option[RapidsBuffer] +} + +case class BufferSpill(spillBuffer: RapidsBuffer, newBuffer: Option[RapidsBuffer]) + extends SpillAction + +case class BufferUnspill(spillBuffer: RapidsBuffer, newBuffer: Option[RapidsBuffer]) + extends SpillAction /** * Base class for all buffer store types. @@ -307,7 +316,7 @@ abstract class RapidsBufferStore(val tier: StorageTier) // as it has already spilled. BufferSpill(nextSpillableBuffer, None) } - totalSpilled += bufferSpill.spilledBuffer.memoryUsedBytes + totalSpilled += bufferSpill.spillBuffer.memoryUsedBytes bufferSpills.append(bufferSpill) catalog.updateTiers(bufferSpill) } @@ -333,7 +342,7 @@ abstract class RapidsBufferStore(val tier: StorageTier) // the buffer via events. // https://github.com/NVIDIA/spark-rapids/issues/8610 Cuda.deviceSynchronize() - bufferSpills.foreach(_.spilledBuffer.safeFree()) + bufferSpills.foreach(_.spillBuffer.safeFree()) } } } @@ -516,6 +525,31 @@ abstract class RapidsBufferStore(val tier: StorageTier) } } + override def getHostMemoryBuffer: HostMemoryBuffer = { + (0 until MAX_UNSPILL_ATTEMPTS).foreach { _ => + catalog.acquireBuffer(id, HOST) match { + case Some(buffer) => + withResource(buffer) { _ => + return buffer.getHostMemoryBuffer + } + case _ => + try { + logDebug(s"Unspilling $this $id to $HOST") + val newBuffer = catalog.unspillBufferToHostStore( + this, + Cuda.DEFAULT_STREAM) + withResource(newBuffer) { _ => + return newBuffer.getHostMemoryBuffer + } + } catch { + case _: DuplicateBufferException => + logDebug(s"Lost host buffer registration race for buffer $id, retrying...") + } + } + } + throw new IllegalStateException(s"Unable to get host memory buffer for ID: $id") + } + /** * close() is called by client code to decrease the ref count of this RapidsBufferBase. * In the off chance that by the time close is invoked, the buffer was freed (not valid) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsDiskStore.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsDiskStore.scala index 3a4c8cf1797..5003ba46184 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsDiskStore.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsDiskStore.scala @@ -140,7 +140,6 @@ class RapidsDiskStore(diskBlockManager: RapidsDiskBlockManager) meta: TableMeta, spillPriority: Long) extends RapidsBufferBase(id, meta, spillPriority) { - private[this] var hostBuffer: Option[HostMemoryBuffer] = None // FIXME: Need to be clean up. Tracked in https://github.com/NVIDIA/spark-rapids/issues/9496 override val memoryUsedBytes: Long = uncompressedSize @@ -148,54 +147,40 @@ class RapidsDiskStore(diskBlockManager: RapidsDiskBlockManager) override val storageTier: StorageTier = StorageTier.DISK override def getMemoryBuffer: MemoryBuffer = synchronized { - if (hostBuffer.isEmpty) { - require(onDiskSizeInBytes > 0, - s"$this attempted an invalid 0-byte mmap of a file") - val path = id.getDiskPath(diskBlockManager) - val serializerManager = diskBlockManager.getSerializerManager() - val memBuffer = if (serializerManager.isRapidsSpill(id)) { - // Only go through serializerManager's stream wrapper for spill case - closeOnExcept(HostMemoryBuffer.allocate(uncompressedSize)) { decompressed => - GpuTaskMetrics.get.readSpillFromDiskTime { - withResource(FileChannel.open(path.toPath, StandardOpenOption.READ)) { c => - c.position(fileOffset) - withResource(Channels.newInputStream(c)) { compressed => - withResource(serializerManager.wrapStream(id, compressed)) { in => - withResource(new HostMemoryOutputStream(decompressed)) { out => - IOUtils.copy(in, out) - } - decompressed + require(onDiskSizeInBytes > 0, + s"$this attempted an invalid 0-byte mmap of a file") + val path = id.getDiskPath(diskBlockManager) + val serializerManager = diskBlockManager.getSerializerManager() + val memBuffer = if (serializerManager.isRapidsSpill(id)) { + // Only go through serializerManager's stream wrapper for spill case + closeOnExcept(HostAlloc.alloc(uncompressedSize)) { + decompressed => GpuTaskMetrics.get.readSpillFromDiskTime { + withResource(FileChannel.open(path.toPath, StandardOpenOption.READ)) { c => + c.position(fileOffset) + withResource(Channels.newInputStream(c)) { compressed => + withResource(serializerManager.wrapStream(id, compressed)) { in => + withResource(new HostMemoryOutputStream(decompressed)) { out => + IOUtils.copy(in, out) } + decompressed } } } } - } else { - // Reserved mmap read fashion for UCX shuffle path. Also it's skipping encryption and - // compression. - HostMemoryBuffer.mapFile(path, MapMode.READ_WRITE, fileOffset, onDiskSizeInBytes) } - hostBuffer = Some(memBuffer) + } else { + // Reserved mmap read fashion for UCX shuffle path. Also it's skipping encryption and + // compression. + HostMemoryBuffer.mapFile(path, MapMode.READ_WRITE, fileOffset, onDiskSizeInBytes) } - hostBuffer.foreach(_.incRefCount()) - hostBuffer.get + memBuffer } override def close(): Unit = synchronized { - if (refcount == 1) { - // free the memory mapping since this is the last active reader - hostBuffer.foreach { b => - logDebug(s"closing mmap buffer $b") - b.close() - } - hostBuffer = None - } super.close() } override protected def releaseResources(): Unit = { - require(hostBuffer.isEmpty, - "Releasing a disk buffer with non-empty host buffer") // Buffers that share paths must be cleaned up elsewhere if (id.canShareDiskPaths) { sharedBufferFiles.remove(id) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsHostMemoryStore.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsHostMemoryStore.scala index cdcdfea9715..32fe6229674 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsHostMemoryStore.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsHostMemoryStore.scala @@ -119,45 +119,57 @@ class RapidsHostMemoryStore( s"in the host store, skipping tier.") None } else { - withResource(other.getCopyIterator) { otherBufferIterator => - val isChunked = otherBufferIterator.isChunked - val totalCopySize = otherBufferIterator.getTotalCopySize - closeOnExcept(HostAlloc.tryAlloc(totalCopySize)) { hb => - hb.map { hostBuffer => - val spillNs = GpuTaskMetrics.get.spillToHostTime { - var hostOffset = 0L - val start = System.nanoTime() - while (otherBufferIterator.hasNext) { - val otherBuffer = otherBufferIterator.next() - withResource(otherBuffer) { _ => - otherBuffer match { - case devBuffer: DeviceMemoryBuffer => - hostBuffer.copyFromMemoryBufferAsync( - hostOffset, devBuffer, 0, otherBuffer.getLength, stream) - hostOffset += otherBuffer.getLength - case _ => - throw new IllegalStateException("copying from buffer without device memory") + // If the other is from the local disk store, we are unspilling to host memory. + if (other.storageTier == StorageTier.DISK) { + logDebug(s"copying RapidsDiskStore buffer ${other.id} to a HostMemoryBuffer") + val hostBuffer = other.getMemoryBuffer.asInstanceOf[HostMemoryBuffer] + Some(new RapidsHostMemoryBuffer( + other.id, + hostBuffer.getLength(), + other.meta, + applyPriorityOffset(other.getSpillPriority, HOST_MEMORY_BUFFER_SPILL_OFFSET), + hostBuffer)) + } else { + withResource(other.getCopyIterator) { otherBufferIterator => + val isChunked = otherBufferIterator.isChunked + val totalCopySize = otherBufferIterator.getTotalCopySize + closeOnExcept(HostAlloc.tryAlloc(totalCopySize)) { hb => + hb.map { hostBuffer => + val spillNs = GpuTaskMetrics.get.spillToHostTime { + var hostOffset = 0L + val start = System.nanoTime() + while (otherBufferIterator.hasNext) { + val otherBuffer = otherBufferIterator.next() + withResource(otherBuffer) { _ => + otherBuffer match { + case devBuffer: DeviceMemoryBuffer => + hostBuffer.copyFromMemoryBufferAsync( + hostOffset, devBuffer, 0, otherBuffer.getLength, stream) + hostOffset += otherBuffer.getLength + case _ => + throw new IllegalStateException("copying from buffer without device memory") + } } } + stream.sync() + System.nanoTime() - start } - stream.sync() - System.nanoTime() - start + val szMB = (totalCopySize.toDouble / 1024.0 / 1024.0).toLong + val bw = (szMB.toDouble / (spillNs.toDouble / 1000000000.0)).toLong + logDebug(s"Spill to host (chunked=$isChunked) " + + s"size=$szMB MiB bandwidth=$bw MiB/sec") + new RapidsHostMemoryBuffer( + other.id, + totalCopySize, + other.meta, + applyPriorityOffset(other.getSpillPriority, HOST_MEMORY_BUFFER_SPILL_OFFSET), + hostBuffer) + }.orElse { + // skip host + logWarning(s"Buffer $other with size ${other.memoryUsedBytes} does not fit " + + s"in the host store, skipping tier.") + None } - val szMB = (totalCopySize.toDouble / 1024.0 / 1024.0).toLong - val bw = (szMB.toDouble / (spillNs.toDouble / 1000000000.0)).toLong - logDebug(s"Spill to host (chunked=$isChunked) " + - s"size=$szMB MiB bandwidth=$bw MiB/sec") - new RapidsHostMemoryBuffer( - other.id, - totalCopySize, - other.meta, - applyPriorityOffset(other.getSpillPriority, HOST_MEMORY_BUFFER_SPILL_OFFSET), - hostBuffer) - }.orElse { - // skip host - logWarning(s"Buffer $other with size ${other.memoryUsedBytes} does not fit " + - s"in the host store, skipping tier.") - None } } } @@ -177,7 +189,9 @@ class RapidsHostMemoryStore( with MemoryBuffer.EventHandler { override val storageTier: StorageTier = StorageTier.HOST - override def getMemoryBuffer: MemoryBuffer = synchronized { + override def getMemoryBuffer: MemoryBuffer = getHostMemoryBuffer + + override def getHostMemoryBuffer: HostMemoryBuffer = synchronized { buffer.synchronized { setSpillable(this, false) buffer.incRefCount() diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SpillableColumnarBatch.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SpillableColumnarBatch.scala index 82efa7699ef..27c8bac497d 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SpillableColumnarBatch.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SpillableColumnarBatch.scala @@ -371,41 +371,9 @@ class SpillableHostBuffer(handle: RapidsBufferHandle, handle.close() } - /** - * Acquires the underlying `RapidsBuffer` and uses - * `RapidsBuffer.withMemoryBufferReadLock` to obtain a read lock - * that will held while invoking `body` with a `HostMemoryBuffer`. - * @param body function that takes a `HostMemoryBuffer` and produces `K` - * @tparam K any return type specified by `body` - * @return the result of body(hostMemoryBuffer) - */ - def withHostBufferReadOnly[K](body: HostMemoryBuffer => K): K = { - withResource(catalog.acquireBuffer(handle)) { rapidsBuffer => - rapidsBuffer.withMemoryBufferReadLock { - case hmb: HostMemoryBuffer => body(hmb) - case memoryBuffer => - throw new IllegalStateException( - s"Expected a HostMemoryBuffer but instead got ${memoryBuffer}") - } - } - } - - /** - * Acquires the underlying `RapidsBuffer` and uses - * `RapidsBuffer.withMemoryBufferWriteLock` to obtain a write lock - * that will held while invoking `body` with a `HostMemoryBuffer`. - * @param body function that takes a `HostMemoryBuffer` and produces `K` - * @tparam K any return type specified by `body` - * @return the result of body(hostMemoryBuffer) - */ - def withHostBufferWriteLock[K](body: HostMemoryBuffer => K): K = { + def getHostBuffer(): HostMemoryBuffer = { withResource(catalog.acquireBuffer(handle)) { rapidsBuffer => - rapidsBuffer.withMemoryBufferWriteLock { - case hmb: HostMemoryBuffer => body(hmb) - case memoryBuffer => - throw new IllegalStateException( - s"Expected a HostMemoryBuffer but instead got ${memoryBuffer}") - } + rapidsBuffer.getHostMemoryBuffer } } } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RapidsBufferCatalogSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RapidsBufferCatalogSuite.scala index 61940ffd463..9b5b37af480 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RapidsBufferCatalogSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RapidsBufferCatalogSuite.scala @@ -18,7 +18,7 @@ package com.nvidia.spark.rapids import java.io.File -import ai.rapids.cudf.{Cuda, DeviceMemoryBuffer, MemoryBuffer} +import ai.rapids.cudf.{Cuda, DeviceMemoryBuffer, HostMemoryBuffer, MemoryBuffer} import com.nvidia.spark.rapids.Arm.withResource import com.nvidia.spark.rapids.StorageTier.{DEVICE, DISK, HOST, StorageTier} import com.nvidia.spark.rapids.format.TableMeta @@ -342,6 +342,7 @@ class RapidsBufferCatalogSuite extends AnyFunSuite with MockitoSugar { length: Long, stream: Cuda.Stream): Unit = {} override def getDeviceMemoryBuffer: DeviceMemoryBuffer = null + override def getHostMemoryBuffer: HostMemoryBuffer = null override def addReference(): Boolean = { if (_acquireAttempts > 0) { _acquireAttempts -= 1 diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RapidsHostMemoryStoreSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RapidsHostMemoryStoreSuite.scala index 153b8da6556..1ffad031451 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RapidsHostMemoryStoreSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RapidsHostMemoryStoreSuite.scala @@ -267,7 +267,7 @@ class RapidsHostMemoryStoreSuite extends AnyFunSuite with MockitoSugar { withResource(spillableBuffer) { _ => // the refcount of 1 is the store assertResult(1)(hmb.getRefCount) - spillableBuffer.withHostBufferReadOnly { memoryBuffer => + withResource(spillableBuffer.getHostBuffer()) { memoryBuffer => assertResult(hmb)(memoryBuffer) assertResult(2)(memoryBuffer.getRefCount) } @@ -278,33 +278,46 @@ class RapidsHostMemoryStoreSuite extends AnyFunSuite with MockitoSugar { } test("host buffer originated: get host memory buffer after spill") { + RapidsBufferCatalog.close() val spillPriority = -10 val hostStoreMaxSize = 1L * 1024 * 1024 - val bm = new RapidsDiskBlockManager(new SparkConf()) - withResource(new RapidsDiskStore(bm)) { diskStore => - withResource(new RapidsHostMemoryStore(Some(hostStoreMaxSize))) { hostStore => - withResource(new RapidsDeviceMemoryStore) { devStore => - val catalog = new RapidsBufferCatalog(devStore, hostStore) - devStore.setSpillStore(hostStore) - hostStore.setSpillStore(diskStore) - val hmb = HostMemoryBuffer.allocate(1L * 1024) - val spillableBuffer = SpillableHostBuffer( - hmb, - hmb.getLength, - spillPriority, - catalog) - assertResult(1)(hmb.getRefCount) - // we spill it - catalog.synchronousSpill(hostStore, 0) - withResource(spillableBuffer) { _ => - // the refcount of the original buffer is 0 because it spilled - assertResult(0)(hmb.getRefCount) - spillableBuffer.withHostBufferReadOnly { memoryBuffer => - assertResult(memoryBuffer.getLength)(hmb.getLength) + try { + val bm = new RapidsDiskBlockManager(new SparkConf()) + val (catalog, devStore, hostStore, diskStore) = + closeOnExcept(new RapidsDiskStore(bm)) { diskStore => + closeOnExcept(new RapidsDeviceMemoryStore()) { devStore => + closeOnExcept(new RapidsHostMemoryStore(Some(hostStoreMaxSize))) { hostStore => + devStore.setSpillStore(hostStore) + hostStore.setSpillStore(diskStore) + val catalog = closeOnExcept( + new RapidsBufferCatalog(devStore, hostStore)) { catalog => catalog } + (catalog, devStore, hostStore, diskStore) } } } + + RapidsBufferCatalog.setDeviceStorage(devStore) + RapidsBufferCatalog.setHostStorage(hostStore) + RapidsBufferCatalog.setDiskStorage(diskStore) + RapidsBufferCatalog.setCatalog(catalog) + + val hmb = HostMemoryBuffer.allocate(1L * 1024) + val spillableBuffer = SpillableHostBuffer( + hmb, + hmb.getLength, + spillPriority) + assertResult(1)(hmb.getRefCount) + // we spill it + RapidsBufferCatalog.synchronousSpill(RapidsBufferCatalog.getHostStorage, 0) + withResource(spillableBuffer) { _ => + // the refcount of the original buffer is 0 because it spilled + assertResult(0)(hmb.getRefCount) + withResource(spillableBuffer.getHostBuffer()) { memoryBuffer => + assertResult(memoryBuffer.getLength)(hmb.getLength) + } } + } finally { + RapidsBufferCatalog.close() } } @@ -326,7 +339,7 @@ class RapidsHostMemoryStoreSuite extends AnyFunSuite with MockitoSugar { catalog) // spillable is 1K assertResult(hmb.getLength)(hostStore.currentSpillableSize) - spillableBuffer.withHostBufferReadOnly { memoryBuffer => + withResource(spillableBuffer.getHostBuffer()) { memoryBuffer => // 0 because we have a reference to the memoryBuffer assertResult(0)(hostStore.currentSpillableSize) val spilled = catalog.synchronousSpill(hostStore, 0) diff --git a/tests/src/test/scala/org/apache/spark/sql/rapids/SpillableColumnarBatchSuite.scala b/tests/src/test/scala/org/apache/spark/sql/rapids/SpillableColumnarBatchSuite.scala index c4a531a8d7d..001f82ab3a0 100644 --- a/tests/src/test/scala/org/apache/spark/sql/rapids/SpillableColumnarBatchSuite.scala +++ b/tests/src/test/scala/org/apache/spark/sql/rapids/SpillableColumnarBatchSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.rapids import java.util.UUID -import ai.rapids.cudf.{Cuda, DeviceMemoryBuffer, MemoryBuffer} +import ai.rapids.cudf.{Cuda, DeviceMemoryBuffer, HostMemoryBuffer, MemoryBuffer} import com.nvidia.spark.rapids.{RapidsBuffer, RapidsBufferCatalog, RapidsBufferId, SpillableColumnarBatchImpl, StorageTier} import com.nvidia.spark.rapids.StorageTier.StorageTier import com.nvidia.spark.rapids.format.TableMeta @@ -54,6 +54,7 @@ class SpillableColumnarBatchSuite extends AnyFunSuite { override def copyToMemoryBuffer(srcOffset: Long, dst: MemoryBuffer, dstOffset: Long, length: Long, stream: Cuda.Stream): Unit = {} override def getDeviceMemoryBuffer: DeviceMemoryBuffer = null + override def getHostMemoryBuffer: HostMemoryBuffer = null override def addReference(): Boolean = true override def free(): Unit = {} override def getSpillPriority: Long = 0