Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix shutdown bugs in the RAPIDS Shuffle Manager #2950

Merged
merged 3 commits into from
Jul 19, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,9 @@ class UCX(transport: UCXShuffleTransport, executor: BlockManagerId, rapidsConf:
}

withResource(new NvtxRange("UCX Handling Tasks", NvtxColor.CYAN)) { _ =>
while (!workerTasks.isEmpty) {
// check initialized since on close we queue a "task" that sets initialized to false
// to exit the progress loop, we don't want to execute any other tasks after that.
while (!workerTasks.isEmpty && initialized) {
val wt = workerTasks.poll()
if (wt != null) {
wt()
Expand All @@ -190,8 +192,11 @@ class UCX(transport: UCXShuffleTransport, executor: BlockManagerId, rapidsConf:
}
}

logDebug("Exiting UCX progress thread.")
Seq(endpointManager, worker, context).safeClose()
worker.synchronized {
logDebug("Exiting UCX progress thread.")
Seq(endpointManager, worker, context).safeClose()
worker = null
jlowe marked this conversation as resolved.
Show resolved Hide resolved
}
})
}

Expand Down Expand Up @@ -630,7 +635,12 @@ class UCX(transport: UCXShuffleTransport, executor: BlockManagerId, rapidsConf:
workerTasks.add(task)
if (rapidsConf.shuffleUcxUseWakeup) {
withResource(new NvtxRange("UCX Signal", NvtxColor.RED)) { _ =>
worker.signal()
// take up the worker object lock to protect against another `.close`
worker.synchronized {
if (worker != null) {
worker.signal()
}
}
}
}
}
Expand Down Expand Up @@ -694,6 +704,14 @@ class UCX(transport: UCXShuffleTransport, executor: BlockManagerId, rapidsConf:
}

override def close(): Unit = {
// put a UCX task in the progress thread. This will:
// - signal the worker, so the task is executed
// - tear down endpoints
// - remove all active messages
// - remove all memory registrations
// - sets `initialized` to false, which means that no further
// tasks will get executed in the progress thread, the loop exits
// and we close the endpointManager, the worker, and the context.
onWorkerThreadAsync(() => {
amRegistrations.forEach { (activeMessageId, _) =>
logDebug(s"Removing Active Message registration for " +
Expand All @@ -719,22 +737,10 @@ class UCX(transport: UCXShuffleTransport, executor: BlockManagerId, rapidsConf:
}
}

if (rapidsConf.shuffleUcxUseWakeup && worker != null) {
worker.signal()
}

progressThread.shutdown()
if (!progressThread.awaitTermination(500, TimeUnit.MILLISECONDS)) {
logError("UCX progress thread failed to terminate correctly")
}

endpointManager.close()

if (worker != null) {
worker.close()
}

context.close()
}

private def makeClientConnection(peerExecutorId: Long): UCXClientConnection = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,10 @@ abstract class RapidsShuffleInternalManagerBase(conf: SparkConf, val isDriver: B

private lazy val localBlockManagerId = blockManager.blockManagerId

// Used to prevent stopping multiple times RAPIDS Shuffle Manager internals.
// see the `stop` method
private var stopped: Boolean = false

// Code that expects the shuffle catalog to be initialized gets it this way,
// with error checking in case we are in a bad state.
private def getCatalogOrThrow: ShuffleBufferCatalog =
Expand Down Expand Up @@ -404,9 +408,12 @@ abstract class RapidsShuffleInternalManagerBase(conf: SparkConf, val isDriver: B

override def shuffleBlockResolver: ShuffleBlockResolver = resolver

override def stop(): Unit = {
override def stop(): Unit = synchronized {
wrapped.stop()
server.foreach(_.close())
transport.foreach(_.close())
if (!stopped) {
stopped = true
server.foreach(_.close())
transport.foreach(_.close())
}
}
}