Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cancel requests that are queued for a client/handler on error #2553

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,16 @@ class UCXShuffleTransport(shuffleServerId: BlockManagerId, rapidsConf: RapidsCon
ucxImpl
}

private val altList = new HashedPriorityQueue[PendingTransferRequest](
1000,
(t: PendingTransferRequest, t1: PendingTransferRequest) => {
java.lang.Long.compare(t.getLength, t1.getLength)
})

// access to this set must hold the `altList` lock
val validHandlers =
new mutable.HashSet[RapidsShuffleFetchHandler]()

override def getDirectByteBuffer(size: Long): RefCountedDirectByteBuffer = {
if (size > rapidsConf.shuffleMaxMetadataSize) {
logWarning(s"Large metadata message size $size B, larger " +
Expand Down Expand Up @@ -333,18 +343,6 @@ class UCXShuffleTransport(shuffleServerId: BlockManagerId, rapidsConf: RapidsCon
inflightMonitor.notifyAll()
}

private val altList = new HashedPriorityQueue[PendingTransferRequest](
1000,
(t: PendingTransferRequest, t1: PendingTransferRequest) => {
if (t.getLength < t1.getLength) {
-1;
} else if (t.getLength > t1.getLength) {
1;
} else {
0
}
})

private[this] val exec = Executors.newSingleThreadExecutor(
GpuDeviceManager.wrapThreadFactory(
new ThreadFactoryBuilder()
Expand Down Expand Up @@ -435,7 +433,13 @@ class UCXShuffleTransport(shuffleServerId: BlockManagerId, rapidsConf: RapidsCon
// A _much_ better way of doing this would be to have separate
// lists, one per client. This should be cleaned up later.
altList.synchronized {
putBack.foreach(altList.add)
putBack.foreach { pb =>
// if this handler hasn't been invalidated, we can add the pending request back
// like with NOTE above, when this is a queue per client, this gets refactored
if (validHandlers.contains(pb.handler)) {
altList.add(pb)
}
}
}

if (perClientReq.nonEmpty) {
Expand Down Expand Up @@ -469,11 +473,41 @@ class UCXShuffleTransport(shuffleServerId: BlockManagerId, rapidsConf: RapidsCon
override def queuePending(reqs: Seq[PendingTransferRequest]): Unit =
altList.synchronized {
import collection.JavaConverters._
validHandlers.add(reqs.head.handler)
altList.addAll(reqs.asJava)
logDebug(s"THROTTLING ${altList.size} queued requests")
altList.notifyAll()
}

override def cancelPending(handler: RapidsShuffleFetchHandler): Unit = {
altList.synchronized {
if (validHandlers.contains(handler)) {
// This is expensive, but will be refactored with a queue per client.
// As it stands, in the good case, it should be invoked once per task/peer,
// on task completion, and `altList` should be empty, turning this into
// mostly a noop.
// When there are errors, we will get more invocations, specifically as `transferError`
// is handled by `RapidsShuffleFetchHandler` and then later when the task finally
// fails.
if (!altList.isEmpty) {
val it = altList.iterator()
val toRemove = new ArrayBuffer[PendingTransferRequest]()
while (it.hasNext) {
val pending = it.next()
if (pending.handler == handler) {
toRemove.append(pending)
}
}
if (toRemove.nonEmpty) {
toRemove.foreach(altList.remove)
}
}
// invalidate the handler
validHandlers.remove(handler)
}
}
}

override def close(): Unit = {
logInfo("UCX transport closing")
exec.shutdown()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ case class PendingTransferRequest(client: RapidsShuffleClient,
*/
class RapidsShuffleClient(
localExecutorId: Long,
connection: ClientConnection,
val connection: ClientConnection,
transport: RapidsShuffleTransport,
exec: Executor,
clientCopyExecutor: Executor,
Expand Down Expand Up @@ -262,6 +262,7 @@ class RapidsShuffleClient(
* the transport's throttle logic.
*/
private[shuffle] def doIssueBufferReceives(bufferReceiveState: BufferReceiveState): Unit = {

try {
if (!bufferReceiveState.hasIterated) {
sendTransferRequest(bufferReceiveState)
Expand Down Expand Up @@ -368,6 +369,16 @@ class RapidsShuffleClient(
}
}

/**
* Cancel pending requests for handler `handler` to the peer represented by this client.
* @param handler instance to use to find requests to cancel
* @note this currently only cancels pending requests that are queued in the transport,
* and not in flight.
*/
def cancelPending(handler: RapidsShuffleFetchHandler): Unit = {
transport.cancelPending(handler)
}

/**
* This function handles data received in `bounceBuffers`. The data should be copied out
* of the buffers, and the function should call into `bufferReceiveState` to advance its
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,14 +152,14 @@ class RapidsShuffleIterator(
// `taskAttemptId`.
case class BlockIdMapIndex(id: ShuffleBlockBatchId, mapIndex: Int)

private var clientAndHandlers = Seq[(RapidsShuffleClient, RapidsShuffleFetchHandler)]()

def start(): Unit = {
logInfo(s"Fetching ${blocksByAddress.length} blocks.")

// issue local fetches first
val (local, remote) = blocksByAddress.partition(ba => ba._1.host == localHost)

var clients = Seq[RapidsShuffleClient]()

(local ++ remote).foreach {
case (blockManagerId: BlockManagerId, blockIds: Seq[(BlockId, Long, Int)]) => {
val shuffleRequestsMapIndex: Seq[BlockIdMapIndex] =
Expand Down Expand Up @@ -248,7 +248,7 @@ class RapidsShuffleIterator(
}
}

override def transferError(errorMessage: String, throwable: Throwable): Unit =
override def transferError(errorMessage: String, throwable: Throwable): Unit = {
resolvedBatches.synchronized {
// If Spark detects a single fetch failure, the whole task has failed
// as per `FetchFailedException`. In the future `mapIndex` will come from the
Expand All @@ -258,16 +258,20 @@ class RapidsShuffleIterator(
blockManagerId, id, mapIndex, errorMessage, throwable))
}
}

// tell the client to cancel pending requests
client.cancelPending(this)
}
}

logInfo(s"Client $blockManagerId triggered, for ${shuffleRequestsMapIndex.size} blocks")
client.doFetch(shuffleRequestsMapIndex.map(_.id), handler)
clients = clients :+ client
clientAndHandlers = clientAndHandlers :+ ((client, handler))
jlowe marked this conversation as resolved.
Show resolved Hide resolved
}
}

logInfo(s"RapidsShuffleIterator for ${Thread.currentThread()} started with " +
s"${clients.size} clients.")
s"${clientAndHandlers.size} clients.")
}

private[this] def receiveBufferCleaner(): Unit = resolvedBatches.synchronized {
Expand All @@ -280,6 +284,10 @@ class RapidsShuffleIterator(
GpuShuffleEnv.getReceivedCatalog.removeBuffer(bufferId)
case _ =>
}
// tell the client to cancel pending requests
clientAndHandlers.foreach {
case (client, handler) => client.cancelPending(handler)
}
}
}

Expand All @@ -290,7 +298,6 @@ class RapidsShuffleIterator(

private[this] val taskContext: Option[TaskContext] = Option(TaskContext.get())

//TODO: on task completion we currently don't ask clients to stop/clean resources
taskContext.foreach(_.addTaskCompletionListener[Unit](_ => receiveBufferCleaner()))

def pollForResult(timeoutSeconds: Long): Option[ShuffleClientResult] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,7 @@ trait Transaction extends AutoCloseable {
* needed.
*/
trait RapidsShuffleTransport extends AutoCloseable {

/**
* This function will connect (if not connected already) to a peer
* described by `blockManagerId`. Connections are cached.
Expand Down Expand Up @@ -382,6 +383,12 @@ trait RapidsShuffleTransport extends AutoCloseable {
*/
def queuePending(reqs: Seq[PendingTransferRequest])

/**
* Cancel requests that are waiting in the queue (not in-flight) for a specific
* handler
*/
def cancelPending(handler: RapidsShuffleFetchHandler): Unit

/**
* (throttle) Signals that `bytesCompleted` are done, allowing more requests through the
* throttle.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ class RapidsShuffleIteratorSuite extends RapidsShuffleTestHelper {
assert(cl.hasNext)
assertThrows[RapidsShuffleFetchFailedException](cl.next())

verify(mockTransport, times(1)).cancelPending(handler)

verify(testMetricsUpdater, times(1))
.update(any(), any(), any(), any())
assertResult(0)(testMetricsUpdater.totalRemoteBlocksFetched)
Expand Down Expand Up @@ -120,6 +122,9 @@ class RapidsShuffleIteratorSuite extends RapidsShuffleTestHelper {
throw rsffe
}
}

verify(mockTransport, times(1)).cancelPending(handler)

verify(testMetricsUpdater, times(1))
.update(any(), any(), any(), any())
assertResult(0)(testMetricsUpdater.totalRemoteBlocksFetched)
Expand All @@ -145,6 +150,9 @@ class RapidsShuffleIteratorSuite extends RapidsShuffleTestHelper {
val ac = ArgumentCaptor.forClass(classOf[RapidsShuffleFetchHandler])
when(mockTransport.makeClient(any(), any())).thenReturn(client)
doNothing().when(client).doFetch(any(), ac.capture())
cl.start()

val handler = ac.getValue.asInstanceOf[RapidsShuffleFetchHandler]

// signal a timeout to the iterator
when(cl.pollForResult(any())).thenReturn(None)
Expand Down Expand Up @@ -173,8 +181,6 @@ class RapidsShuffleIteratorSuite extends RapidsShuffleTestHelper {
mockCatalog,
123)

when(mockTransaction.getStatus).thenReturn(TransactionStatus.Error)

val ac = ArgumentCaptor.forClass(classOf[RapidsShuffleFetchHandler])
when(mockTransport.makeClient(any(), any())).thenReturn(client)
doNothing().when(client).doFetch(any(), ac.capture())
Expand All @@ -192,6 +198,8 @@ class RapidsShuffleIteratorSuite extends RapidsShuffleTestHelper {
handler.start(1)
handler.batchReceived(bufferId)

verify(mockTransport, times(0)).cancelPending(handler)

assert(cl.hasNext)
assertResult(cb)(cl.next())
assertResult(1)(testMetricsUpdater.totalRemoteBlocksFetched)
Expand Down