Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tryAcquire to GpuSemaphore #10330

Merged
merged 2 commits into from
Jan 31, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -58,6 +58,20 @@ object GpuSemaphore {
instance = new GpuSemaphore()
}

/**
* A thread may try to acquire the semaphore without blocking on it. NOTE: A task completion
* listener will automatically be installed to ensure the semaphore is always released by the
* time the task completes.
*/
def tryAcquire(context: TaskContext): Boolean = {
if (context != null) {
getInstance.tryAcquire(context)
} else {
// For unit tests that might try with no context
true
}
}

/**
* Tasks must call this when they begin to use the GPU.
* If the task has not already acquired the GPU semaphore then it is acquired,
Expand Down Expand Up @@ -245,6 +259,30 @@ private final class SemaphoreTaskInfo() extends Logging {
}
}

def tryAcquire(semaphore: Semaphore): Boolean = synchronized {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think there are some smarter heuristics to be had if we return the number of waiters when the semaphore cannot be acquired via Either[Int,Boolean] ?

val t = Thread.currentThread()
if (hasSemaphore) {
blockedThreads.add(t)
moveToActive(t)
jlowe marked this conversation as resolved.
Show resolved Hide resolved
true
} else {
if (blockedThreads.size() == 0) {
// No other threads for this task are waiting, so we might be able to grab this directly
val ret = semaphore.tryAcquire(numPermits)
if (ret) {
hasSemaphore = true
blockedThreads.add(t)
moveToActive(t)
jlowe marked this conversation as resolved.
Show resolved Hide resolved
// no need to notify because there are no other threads and we are holding the lock
// to ensure that.
}
ret
} else {
false
}
}
}

def releaseSemaphore(semaphore: Semaphore): Unit = synchronized {
val t = Thread.currentThread()
activeThreads.remove(t)
Expand All @@ -267,6 +305,21 @@ private final class GpuSemaphore() extends Logging {
// Keep track of all tasks that are both active on the GPU and blocked waiting on the GPU
private val tasks = new ConcurrentHashMap[Long, SemaphoreTaskInfo]

def tryAcquire(context: TaskContext): Boolean = {
// Make sure that the thread/task is registered before we try and block
TaskRegistryTracker.registerThreadForRetry()
val taskAttemptId = context.taskAttemptId()
val taskInfo = tasks.computeIfAbsent(taskAttemptId, _ => {
onTaskCompletion(context, completeTask)
new SemaphoreTaskInfo()
})
val acquired = taskInfo.tryAcquire(semaphore)
if (acquired) {
GpuDeviceManager.initializeFromTask()
}
acquired
}

def acquireIfNecessary(context: TaskContext): Unit = {
// Make sure that the thread/task is registered before we try and block
TaskRegistryTracker.registerThreadForRetry()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -29,19 +29,23 @@ import org.apache.spark.sql.SparkSession

class GpuSemaphoreSuite extends AnyFunSuite
with BeforeAndAfterEach with MockitoSugar with TimeLimits with TimeLimitedTests {
val timeLimit = Span(10, Seconds)
val timeLimit: Span = Span(10, Seconds)

override def beforeEach(): Unit = {
ScalableTaskCompletion.reset()
GpuSemaphore.shutdown()
// semaphore tests depend on a SparkEnv being available
val activeSession = SparkSession.getActiveSession
if (activeSession.isEmpty) {
SparkSession.builder
if (activeSession.isDefined) {
SparkSession.getActiveSession.foreach(_.stop())
SparkSession.clearActiveSession()
}
SparkSession.builder
.appName("semaphoreTests")
.master("local[1]")
// only 1 task at a time so we can verify what blocks and what does not block
.config("spark.rapids.sql.concurrentGpuTasks", "1")
.getOrCreate()
}
}

override def afterEach(): Unit = {
Expand Down Expand Up @@ -79,4 +83,29 @@ class GpuSemaphoreSuite extends AnyFunSuite
GpuSemaphore.acquireIfNecessary(context)
verify(context, times(1)).addTaskCompletionListener[Unit](any())
}

test("multi tryAcquire") {
GpuDeviceManager.setRmmTaskInitEnabled(false)
val context = mockContext(1)
try {
assert(GpuSemaphore.tryAcquire(context))
assert(GpuSemaphore.tryAcquire(context))
} finally {
GpuSemaphore.releaseIfNecessary(context)
}
}

test("tryAcquire non-blocking") {
GpuDeviceManager.setRmmTaskInitEnabled(false)
val context1 = mockContext(1)
val context2 = mockContext(2)
try {
GpuSemaphore.acquireIfNecessary(context1)
assert(!GpuSemaphore.tryAcquire(context2))
assert(!GpuSemaphore.tryAcquire(context2))
} finally {
GpuSemaphore.releaseIfNecessary(context1)
GpuSemaphore.releaseIfNecessary(context2)
}
}
}
Loading