Skip to content

Commit

Permalink
Add tryAcquire to GpuSemaphore
Browse files Browse the repository at this point in the history
Signed-off-by: Robert (Bobby) Evans <bobby@apache.org>
  • Loading branch information
revans2 committed Jan 30, 2024
1 parent 7e48cc9 commit 337e060
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 6 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -58,6 +58,20 @@ object GpuSemaphore {
instance = new GpuSemaphore()
}

/**
* A thread may try to acquire the semaphore without blocking on it. NOTE: A task completion
* listener will automatically be installed to ensure the semaphore is always released by the
* time the task completes.
*/
def tryAcquire(context: TaskContext): Boolean = {
if (context != null) {
getInstance.tryAcquire(context)
} else {
// For unit tests that might try with no context
true
}
}

/**
* Tasks must call this when they begin to use the GPU.
* If the task has not already acquired the GPU semaphore then it is acquired,
Expand Down Expand Up @@ -245,6 +259,30 @@ private final class SemaphoreTaskInfo() extends Logging {
}
}

def tryAcquire(semaphore: Semaphore): Boolean = synchronized {
val t = Thread.currentThread()
if (hasSemaphore) {
blockedThreads.add(t)
moveToActive(t)
true
} else {
if (blockedThreads.size() == 0) {
// No other threads for this task are waiting, so we might be able to grab this directly
val ret = semaphore.tryAcquire(numPermits)
if (ret) {
hasSemaphore = true
blockedThreads.add(t)
moveToActive(t)
// no need to notify because there are no other threads and we are holding the lock
// to ensure that.
}
ret
} else {
false
}
}
}

def releaseSemaphore(semaphore: Semaphore): Unit = synchronized {
val t = Thread.currentThread()
activeThreads.remove(t)
Expand All @@ -267,6 +305,21 @@ private final class GpuSemaphore() extends Logging {
// Keep track of all tasks that are both active on the GPU and blocked waiting on the GPU
private val tasks = new ConcurrentHashMap[Long, SemaphoreTaskInfo]

def tryAcquire(context: TaskContext): Boolean = {
// Make sure that the thread/task is registered before we try and block
TaskRegistryTracker.registerThreadForRetry()
val taskAttemptId = context.taskAttemptId()
val taskInfo = tasks.computeIfAbsent(taskAttemptId, _ => {
onTaskCompletion(context, completeTask)
new SemaphoreTaskInfo()
})
val acquired = taskInfo.tryAcquire(semaphore)
if (acquired) {
GpuDeviceManager.initializeFromTask()
}
acquired
}

def acquireIfNecessary(context: TaskContext): Unit = {
// Make sure that the thread/task is registered before we try and block
TaskRegistryTracker.registerThreadForRetry()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -29,19 +29,23 @@ import org.apache.spark.sql.SparkSession

class GpuSemaphoreSuite extends AnyFunSuite
with BeforeAndAfterEach with MockitoSugar with TimeLimits with TimeLimitedTests {
val timeLimit = Span(10, Seconds)
val timeLimit: Span = Span(10, Seconds)

override def beforeEach(): Unit = {
ScalableTaskCompletion.reset()
GpuSemaphore.shutdown()
// semaphore tests depend on a SparkEnv being available
val activeSession = SparkSession.getActiveSession
if (activeSession.isEmpty) {
SparkSession.builder
if (activeSession.isDefined) {
SparkSession.getActiveSession.foreach(_.stop())
SparkSession.clearActiveSession()
}
SparkSession.builder
.appName("semaphoreTests")
.master("local[1]")
// only 1 task at a time so we can verify what blocks and what does not block
.config("spark.rapids.sql.concurrentGpuTasks", "1")
.getOrCreate()
}
}

override def afterEach(): Unit = {
Expand Down Expand Up @@ -79,4 +83,29 @@ class GpuSemaphoreSuite extends AnyFunSuite
GpuSemaphore.acquireIfNecessary(context)
verify(context, times(1)).addTaskCompletionListener[Unit](any())
}

test("multi tryAcquire") {
GpuDeviceManager.setRmmTaskInitEnabled(false)
val context = mockContext(1)
try {
assert(GpuSemaphore.tryAcquire(context))
assert(GpuSemaphore.tryAcquire(context))
} finally {
GpuSemaphore.releaseIfNecessary(context)
}
}

test("tryAcquire non-blocking") {
GpuDeviceManager.setRmmTaskInitEnabled(false)
val context1 = mockContext(1)
val context2 = mockContext(2)
try {
GpuSemaphore.acquireIfNecessary(context1)
assert(!GpuSemaphore.tryAcquire(context2))
assert(!GpuSemaphore.tryAcquire(context2))
} finally {
GpuSemaphore.releaseIfNecessary(context1)
GpuSemaphore.releaseIfNecessary(context2)
}
}
}

0 comments on commit 337e060

Please sign in to comment.