From ca3527924b3ab26865ddfa92ee575823310acb4f Mon Sep 17 00:00:00 2001 From: Alessandro Bellina Date: Mon, 19 Jul 2021 09:00:23 -0500 Subject: [PATCH] Fix shutdown bugs in the RAPIDS Shuffle Manager (#2950) * Fix shutdown bugs in the RAPIDS Shuffle Manager Signed-off-by: Alessandro Bellina * Protect calls to worker.signal so that they don't happen after close * Dont rely on an object for locking that can become null --- .../nvidia/spark/rapids/shuffle/ucx/UCX.scala | 38 +++++++++++-------- .../RapidsShuffleInternalManagerBase.scala | 13 +++++-- 2 files changed, 32 insertions(+), 19 deletions(-) diff --git a/shuffle-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/ucx/UCX.scala b/shuffle-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/ucx/UCX.scala index da3f22def33..177eeef1a5c 100644 --- a/shuffle-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/ucx/UCX.scala +++ b/shuffle-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/ucx/UCX.scala @@ -176,7 +176,9 @@ class UCX(transport: UCXShuffleTransport, executor: BlockManagerId, rapidsConf: } withResource(new NvtxRange("UCX Handling Tasks", NvtxColor.CYAN)) { _ => - while (!workerTasks.isEmpty) { + // check initialized since on close we queue a "task" that sets initialized to false + // to exit the progress loop, we don't want to execute any other tasks after that. + while (!workerTasks.isEmpty && initialized) { val wt = workerTasks.poll() if (wt != null) { wt() @@ -190,8 +192,11 @@ class UCX(transport: UCXShuffleTransport, executor: BlockManagerId, rapidsConf: } } - logDebug("Exiting UCX progress thread.") - Seq(endpointManager, worker, context).safeClose() + synchronized { + logDebug("Exiting UCX progress thread.") + Seq(endpointManager, worker, context).safeClose() + worker = null + } }) } @@ -630,7 +635,12 @@ class UCX(transport: UCXShuffleTransport, executor: BlockManagerId, rapidsConf: workerTasks.add(task) if (rapidsConf.shuffleUcxUseWakeup) { withResource(new NvtxRange("UCX Signal", NvtxColor.RED)) { _ => - worker.signal() + // take up the worker object lock to protect against another `.close` + synchronized { + if (worker != null) { + worker.signal() + } + } } } } @@ -694,6 +704,14 @@ class UCX(transport: UCXShuffleTransport, executor: BlockManagerId, rapidsConf: } override def close(): Unit = { + // put a UCX task in the progress thread. This will: + // - signal the worker, so the task is executed + // - tear down endpoints + // - remove all active messages + // - remove all memory registrations + // - sets `initialized` to false, which means that no further + // tasks will get executed in the progress thread, the loop exits + // and we close the endpointManager, the worker, and the context. onWorkerThreadAsync(() => { amRegistrations.forEach { (activeMessageId, _) => logDebug(s"Removing Active Message registration for " + @@ -719,22 +737,10 @@ class UCX(transport: UCXShuffleTransport, executor: BlockManagerId, rapidsConf: } } - if (rapidsConf.shuffleUcxUseWakeup && worker != null) { - worker.signal() - } - progressThread.shutdown() if (!progressThread.awaitTermination(500, TimeUnit.MILLISECONDS)) { logError("UCX progress thread failed to terminate correctly") } - - endpointManager.close() - - if (worker != null) { - worker.close() - } - - context.close() } private def makeClientConnection(peerExecutorId: Long): UCXClientConnection = { diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManagerBase.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManagerBase.scala index f0e970c7f3f..dc6624bdf9e 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManagerBase.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManagerBase.scala @@ -260,6 +260,10 @@ abstract class RapidsShuffleInternalManagerBase(conf: SparkConf, val isDriver: B private lazy val localBlockManagerId = blockManager.blockManagerId + // Used to prevent stopping multiple times RAPIDS Shuffle Manager internals. + // see the `stop` method + private var stopped: Boolean = false + // Code that expects the shuffle catalog to be initialized gets it this way, // with error checking in case we are in a bad state. private def getCatalogOrThrow: ShuffleBufferCatalog = @@ -404,9 +408,12 @@ abstract class RapidsShuffleInternalManagerBase(conf: SparkConf, val isDriver: B override def shuffleBlockResolver: ShuffleBlockResolver = resolver - override def stop(): Unit = { + override def stop(): Unit = synchronized { wrapped.stop() - server.foreach(_.close()) - transport.foreach(_.close()) + if (!stopped) { + stopped = true + server.foreach(_.close()) + transport.foreach(_.close()) + } } }