diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSemaphore.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSemaphore.scala index 84baf7e9708..916f7f463a3 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSemaphore.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSemaphore.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -58,6 +58,20 @@ object GpuSemaphore { instance = new GpuSemaphore() } + /** + * A thread may try to acquire the semaphore without blocking on it. NOTE: A task completion + * listener will automatically be installed to ensure the semaphore is always released by the + * time the task completes. + */ + def tryAcquire(context: TaskContext): Boolean = { + if (context != null) { + getInstance.tryAcquire(context) + } else { + // For unit tests that might try with no context + true + } + } + /** * Tasks must call this when they begin to use the GPU. * If the task has not already acquired the GPU semaphore then it is acquired, @@ -245,6 +259,30 @@ private final class SemaphoreTaskInfo() extends Logging { } } + def tryAcquire(semaphore: Semaphore): Boolean = synchronized { + val t = Thread.currentThread() + if (hasSemaphore) { + blockedThreads.add(t) + moveToActive(t) + true + } else { + if (blockedThreads.size() == 0) { + // No other threads for this task are waiting, so we might be able to grab this directly + val ret = semaphore.tryAcquire(numPermits) + if (ret) { + hasSemaphore = true + blockedThreads.add(t) + moveToActive(t) + // no need to notify because there are no other threads and we are holding the lock + // to ensure that. + } + ret + } else { + false + } + } + } + def releaseSemaphore(semaphore: Semaphore): Unit = synchronized { val t = Thread.currentThread() activeThreads.remove(t) @@ -267,6 +305,21 @@ private final class GpuSemaphore() extends Logging { // Keep track of all tasks that are both active on the GPU and blocked waiting on the GPU private val tasks = new ConcurrentHashMap[Long, SemaphoreTaskInfo] + def tryAcquire(context: TaskContext): Boolean = { + // Make sure that the thread/task is registered before we try and block + TaskRegistryTracker.registerThreadForRetry() + val taskAttemptId = context.taskAttemptId() + val taskInfo = tasks.computeIfAbsent(taskAttemptId, _ => { + onTaskCompletion(context, completeTask) + new SemaphoreTaskInfo() + }) + val acquired = taskInfo.tryAcquire(semaphore) + if (acquired) { + GpuDeviceManager.initializeFromTask() + } + acquired + } + def acquireIfNecessary(context: TaskContext): Unit = { // Make sure that the thread/task is registered before we try and block TaskRegistryTracker.registerThreadForRetry() diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/GpuSemaphoreSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/GpuSemaphoreSuite.scala index c05bbf82afe..3992298735c 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/GpuSemaphoreSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/GpuSemaphoreSuite.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -29,19 +29,23 @@ import org.apache.spark.sql.SparkSession class GpuSemaphoreSuite extends AnyFunSuite with BeforeAndAfterEach with MockitoSugar with TimeLimits with TimeLimitedTests { - val timeLimit = Span(10, Seconds) + val timeLimit: Span = Span(10, Seconds) override def beforeEach(): Unit = { ScalableTaskCompletion.reset() GpuSemaphore.shutdown() // semaphore tests depend on a SparkEnv being available val activeSession = SparkSession.getActiveSession - if (activeSession.isEmpty) { - SparkSession.builder + if (activeSession.isDefined) { + SparkSession.getActiveSession.foreach(_.stop()) + SparkSession.clearActiveSession() + } + SparkSession.builder .appName("semaphoreTests") .master("local[1]") + // only 1 task at a time so we can verify what blocks and what does not block + .config("spark.rapids.sql.concurrentGpuTasks", "1") .getOrCreate() - } } override def afterEach(): Unit = { @@ -79,4 +83,29 @@ class GpuSemaphoreSuite extends AnyFunSuite GpuSemaphore.acquireIfNecessary(context) verify(context, times(1)).addTaskCompletionListener[Unit](any()) } + + test("multi tryAcquire") { + GpuDeviceManager.setRmmTaskInitEnabled(false) + val context = mockContext(1) + try { + assert(GpuSemaphore.tryAcquire(context)) + assert(GpuSemaphore.tryAcquire(context)) + } finally { + GpuSemaphore.releaseIfNecessary(context) + } + } + + test("tryAcquire non-blocking") { + GpuDeviceManager.setRmmTaskInitEnabled(false) + val context1 = mockContext(1) + val context2 = mockContext(2) + try { + GpuSemaphore.acquireIfNecessary(context1) + assert(!GpuSemaphore.tryAcquire(context2)) + assert(!GpuSemaphore.tryAcquire(context2)) + } finally { + GpuSemaphore.releaseIfNecessary(context1) + GpuSemaphore.releaseIfNecessary(context2) + } + } }