Skip to content

Commit

Permalink
Cancel requests that are queued for a client/handler on error
Browse files Browse the repository at this point in the history
  • Loading branch information
abellina committed Jun 1, 2021
1 parent 27250f7 commit d08a569
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,22 @@ class UCXShuffleTransport(shuffleServerId: BlockManagerId, rapidsConf: RapidsCon
ucxImpl
}

private val altList = new HashedPriorityQueue[PendingTransferRequest](
1000,
(t: PendingTransferRequest, t1: PendingTransferRequest) => {
if (t.getLength < t1.getLength) {
-1;
} else if (t.getLength > t1.getLength) {
1;
} else {
0
}
})

// access to this set must hold the `altList` lock
val validClientAndHandler =
new mutable.HashSet[(RapidsShuffleClient, RapidsShuffleFetchHandler)]()

override def getDirectByteBuffer(size: Long): RefCountedDirectByteBuffer = {
if (size > rapidsConf.shuffleMaxMetadataSize) {
logWarning(s"Large metadata message size $size B, larger " +
Expand Down Expand Up @@ -333,18 +349,6 @@ class UCXShuffleTransport(shuffleServerId: BlockManagerId, rapidsConf: RapidsCon
inflightMonitor.notifyAll()
}

private val altList = new HashedPriorityQueue[PendingTransferRequest](
1000,
(t: PendingTransferRequest, t1: PendingTransferRequest) => {
if (t.getLength < t1.getLength) {
-1;
} else if (t.getLength > t1.getLength) {
1;
} else {
0
}
})

private[this] val exec = Executors.newSingleThreadExecutor(
GpuDeviceManager.wrapThreadFactory(
new ThreadFactoryBuilder()
Expand Down Expand Up @@ -435,7 +439,13 @@ class UCXShuffleTransport(shuffleServerId: BlockManagerId, rapidsConf: RapidsCon
// A _much_ better way of doing this would be to have separate
// lists, one per client. This should be cleaned up later.
altList.synchronized {
putBack.foreach(altList.add)
putBack.foreach { pb =>
// if this client+handler hasn't been invalidated, we can add the pending request back
// like with NOTE above, when this is a queue per client, this gets refactored
if (validClientAndHandler.contains((pb.client, pb.handler))) {
altList.add(pb)
}
}
}

if (perClientReq.nonEmpty) {
Expand Down Expand Up @@ -469,11 +479,42 @@ class UCXShuffleTransport(shuffleServerId: BlockManagerId, rapidsConf: RapidsCon
override def queuePending(reqs: Seq[PendingTransferRequest]): Unit =
altList.synchronized {
import collection.JavaConverters._
val clientAndHandler = (reqs.head.client, reqs.head.handler)
validClientAndHandler.add(clientAndHandler)
altList.addAll(reqs.asJava)
logDebug(s"THROTTLING ${altList.size} queued requests")
altList.notifyAll()
}

override def cancelPending(client: RapidsShuffleClient,
handler: RapidsShuffleFetchHandler): Unit = {
altList.synchronized {
val clientAndHandler = (client, handler)
if (validClientAndHandler.contains(clientAndHandler)) {
// This is expensive, but will be refactored with a queue per client.
// As it stands, in the good case it should be invoked once per task/peer,
// on task completion, and `altList` should be empty
// will be skipped.
// When there are errors, we would get more invocations.
if (!altList.isEmpty) {
val it = altList.iterator()
val toRemove = new ArrayBuffer[PendingTransferRequest]()
while (it.hasNext) {
val pending = it.next()
if (pending.client == client && pending.handler == handler) {
toRemove.append(pending)
}
}
if (toRemove.nonEmpty) {
toRemove.foreach(altList.remove)
}
}
// invalidate the client+handler pair
validClientAndHandler.remove(clientAndHandler)
}
}
}

override def close(): Unit = {
logInfo("UCX transport closing")
exec.shutdown()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ case class PendingTransferRequest(client: RapidsShuffleClient,
*/
class RapidsShuffleClient(
localExecutorId: Long,
connection: ClientConnection,
val connection: ClientConnection,
transport: RapidsShuffleTransport,
exec: Executor,
clientCopyExecutor: Executor,
Expand Down Expand Up @@ -262,6 +262,7 @@ class RapidsShuffleClient(
* the transport's throttle logic.
*/
private[shuffle] def doIssueBufferReceives(bufferReceiveState: BufferReceiveState): Unit = {

try {
if (!bufferReceiveState.hasIterated) {
sendTransferRequest(bufferReceiveState)
Expand Down Expand Up @@ -368,6 +369,16 @@ class RapidsShuffleClient(
}
}

/**
* Cancel pending requests for handler `handler` to the peer represented by this client.
* @param handler instance to use to find requests to cancel
* @note this currentl only cancels pending requests that are queued in the transport,
* and not in flight.
*/
def cancelPending(handler: RapidsShuffleFetchHandler): Unit = {
transport.cancelPending(this, handler)
}

/**
* This function handles data received in `bounceBuffers`. The data should be copied out
* of the buffers, and the function should call into `bufferReceiveState` to advance its
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,14 +152,14 @@ class RapidsShuffleIterator(
// `taskAttemptId`.
case class BlockIdMapIndex(id: ShuffleBlockBatchId, mapIndex: Int)

private var clientAndHandlers = Seq[(RapidsShuffleClient, RapidsShuffleFetchHandler)]()

def start(): Unit = {
logInfo(s"Fetching ${blocksByAddress.length} blocks.")

// issue local fetches first
val (local, remote) = blocksByAddress.partition(ba => ba._1.host == localHost)

var clients = Seq[RapidsShuffleClient]()

(local ++ remote).foreach {
case (blockManagerId: BlockManagerId, blockIds: Seq[(BlockId, Long, Int)]) => {
val shuffleRequestsMapIndex: Seq[BlockIdMapIndex] =
Expand Down Expand Up @@ -248,7 +248,7 @@ class RapidsShuffleIterator(
}
}

override def transferError(errorMessage: String, throwable: Throwable): Unit =
override def transferError(errorMessage: String, throwable: Throwable): Unit = {
resolvedBatches.synchronized {
// If Spark detects a single fetch failure, the whole task has failed
// as per `FetchFailedException`. In the future `mapIndex` will come from the
Expand All @@ -258,16 +258,20 @@ class RapidsShuffleIterator(
blockManagerId, id, mapIndex, errorMessage, throwable))
}
}

// tell the client to cancel pending requests
client.cancelPending(this)
}
}

logInfo(s"Client $blockManagerId triggered, for ${shuffleRequestsMapIndex.size} blocks")
client.doFetch(shuffleRequestsMapIndex.map(_.id), handler)
clients = clients :+ client
clientAndHandlers = clientAndHandlers :+ ((client, handler))
}
}

logInfo(s"RapidsShuffleIterator for ${Thread.currentThread()} started with " +
s"${clients.size} clients.")
s"${clientAndHandlers.size} clients.")
}

private[this] def receiveBufferCleaner(): Unit = resolvedBatches.synchronized {
Expand All @@ -280,6 +284,10 @@ class RapidsShuffleIterator(
GpuShuffleEnv.getReceivedCatalog.removeBuffer(bufferId)
case _ =>
}
// tell the client to cancel pending requests
clientAndHandlers.foreach {
case (client, handler) => client.cancelPending(handler)
}
}
}

Expand All @@ -290,7 +298,6 @@ class RapidsShuffleIterator(

private[this] val taskContext: Option[TaskContext] = Option(TaskContext.get())

//TODO: on task completion we currently don't ask clients to stop/clean resources
taskContext.foreach(_.addTaskCompletionListener[Unit](_ => receiveBufferCleaner()))

def pollForResult(timeoutSeconds: Long): Option[ShuffleClientResult] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,7 @@ trait Transaction extends AutoCloseable {
* needed.
*/
trait RapidsShuffleTransport extends AutoCloseable {

/**
* This function will connect (if not connected already) to a peer
* described by `blockManagerId`. Connections are cached.
Expand Down Expand Up @@ -382,6 +383,13 @@ trait RapidsShuffleTransport extends AutoCloseable {
*/
def queuePending(reqs: Seq[PendingTransferRequest])

/**
* Cancel requests that are waiting in the queue (not in-flight) for a specific
* client/handler
*/
def cancelPending(client: RapidsShuffleClient,
handler: RapidsShuffleFetchHandler): Unit

/**
* (throttle) Signals that `bytesCompleted` are done, allowing more requests through the
* throttle.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ class RapidsShuffleIteratorSuite extends RapidsShuffleTestHelper {
assert(cl.hasNext)
assertThrows[RapidsShuffleFetchFailedException](cl.next())

verify(mockTransport, times(1)).cancelPending(client, handler)

verify(testMetricsUpdater, times(1))
.update(any(), any(), any(), any())
assertResult(0)(testMetricsUpdater.totalRemoteBlocksFetched)
Expand Down Expand Up @@ -120,6 +122,9 @@ class RapidsShuffleIteratorSuite extends RapidsShuffleTestHelper {
throw rsffe
}
}

verify(mockTransport, times(1)).cancelPending(client, handler)

verify(testMetricsUpdater, times(1))
.update(any(), any(), any(), any())
assertResult(0)(testMetricsUpdater.totalRemoteBlocksFetched)
Expand All @@ -145,6 +150,9 @@ class RapidsShuffleIteratorSuite extends RapidsShuffleTestHelper {
val ac = ArgumentCaptor.forClass(classOf[RapidsShuffleFetchHandler])
when(mockTransport.makeClient(any(), any())).thenReturn(client)
doNothing().when(client).doFetch(any(), ac.capture())
cl.start()

val handler = ac.getValue.asInstanceOf[RapidsShuffleFetchHandler]

// signal a timeout to the iterator
when(cl.pollForResult(any())).thenReturn(None)
Expand Down Expand Up @@ -173,8 +181,6 @@ class RapidsShuffleIteratorSuite extends RapidsShuffleTestHelper {
mockCatalog,
123)

when(mockTransaction.getStatus).thenReturn(TransactionStatus.Error)

val ac = ArgumentCaptor.forClass(classOf[RapidsShuffleFetchHandler])
when(mockTransport.makeClient(any(), any())).thenReturn(client)
doNothing().when(client).doFetch(any(), ac.capture())
Expand All @@ -192,6 +198,8 @@ class RapidsShuffleIteratorSuite extends RapidsShuffleTestHelper {
handler.start(1)
handler.batchReceived(bufferId)

verify(mockTransport, times(0)).cancelPending(client, handler)

assert(cl.hasNext)
assertResult(cb)(cl.next())
assertResult(1)(testMetricsUpdater.totalRemoteBlocksFetched)
Expand Down

0 comments on commit d08a569

Please sign in to comment.