diff --git a/shuffle-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/ucx/UCXShuffleTransport.scala b/shuffle-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/ucx/UCXShuffleTransport.scala index 8acbbe82ce1..6ea4a9ec98f 100644 --- a/shuffle-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/ucx/UCXShuffleTransport.scala +++ b/shuffle-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/ucx/UCXShuffleTransport.scala @@ -88,6 +88,22 @@ class UCXShuffleTransport(shuffleServerId: BlockManagerId, rapidsConf: RapidsCon ucxImpl } + private val altList = new HashedPriorityQueue[PendingTransferRequest]( + 1000, + (t: PendingTransferRequest, t1: PendingTransferRequest) => { + if (t.getLength < t1.getLength) { + -1; + } else if (t.getLength > t1.getLength) { + 1; + } else { + 0 + } + }) + + // access to this set must hold the `altList` lock + val validClientAndHandler = + new mutable.HashSet[(RapidsShuffleClient, RapidsShuffleFetchHandler)]() + override def getDirectByteBuffer(size: Long): RefCountedDirectByteBuffer = { if (size > rapidsConf.shuffleMaxMetadataSize) { logWarning(s"Large metadata message size $size B, larger " + @@ -333,18 +349,6 @@ class UCXShuffleTransport(shuffleServerId: BlockManagerId, rapidsConf: RapidsCon inflightMonitor.notifyAll() } - private val altList = new HashedPriorityQueue[PendingTransferRequest]( - 1000, - (t: PendingTransferRequest, t1: PendingTransferRequest) => { - if (t.getLength < t1.getLength) { - -1; - } else if (t.getLength > t1.getLength) { - 1; - } else { - 0 - } - }) - private[this] val exec = Executors.newSingleThreadExecutor( GpuDeviceManager.wrapThreadFactory( new ThreadFactoryBuilder() @@ -435,7 +439,13 @@ class UCXShuffleTransport(shuffleServerId: BlockManagerId, rapidsConf: RapidsCon // A _much_ better way of doing this would be to have separate // lists, one per client. This should be cleaned up later. altList.synchronized { - putBack.foreach(altList.add) + putBack.foreach { pb => + // if this client+handler hasn't been invalidated, we can add the pending request back + // like with NOTE above, when this is a queue per client, this gets refactored + if (validClientAndHandler.contains((pb.client, pb.handler))) { + altList.add(pb) + } + } } if (perClientReq.nonEmpty) { @@ -469,11 +479,42 @@ class UCXShuffleTransport(shuffleServerId: BlockManagerId, rapidsConf: RapidsCon override def queuePending(reqs: Seq[PendingTransferRequest]): Unit = altList.synchronized { import collection.JavaConverters._ + val clientAndHandler = (reqs.head.client, reqs.head.handler) + validClientAndHandler.add(clientAndHandler) altList.addAll(reqs.asJava) logDebug(s"THROTTLING ${altList.size} queued requests") altList.notifyAll() } + override def cancelPending(client: RapidsShuffleClient, + handler: RapidsShuffleFetchHandler): Unit = { + altList.synchronized { + val clientAndHandler = (client, handler) + if (validClientAndHandler.contains(clientAndHandler)) { + // This is expensive, but will be refactored with a queue per client. + // As it stands, in the good case it should be invoked once per task/peer, + // on task completion, and `altList` should be empty + // will be skipped. + // When there are errors, we would get more invocations. + if (!altList.isEmpty) { + val it = altList.iterator() + val toRemove = new ArrayBuffer[PendingTransferRequest]() + while (it.hasNext) { + val pending = it.next() + if (pending.client == client && pending.handler == handler) { + toRemove.append(pending) + } + } + if (toRemove.nonEmpty) { + toRemove.foreach(altList.remove) + } + } + // invalidate the client+handler pair + validClientAndHandler.remove(clientAndHandler) + } + } + } + override def close(): Unit = { logInfo("UCX transport closing") exec.shutdown() diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleClient.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleClient.scala index 1c7bcaf738f..72b441a2cda 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleClient.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleClient.scala @@ -95,7 +95,7 @@ case class PendingTransferRequest(client: RapidsShuffleClient, */ class RapidsShuffleClient( localExecutorId: Long, - connection: ClientConnection, + val connection: ClientConnection, transport: RapidsShuffleTransport, exec: Executor, clientCopyExecutor: Executor, @@ -262,6 +262,7 @@ class RapidsShuffleClient( * the transport's throttle logic. */ private[shuffle] def doIssueBufferReceives(bufferReceiveState: BufferReceiveState): Unit = { + try { if (!bufferReceiveState.hasIterated) { sendTransferRequest(bufferReceiveState) @@ -368,6 +369,16 @@ class RapidsShuffleClient( } } + /** + * Cancel pending requests for handler `handler` to the peer represented by this client. + * @param handler instance to use to find requests to cancel + * @note this currentl only cancels pending requests that are queued in the transport, + * and not in flight. + */ + def cancelPending(handler: RapidsShuffleFetchHandler): Unit = { + transport.cancelPending(this, handler) + } + /** * This function handles data received in `bounceBuffers`. The data should be copied out * of the buffers, and the function should call into `bufferReceiveState` to advance its diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleIterator.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleIterator.scala index c1f2463c38d..969f2c955d6 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleIterator.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleIterator.scala @@ -152,14 +152,14 @@ class RapidsShuffleIterator( // `taskAttemptId`. case class BlockIdMapIndex(id: ShuffleBlockBatchId, mapIndex: Int) + private var clientAndHandlers = Seq[(RapidsShuffleClient, RapidsShuffleFetchHandler)]() + def start(): Unit = { logInfo(s"Fetching ${blocksByAddress.length} blocks.") // issue local fetches first val (local, remote) = blocksByAddress.partition(ba => ba._1.host == localHost) - var clients = Seq[RapidsShuffleClient]() - (local ++ remote).foreach { case (blockManagerId: BlockManagerId, blockIds: Seq[(BlockId, Long, Int)]) => { val shuffleRequestsMapIndex: Seq[BlockIdMapIndex] = @@ -248,7 +248,7 @@ class RapidsShuffleIterator( } } - override def transferError(errorMessage: String, throwable: Throwable): Unit = + override def transferError(errorMessage: String, throwable: Throwable): Unit = { resolvedBatches.synchronized { // If Spark detects a single fetch failure, the whole task has failed // as per `FetchFailedException`. In the future `mapIndex` will come from the @@ -258,16 +258,20 @@ class RapidsShuffleIterator( blockManagerId, id, mapIndex, errorMessage, throwable)) } } + + // tell the client to cancel pending requests + client.cancelPending(this) + } } logInfo(s"Client $blockManagerId triggered, for ${shuffleRequestsMapIndex.size} blocks") client.doFetch(shuffleRequestsMapIndex.map(_.id), handler) - clients = clients :+ client + clientAndHandlers = clientAndHandlers :+ ((client, handler)) } } logInfo(s"RapidsShuffleIterator for ${Thread.currentThread()} started with " + - s"${clients.size} clients.") + s"${clientAndHandlers.size} clients.") } private[this] def receiveBufferCleaner(): Unit = resolvedBatches.synchronized { @@ -280,6 +284,10 @@ class RapidsShuffleIterator( GpuShuffleEnv.getReceivedCatalog.removeBuffer(bufferId) case _ => } + // tell the client to cancel pending requests + clientAndHandlers.foreach { + case (client, handler) => client.cancelPending(handler) + } } } @@ -290,7 +298,6 @@ class RapidsShuffleIterator( private[this] val taskContext: Option[TaskContext] = Option(TaskContext.get()) - //TODO: on task completion we currently don't ask clients to stop/clean resources taskContext.foreach(_.addTaskCompletionListener[Unit](_ => receiveBufferCleaner())) def pollForResult(timeoutSeconds: Long): Option[ShuffleClientResult] = { diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleTransport.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleTransport.scala index 5ae535e9b37..81d5ecb6991 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleTransport.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleTransport.scala @@ -336,6 +336,7 @@ trait Transaction extends AutoCloseable { * needed. */ trait RapidsShuffleTransport extends AutoCloseable { + /** * This function will connect (if not connected already) to a peer * described by `blockManagerId`. Connections are cached. @@ -382,6 +383,13 @@ trait RapidsShuffleTransport extends AutoCloseable { */ def queuePending(reqs: Seq[PendingTransferRequest]) + /** + * Cancel requests that are waiting in the queue (not in-flight) for a specific + * client/handler + */ + def cancelPending(client: RapidsShuffleClient, + handler: RapidsShuffleFetchHandler): Unit + /** * (throttle) Signals that `bytesCompleted` are done, allowing more requests through the * throttle. diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleIteratorSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleIteratorSuite.scala index ee6e627610d..d76db48ebd9 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleIteratorSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleIteratorSuite.scala @@ -77,6 +77,8 @@ class RapidsShuffleIteratorSuite extends RapidsShuffleTestHelper { assert(cl.hasNext) assertThrows[RapidsShuffleFetchFailedException](cl.next()) + verify(mockTransport, times(1)).cancelPending(client, handler) + verify(testMetricsUpdater, times(1)) .update(any(), any(), any(), any()) assertResult(0)(testMetricsUpdater.totalRemoteBlocksFetched) @@ -120,6 +122,9 @@ class RapidsShuffleIteratorSuite extends RapidsShuffleTestHelper { throw rsffe } } + + verify(mockTransport, times(1)).cancelPending(client, handler) + verify(testMetricsUpdater, times(1)) .update(any(), any(), any(), any()) assertResult(0)(testMetricsUpdater.totalRemoteBlocksFetched) @@ -145,6 +150,9 @@ class RapidsShuffleIteratorSuite extends RapidsShuffleTestHelper { val ac = ArgumentCaptor.forClass(classOf[RapidsShuffleFetchHandler]) when(mockTransport.makeClient(any(), any())).thenReturn(client) doNothing().when(client).doFetch(any(), ac.capture()) + cl.start() + + val handler = ac.getValue.asInstanceOf[RapidsShuffleFetchHandler] // signal a timeout to the iterator when(cl.pollForResult(any())).thenReturn(None) @@ -173,8 +181,6 @@ class RapidsShuffleIteratorSuite extends RapidsShuffleTestHelper { mockCatalog, 123) - when(mockTransaction.getStatus).thenReturn(TransactionStatus.Error) - val ac = ArgumentCaptor.forClass(classOf[RapidsShuffleFetchHandler]) when(mockTransport.makeClient(any(), any())).thenReturn(client) doNothing().when(client).doFetch(any(), ac.capture()) @@ -192,6 +198,8 @@ class RapidsShuffleIteratorSuite extends RapidsShuffleTestHelper { handler.start(1) handler.batchReceived(bufferId) + verify(mockTransport, times(0)).cancelPending(client, handler) + assert(cl.hasNext) assertResult(cb)(cl.next()) assertResult(1)(testMetricsUpdater.totalRemoteBlocksFetched)