Skip to content

Commit

Permalink
Add cudaStreamSynchronize when a new device buffer is added to the sp…
Browse files Browse the repository at this point in the history
…ill framework (NVIDIA#5485)

* Adds a stream synchronize in addBuffer to ensure we safely spill

* Small cleanup in copyBuffer, add note about createBuffer synchronation requirements

Signed-off-by: Alessandro Bellina <abellina@nvidia.com>

* Remove extra nvtx range

* When adding contiguous_split buffers in RapidsShuffleManager, synchronize once

* Fix RapidsShuffleTestHelper

* Fix RapidsShuffleClientSuite
  • Loading branch information
abellina authored and anthony-chang committed May 17, 2022
1 parent 39960cc commit a807a06
Show file tree
Hide file tree
Showing 7 changed files with 56 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ trait GpuPartitioning extends Partitioning with Arm {
cts.foreach { ct => splits.append(GpuPackedTableColumn.from(ct)) }
}
}
// synchronize our stream to ensure we have caught up with contiguous split
// as downstream consumers (RapidsShuffleManager) will add hundreds of buffers
// to the spill framework, this makes it so here we synchronize once.
Cuda.DEFAULT_STREAM.sync()
splits.toArray
}
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,7 @@ abstract class RapidsBufferStore(
def copyBuffer(buffer: RapidsBuffer, memoryBuffer: MemoryBuffer, stream: Cuda.Stream)
: RapidsBufferBase = {
freeOnExcept(createBuffer(buffer, memoryBuffer, stream)) { newBuffer =>
buffers.add(newBuffer)
catalog.registerNewBuffer(newBuffer)
addBuffer(newBuffer)
newBuffer
}
}
Expand Down Expand Up @@ -209,6 +208,7 @@ abstract class RapidsBufferStore(
* If the data transfer will be performed asynchronously, this method is responsible for
* adding a reference to the existing buffer and later closing it when the transfer completes.
* @note DO NOT close the buffer unless adding a reference!
* @note `createBuffer` impls should synchronize against `stream` before returning, if needed.
* @param buffer data from another store
* @param memoryBuffer memory buffer obtained from the specified Rapids buffer. The ownership
* for `memoryBuffer` is transferred to this store. The store may close
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020-2021, NVIDIA CORPORATION.
* Copyright (c) 2020-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -78,7 +78,7 @@ class RapidsDeviceMemoryStore(catalog: RapidsBufferCatalog = RapidsBufferCatalog
spillCallback)) { buffer =>
logDebug(s"Adding table for: [id=$id, size=${buffer.size}, " +
s"meta_id=${buffer.meta.bufferMeta.id}, meta_size=${buffer.meta.bufferMeta.size}]")
addBuffer(buffer)
addDeviceBuffer(buffer, needsSync = true)
}
}

Expand All @@ -92,12 +92,15 @@ class RapidsDeviceMemoryStore(catalog: RapidsBufferCatalog = RapidsBufferCatalog
* @param initialSpillPriority starting spill priority value for the buffer
* @param spillCallback a callback when the buffer is spilled. This should be very light weight.
* It should never allocate GPU memory and really just be used for metrics.
* @param needsSync whether the spill framework should stream synchronize while adding
* this device buffer (defaults to true)
*/
def addContiguousTable(
id: RapidsBufferId,
contigTable: ContiguousTable,
initialSpillPriority: Long,
spillCallback: SpillCallback = RapidsBuffer.defaultSpillCallback): Unit = {
spillCallback: SpillCallback = RapidsBuffer.defaultSpillCallback,
needsSync: Boolean = true): Unit = {
val contigBuffer = contigTable.getBuffer
val size = contigBuffer.getLength
val meta = MetaUtils.buildTableMeta(id.tableId, contigTable)
Expand All @@ -114,7 +117,7 @@ class RapidsDeviceMemoryStore(catalog: RapidsBufferCatalog = RapidsBufferCatalog
logDebug(s"Adding table for: [id=$id, size=${buffer.size}, " +
s"uncompressed=${buffer.meta.bufferMeta.uncompressedSize}, " +
s"meta_id=${buffer.meta.bufferMeta.id}, meta_size=${buffer.meta.bufferMeta.size}]")
addBuffer(buffer)
addDeviceBuffer(buffer, needsSync)
}
}

Expand All @@ -126,13 +129,16 @@ class RapidsDeviceMemoryStore(catalog: RapidsBufferCatalog = RapidsBufferCatalog
* @param initialSpillPriority starting spill priority value for the buffer
* @param spillCallback a callback when the buffer is spilled. This should be very light weight.
* It should never allocate GPU memory and really just be used for metrics.
* @param needsSync whether the spill framework should stream synchronize while adding
* this device buffer (defaults to true)
*/
def addBuffer(
id: RapidsBufferId,
buffer: DeviceMemoryBuffer,
tableMeta: TableMeta,
initialSpillPriority: Long,
spillCallback: SpillCallback = RapidsBuffer.defaultSpillCallback): Unit = {
spillCallback: SpillCallback = RapidsBuffer.defaultSpillCallback,
needsSync: Boolean = true): Unit = {
freeOnExcept(
new RapidsDeviceMemoryBuffer(
id,
Expand All @@ -146,8 +152,21 @@ class RapidsDeviceMemoryStore(catalog: RapidsBufferCatalog = RapidsBufferCatalog
s"uncompressed=${buff.meta.bufferMeta.uncompressedSize}, " +
s"meta_id=${tableMeta.bufferMeta.id}, " +
s"meta_size=${tableMeta.bufferMeta.size}]")
addBuffer(buff)
addDeviceBuffer(buff, needsSync)
}
}

/**
* Adds a device buffer to the spill framework, stream synchronizing with the producer
* stream to ensure that the buffer is fully materialized, and can be safely copied
* as part of the spill.
* @param needsSync true if we should stream synchronize before adding the buffer
*/
private def addDeviceBuffer(buffer: RapidsDeviceMemoryBuffer, needsSync: Boolean): Unit = {
if (needsSync) {
Cuda.DEFAULT_STREAM.sync()
}
addBuffer(buffer);
}

class RapidsDeviceMemoryBuffer(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020-2021, NVIDIA CORPORATION.
* Copyright (c) 2020-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -430,7 +430,12 @@ class RapidsShuffleClient(
logDebug(s"Adding buffer id ${id} to catalog")
if (buffer != null) {
// add the buffer to the catalog so it is available for spill
devStorage.addBuffer(id, buffer, meta, SpillPriorities.INPUT_FROM_SHUFFLE_PRIORITY)
devStorage.addBuffer(id, buffer, meta,
SpillPriorities.INPUT_FROM_SHUFFLE_PRIORITY,
// set needsSync to false because we already have stream synchronized after
// consuming the bounce buffer, so we know these buffers are synchronized
// w.r.t. the CPU
needsSync = false)
} else {
// no device data, just tracking metadata
catalog.registerNewBuffer(new DegenerateRapidsBuffer(id, meta))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,10 @@ class RapidsCachingWriter[K, V](
shuffleStorage.addContiguousTable(
bufferId,
contigTable,
SpillPriorities.OUTPUT_FOR_SHUFFLE_INITIAL_PRIORITY)
SpillPriorities.OUTPUT_FOR_SHUFFLE_INITIAL_PRIORITY,
// we don't need to sync here, because we sync on the cuda
// stream after sliceInternalOnGpu (contiguous_split)
needsSync = false)
case c: GpuCompressedColumnVector =>
val buffer = c.getTableBuffer
buffer.incRefCount()
Expand All @@ -127,7 +130,10 @@ class RapidsCachingWriter[K, V](
bufferId,
buffer,
tableMeta,
SpillPriorities.OUTPUT_FOR_SHUFFLE_INITIAL_PRIORITY)
SpillPriorities.OUTPUT_FOR_SHUFFLE_INITIAL_PRIORITY,
// we don't need to sync here, because we sync on the cuda
// stream after compression.
needsSync = false)
case c => throw new IllegalStateException(s"Unexpected column type: ${c.getClass}")
}
bytesWritten += partSize
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020-2021, NVIDIA CORPORATION.
* Copyright (c) 2020-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -231,7 +231,7 @@ class RapidsShuffleClientSuite extends RapidsShuffleTestHelper {
verify(client, times(1)).track(any[DeviceMemoryBuffer](), tmCaptor.capture())
verifyTableMeta(tableMeta, tmCaptor.getValue.asInstanceOf[TableMeta])
verify(mockStorage, times(1))
.addBuffer(any(), dmbCaptor.capture(), any(), any(), any())
.addBuffer(any(), dmbCaptor.capture(), any(), any(), any(), any())

val receivedBuff = dmbCaptor.getValue.asInstanceOf[DeviceMemoryBuffer]
assertResult(tableMeta.bufferMeta().size())(receivedBuff.getLength)
Expand Down Expand Up @@ -282,7 +282,7 @@ class RapidsShuffleClientSuite extends RapidsShuffleTestHelper {
verify(client, times(1)).track(any[DeviceMemoryBuffer](), tmCaptor.capture())
verifyTableMeta(tableMeta, tmCaptor.getValue.asInstanceOf[TableMeta])
verify(mockStorage, times(1))
.addBuffer(any(), dmbCaptor.capture(), any(), any(), any())
.addBuffer(any(), dmbCaptor.capture(), any(), any(), any(), any())
verify(mockCatalog, times(1)).removeBuffer(any())

val receivedBuff = dmbCaptor.getValue.asInstanceOf[DeviceMemoryBuffer]
Expand Down Expand Up @@ -335,7 +335,7 @@ class RapidsShuffleClientSuite extends RapidsShuffleTestHelper {
}

verify(mockStorage, times(5))
.addBuffer(any(), dmbCaptor.capture(), any(), any(), any())
.addBuffer(any(), dmbCaptor.capture(), any(), any(), any(), any())

assertResult(totalExpectedSize)(
dmbCaptor.getAllValues().toArray().map(_.asInstanceOf[DeviceMemoryBuffer].getLength).sum)
Expand Down Expand Up @@ -388,7 +388,7 @@ class RapidsShuffleClientSuite extends RapidsShuffleTestHelper {
}

verify(mockStorage, times(20))
.addBuffer(any(), dmbCaptor.capture(), any(), any(), any())
.addBuffer(any(), dmbCaptor.capture(), any(), any(), any(), any())

assertResult(totalExpectedSize)(
dmbCaptor.getAllValues().toArray().map(_.asInstanceOf[DeviceMemoryBuffer].getLength).sum)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020-2021, NVIDIA CORPORATION.
* Copyright (c) 2020-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -141,9 +141,10 @@ class RapidsShuffleTestHelper extends FunSuite
testMetricsUpdater = spy(new TestShuffleMetricsUpdater)

val dmbCaptor = ArgumentCaptor.forClass(classOf[DeviceMemoryBuffer])
when(mockStorage.addBuffer(any(), dmbCaptor.capture(), any(), any(), any())).thenAnswer(_ => {
buffersToClose.append(dmbCaptor.getValue.asInstanceOf[DeviceMemoryBuffer])
})
when(mockStorage.addBuffer(any(), dmbCaptor.capture(), any(), any(), any(), any()))
.thenAnswer(_ => {
buffersToClose.append(dmbCaptor.getValue.asInstanceOf[DeviceMemoryBuffer])
})

client = spy(new RapidsShuffleClient(
mockConnection,
Expand Down

0 comments on commit a807a06

Please sign in to comment.