diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Plugin.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Plugin.scala index bd5d76e2337..74204e2d698 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Plugin.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Plugin.scala @@ -446,6 +446,7 @@ class RapidsDriverPlugin extends DriverPlugin with Logging { } rapidsShuffleHeartbeatManager.executorHeartbeat(id) case m: GpuCoreDumpMsg => GpuCoreDumpHandler.handleMsg(m) + case m: ProfileMsg => ProfilerOnDriver.handleMsg(m) case m => throw new IllegalStateException(s"Unknown message $m") } } @@ -458,6 +459,7 @@ class RapidsDriverPlugin extends DriverPlugin with Logging { RapidsPluginUtils.detectMultipleJars(conf) RapidsPluginUtils.logPluginMode(conf) GpuCoreDumpHandler.driverInit(sc, conf) + ProfilerOnDriver.init(sc, conf) if (GpuShuffleEnv.isRapidsShuffleAvailable(conf)) { GpuShuffleEnv.initShuffleManager() @@ -507,6 +509,7 @@ class RapidsExecutorPlugin extends ExecutorPlugin with Logging { val sparkConf = pluginContext.conf() val numCores = RapidsPluginUtils.estimateCoresOnExec(sparkConf) val conf = new RapidsConf(extraConf.asScala.toMap) + ProfilerOnExecutor.init(pluginContext, conf) // Checks if the current GPU architecture is supported by the // spark-rapids-jni and cuDF libraries. @@ -656,6 +659,7 @@ class RapidsExecutorPlugin extends ExecutorPlugin with Logging { GpuSemaphore.shutdown() PythonWorkerSemaphore.shutdown() GpuDeviceManager.shutdown() + ProfilerOnExecutor.shutdown() Option(rapidsShuffleHeartbeatEndpoint).foreach(_.close()) extraExecutorPlugins.foreach(_.shutdown()) FileCache.shutdown() @@ -692,6 +696,7 @@ class RapidsExecutorPlugin extends ExecutorPlugin with Logging { override def onTaskStart(): Unit = { startTaskNvtx(TaskContext.get) extraExecutorPlugins.foreach(_.onTaskStart()) + ProfilerOnExecutor.onTaskStart() } override def onTaskSucceeded(): Unit = { diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RangeConfMatcher.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RangeConfMatcher.scala new file mode 100644 index 00000000000..951e8f71990 --- /dev/null +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RangeConfMatcher.scala @@ -0,0 +1,87 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids + +import scala.util.Try + +/** + * Determines if a value is in a comma-separated list of values and/or + * hyphenated ranges provided by the user for a configuration setting. + */ +class RangeConfMatcher(configKey: String, configValue: Option[String]) { + def this(conf: RapidsConf, entry: ConfEntry[String]) = { + this(entry.key, Some(conf.get(entry))) + } + + def this(conf: RapidsConf, entry: OptionalConfEntry[String]) = { + this(entry.key, conf.get(entry)) + } + + private val (stringSet, intRanges) = { + configValue.map { cv => + val parts = cv.split(',') + val (rangeParts, singleParts) = parts.partition(_.contains('-')) + val ranges = try { + rangeParts.map(RangeConfMatcher.parseRange) + } catch { + case e: IllegalArgumentException => + throw new IllegalArgumentException(s"Invalid range settings for $configKey: $cv", e) + } + (singleParts.map(_.trim).toSet, ranges) + }.getOrElse((Set.empty[String], Array.empty[(Int, Int)])) + } + + val isEmpty: Boolean = stringSet.isEmpty && intRanges.isEmpty + val nonEmpty: Boolean = !isEmpty + + def size: Int = { + stringSet.size + intRanges.map { + case (start, end) => end - start + 1 + }.sum + } + + /** Returns true if the string value is in the configured values or ranges. */ + def contains(v: String): Boolean = { + stringSet.contains(v) || (intRanges.nonEmpty && Try(v.toInt).map(checkRanges).getOrElse(false)) + } + + /** Returns true if the integer value is in the configured values or ranges. */ + def contains(v: Int): Boolean = { + checkRanges(v) || stringSet.contains(v.toString) + } + + private def checkRanges(v: Int): Boolean = { + intRanges.exists { + case (start, end) => start <= v && v <= end + } + } +} + +object RangeConfMatcher { + def parseRange(rangeStr: String): (Int,Int) = { + val rangePair = rangeStr.split('-') + if (rangePair.length != 2) { + throw new IllegalArgumentException(s"Invalid range: $rangeStr") + } + val start = rangePair.head.trim.toInt + val end = rangePair.last.trim.toInt + if (end < start) { + throw new IllegalArgumentException(s"Invalid range: $rangeStr") + } + (start, end) + } +} \ No newline at end of file diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala index 27b57c1f2a0..5a97033d59c 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala @@ -708,6 +708,71 @@ val GPU_COREDUMP_PIPE_PATTERN = conf("spark.rapids.gpu.coreDump.pipePattern") .checkValues(Set("DEBUG", "MODERATE", "ESSENTIAL")) .createWithDefault("MODERATE") + val PROFILE_PATH = conf("spark.rapids.profile.pathPrefix") + .doc("Enables profiling and specifies a URI path to use when writing profile data") + .internal() + .stringConf + .createOptional + + val PROFILE_EXECUTORS = conf("spark.rapids.profile.executors") + .doc("Comma-separated list of executors IDs and hyphenated ranges of executor IDs to " + + "profile when profiling is enabled") + .internal() + .stringConf + .createWithDefault("0") + + val PROFILE_TIME_RANGES_SECONDS = conf("spark.rapids.profile.timeRangesInSeconds") + .doc("Comma-separated list of start-end ranges of time, in seconds, since executor startup " + + "to start and stop profiling. For example, a value of 10-30,100-110 will have the profiler " + + "wait for 10 seconds after executor startup then profile for 20 seconds, then wait for " + + "70 seconds then profile again for the next 10 seconds") + .internal() + .stringConf + .createOptional + + val PROFILE_JOBS = conf("spark.rapids.profile.jobs") + .doc("Comma-separated list of job IDs and hyphenated ranges of job IDs to " + + "profile when profiling is enabled") + .internal() + .stringConf + .createOptional + + val PROFILE_STAGES = conf("spark.rapids.profile.stages") + .doc("Comma-separated list of stage IDs and hyphenated ranges of stage IDs to " + + "profile when profiling is enabled") + .internal() + .stringConf + .createOptional + + val PROFILE_DRIVER_POLL_MILLIS = conf("spark.rapids.profile.driverPollMillis") + .doc("Interval in milliseconds the executors will poll for job and stage completion when " + + "stage-level profiling is used.") + .internal() + .integerConf + .createWithDefault(1000) + + val PROFILE_COMPRESSION = conf("spark.rapids.profile.compression") + .doc("Specifies the compression codec to use when writing profile data, one of " + + "zstd or none") + .internal() + .stringConf + .transform(_.toLowerCase(java.util.Locale.ROOT)) + .checkValues(Set("zstd", "none")) + .createWithDefault("zstd") + + val PROFILE_FLUSH_PERIOD_MILLIS = conf("spark.rapids.profile.flushPeriodMillis") + .doc("Specifies the time period in milliseconds to flush profile records. " + + "A value <= 0 will disable time period flushing.") + .internal() + .integerConf + .createWithDefault(0) + + val PROFILE_WRITE_BUFFER_SIZE = conf("spark.rapids.profile.writeBufferSize") + .doc("Buffer size to use when writing profile records.") + .internal() + .bytesConf(ByteUnit.BYTE) + .createWithDefault(1024 * 1024) + // ENABLE/DISABLE PROCESSING val SQL_ENABLED = conf("spark.rapids.sql.enabled") @@ -2495,6 +2560,24 @@ class RapidsConf(conf: Map[String, String]) extends Logging { lazy val metricsLevel: String = get(METRICS_LEVEL) + lazy val profilePath: Option[String] = get(PROFILE_PATH) + + lazy val profileExecutors: String = get(PROFILE_EXECUTORS) + + lazy val profileTimeRangesSeconds: Option[String] = get(PROFILE_TIME_RANGES_SECONDS) + + lazy val profileJobs: Option[String] = get(PROFILE_JOBS) + + lazy val profileStages: Option[String] = get(PROFILE_STAGES) + + lazy val profileDriverPollMillis: Int = get(PROFILE_DRIVER_POLL_MILLIS) + + lazy val profileCompression: String = get(PROFILE_COMPRESSION) + + lazy val profileFlushPeriodMillis: Int = get(PROFILE_FLUSH_PERIOD_MILLIS) + + lazy val profileWriteBufferSize: Long = get(PROFILE_WRITE_BUFFER_SIZE) + lazy val isSqlEnabled: Boolean = get(SQL_ENABLED) lazy val isSqlExecuteOnGPU: Boolean = get(SQL_MODE).equals("executeongpu") diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/profiler.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/profiler.scala new file mode 100644 index 00000000000..e6e2bcc9f7d --- /dev/null +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/profiler.scala @@ -0,0 +1,404 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids + +import java.lang.reflect.Method +import java.nio.ByteBuffer +import java.nio.channels.{Channels, WritableByteChannel} +import java.util.concurrent.{ConcurrentHashMap, Future, ScheduledExecutorService, TimeUnit} + +import scala.collection.mutable + +import com.nvidia.spark.rapids.jni.Profiler +import org.apache.hadoop.fs.Path + +import org.apache.spark.{SparkContext, TaskContext} +import org.apache.spark.api.plugin.PluginContext +import org.apache.spark.internal.Logging +import org.apache.spark.io.CompressionCodec +import org.apache.spark.scheduler.{SparkListener, SparkListenerJobEnd, SparkListenerStageCompleted} +import org.apache.spark.sql.rapids.execution.TrampolineUtil +import org.apache.spark.util.SerializableConfiguration + +object ProfilerOnExecutor extends Logging { + private val jobPattern = raw"SPARK_.*_JId_([0-9]+).*".r + private var writer: Option[ProfileWriter] = None + private var timeRanges: Option[Seq[(Long, Long)]] = None + private var jobRanges: RangeConfMatcher = null + private var stageRanges: RangeConfMatcher = null + // NOTE: Active sets are updated asynchronously, synchronize on ProfilerOnExecutor to access + private val activeJobs = mutable.HashSet[Int]() + private val activeStages = mutable.HashSet[Int]() + private var timer: Option[ScheduledExecutorService] = None + private var timerFuture: Option[Future[_]] = None + private var driverPollMillis = 0 + private val startTimestamp = System.nanoTime() + private var isProfileActive = false + private var currentContextMethod: Method = null + private var getContextMethod: Method = null + + def init(pluginCtx: PluginContext, conf: RapidsConf): Unit = { + require(writer.isEmpty, "Already initialized") + timeRanges = conf.profileTimeRangesSeconds.map(parseTimeRanges) + jobRanges = new RangeConfMatcher(conf, RapidsConf.PROFILE_JOBS) + stageRanges = new RangeConfMatcher(conf, RapidsConf.PROFILE_STAGES) + driverPollMillis = conf.profileDriverPollMillis + if (timeRanges.isDefined && (stageRanges.nonEmpty || jobRanges.nonEmpty)) { + throw new UnsupportedOperationException( + "Profiling with time ranges and stage or job ranges simultaneously is not supported") + } + if (jobRanges.nonEmpty) { + // Hadoop's CallerContext is used to identify the job ID of a task on the executor. + val callerContextClass = TrampolineUtil.classForName("org.apache.hadoop.ipc.CallerContext") + currentContextMethod = callerContextClass.getMethod("getCurrent") + getContextMethod = callerContextClass.getMethod("getContext") + } + writer = conf.profilePath.flatMap { pathPrefix => + val executorId = pluginCtx.executorID() + if (shouldProfile(executorId, conf)) { + logInfo("Initializing profiler") + if (jobRanges.nonEmpty) { + // Need caller context enabled to get the job ID of a task on the executor + TrampolineUtil.getSparkHadoopUtilConf.setBoolean("hadoop.caller.context.enabled", true) + } + val codec = conf.profileCompression match { + case "none" => None + case c => Some(TrampolineUtil.createCodec(pluginCtx.conf(), c)) + } + val w = new ProfileWriter(pluginCtx, pathPrefix, codec) + Profiler.init(w, conf.profileWriteBufferSize, conf.profileFlushPeriodMillis) + Some(w) + } else { + None + } + } + writer.foreach { _ => + updateAndSchedule() + } + } + + def onTaskStart(): Unit = { + if (jobRanges.nonEmpty) { + val callerCtx = currentContextMethod.invoke(null) + if (callerCtx != null) { + getContextMethod.invoke(callerCtx).asInstanceOf[String] match { + case jobPattern(jid) => + val jobId = jid.toInt + if (jobRanges.contains(jobId)) { + synchronized { + activeJobs.add(jobId) + enable() + startPollingDriver() + } + } + case _ => + } + } + } + if (stageRanges.nonEmpty) { + val taskCtx = TaskContext.get + val stageId = taskCtx.stageId + if (stageRanges.contains(stageId)) { + synchronized { + activeStages.add(taskCtx.stageId) + enable() + startPollingDriver() + } + } + } + } + + def shutdown(): Unit = { + writer.foreach { w => + timerFuture.foreach(_.cancel(false)) + timerFuture = None + Profiler.shutdown() + w.close() + } + writer = None + } + + private def enable(): Unit = { + writer.foreach { w => + if (!isProfileActive) { + Profiler.start() + isProfileActive = true + w.pluginCtx.send(ProfileStatusMsg(w.executorId, "profile started")) + } + } + } + + private def disable(): Unit = { + writer.foreach { w => + if (isProfileActive) { + Profiler.stop() + isProfileActive = false + w.pluginCtx.send(ProfileStatusMsg(w.executorId, "profile stopped")) + } + } + } + + private def shouldProfile(executorId: String, conf: RapidsConf): Boolean = { + val matcher = new RangeConfMatcher(conf, RapidsConf.PROFILE_EXECUTORS) + matcher.contains(executorId) + } + + private def parseTimeRanges(confVal: String): Seq[(Long, Long)] = { + val ranges = try { + confVal.split(',').map(RangeConfMatcher.parseRange).map { + case (start, end) => + // convert relative time in seconds to absolute time in nanoseconds + (startTimestamp + TimeUnit.SECONDS.toNanos(start), + startTimestamp + TimeUnit.SECONDS.toNanos(end)) + } + } catch { + case e: IllegalArgumentException => + throw new IllegalArgumentException( + s"Invalid range settings for ${RapidsConf.PROFILE_TIME_RANGES_SECONDS}: $confVal", e) + } + ranges.sorted.toIndexedSeq + } + + private def updateAndSchedule(): Unit = { + if (timeRanges.isDefined) { + if (timer.isEmpty) { + timer = Some(TrampolineUtil.newDaemonSingleThreadScheduledExecutor("profiler timer")) + } + val now = System.nanoTime() + // skip time ranges that have already passed + val currentRanges = timeRanges.get.dropWhile { + case (_, end) => end <= now + } + timeRanges = Some(currentRanges) + if (currentRanges.isEmpty) { + logWarning("No further time ranges to profile, shutting down") + shutdown() + } else { + currentRanges.headOption.foreach { + case (start, end) => + val delay = if (start <= now) { + enable() + end - now + } else { + disable() + start - now + } + timerFuture = Some(timer.get.schedule(new Runnable { + override def run(): Unit = try { + updateAndSchedule() + } catch { + case e: Exception => + logError(s"Error in profiler timer task", e) + } + }, delay, TimeUnit.NANOSECONDS)) + } + } + } else if (jobRanges.nonEmpty || stageRanges.nonEmpty) { + // nothing to do yet, profiling will start when tasks for targeted job/stage are seen + } else { + enable() + } + } + + private def startPollingDriver(): Unit = { + if (timerFuture.isEmpty) { + if (timer.isEmpty) { + timer = Some(TrampolineUtil.newDaemonSingleThreadScheduledExecutor("profiler timer")) + } + timerFuture = Some(timer.get.scheduleWithFixedDelay(() => try { + updateActiveFromDriver() + } catch { + case e: Exception => + logError("Profiler timer task error: ", e) + }, driverPollMillis, driverPollMillis, TimeUnit.MILLISECONDS)) + } + } + + private def stopPollingDriver(): Unit = { + timerFuture.foreach(_.cancel(false)) + timerFuture = None + } + + private def updateActiveFromDriver(): Unit = { + writer.foreach { w => + val (jobs, stages) = synchronized { + (activeJobs.toArray, activeStages.toArray) + } + val (completedJobs, completedStages, allDone) = + w.pluginCtx.ask(ProfileJobStageQueryMsg(jobs, stages)) + .asInstanceOf[(Array[Int], Array[Int], Boolean)] + if (completedJobs.nonEmpty || completedStages.nonEmpty) { + synchronized { + completedJobs.foreach(activeJobs.remove) + completedStages.foreach(activeStages.remove) + if (activeJobs.isEmpty && activeStages.isEmpty) { + disable() + stopPollingDriver() + } + } + } + if (allDone) { + logWarning("No further jobs or stages to profile, shutting down") + shutdown() + } + } + } +} + +class ProfileWriter( + val pluginCtx: PluginContext, + profilePathPrefix: String, + codec: Option[CompressionCodec]) extends Profiler.DataWriter with Logging { + val executorId: String = pluginCtx.executorID() + private val outPath = getOutputPath(profilePathPrefix, codec) + private val out = openOutput(codec) + private var isClosed = false + + override def write(data: ByteBuffer): Unit = { + if (!isClosed) { + while (data.hasRemaining) { + out.write(data) + } + } + } + + override def close(): Unit = { + if (!isClosed) { + isClosed = true + out.close() + logWarning(s"Profiling completed, output written to $outPath") + pluginCtx.send(ProfileEndMsg(executorId, outPath.toString)) + } + } + + private def getAppId: String = { + val appId = pluginCtx.conf.get("spark.app.id", "") + if (appId.isEmpty) { + java.lang.management.ManagementFactory.getRuntimeMXBean.getName + } else { + appId + } + } + + private def getOutputPath(prefix: String, codec: Option[CompressionCodec]): Path = { + val parentDir = new Path(prefix) + val suffix = codec.map(c => "." + TrampolineUtil.getCodecShortName(c.getClass.getName)) + .getOrElse("") + new Path(parentDir, s"rapids-profile-$getAppId-$executorId.bin$suffix") + } + + private def openOutput(codec: Option[CompressionCodec]): WritableByteChannel = { + logWarning(s"Profiler initialized, output will be written to $outPath") + val hadoopConf = pluginCtx.ask(ProfileInitMsg(executorId, outPath.toString)) + .asInstanceOf[SerializableConfiguration].value + val fs = outPath.getFileSystem(hadoopConf) + val fsStream = fs.create(outPath, false) + val outStream = codec.map(_.compressedOutputStream(fsStream)).getOrElse(fsStream) + Channels.newChannel(outStream) + } +} + +object ProfilerOnDriver extends Logging { + private var hadoopConf: SerializableConfiguration = null + private var jobRanges: RangeConfMatcher = null + private var numJobsToProfile: Long = 0L + private var stageRanges: RangeConfMatcher = null + private var numStagesToProfile: Long = 0L + private val completedJobs = new ConcurrentHashMap[Int, Unit]() + private val completedStages = new ConcurrentHashMap[Int, Unit]() + private var isJobsStageProfilingComplete = false + + def init(sc: SparkContext, conf: RapidsConf): Unit = { + // if no profile path, profiling is disabled and nothing to do + conf.profilePath.foreach { _ => + hadoopConf = new SerializableConfiguration(sc.hadoopConfiguration) + jobRanges = new RangeConfMatcher(conf, RapidsConf.PROFILE_JOBS) + stageRanges = new RangeConfMatcher(conf, RapidsConf.PROFILE_STAGES) + if (jobRanges.nonEmpty || stageRanges.nonEmpty) { + numJobsToProfile = jobRanges.size + numStagesToProfile = stageRanges.size + if (jobRanges.nonEmpty) { + // Need caller context enabled to get the job ID of a task on the executor + try { + TrampolineUtil.classForName("org.apache.hadoop.ipc.CallerContext") + } catch { + case _: ClassNotFoundException => + throw new UnsupportedOperationException(s"${RapidsConf.PROFILE_JOBS} requires " + + "Hadoop CallerContext which is unavailable.") + } + sc.getConf.set("hadoop.caller.context.enabled", "true") + } + sc.addSparkListener(Listener) + } + } + } + + def handleMsg(m: ProfileMsg): AnyRef = m match { + case ProfileInitMsg(executorId, path) => + logWarning(s"Profiling: Executor $executorId initialized profiler, writing to $path") + if (hadoopConf == null) { + throw new IllegalStateException("Hadoop configuration not set") + } + hadoopConf + case ProfileStatusMsg(executorId, msg) => + logWarning(s"Profiling: Executor $executorId: $msg") + null + case ProfileJobStageQueryMsg(activeJobs, activeStages) => + val filteredJobs = activeJobs.filter(j => completedJobs.containsKey(j)) + val filteredStages = activeStages.filter(s => completedStages.containsKey(s)) + (filteredJobs, filteredStages, isJobsStageProfilingComplete) + case ProfileEndMsg(executorId, path) => + logWarning(s"Profiling: Executor $executorId ended profiling, profile written to $path") + null + case _ => + throw new IllegalStateException(s"Unexpected profile msg: $m") + } + + private object Listener extends SparkListener { + override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { + val jobId = jobEnd.jobId + if (jobRanges.contains(jobId)) { + completedJobs.putIfAbsent(jobId, ()) + isJobsStageProfilingComplete = completedJobs.size == numJobsToProfile && + completedStages.size == numStagesToProfile + } + } + + override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = { + val stageId = stageCompleted.stageInfo.stageId + if (stageRanges.contains(stageId)) { + completedStages.putIfAbsent(stageId, ()) + isJobsStageProfilingComplete = completedJobs.size == numJobsToProfile && + completedStages.size == numStagesToProfile + } + } + } +} + +trait ProfileMsg + +case class ProfileInitMsg(executorId: String, path: String) extends ProfileMsg +case class ProfileStatusMsg(executorId: String, msg: String) extends ProfileMsg +case class ProfileEndMsg(executorId: String, path: String) extends ProfileMsg + +// Reply is a tuple of: +// - array of jobs that have completed +// - array of stages that have completed +// - boolean if there are no further jobs/stages to profile +case class ProfileJobStageQueryMsg( + activeJobs: Array[Int], + activeStages: Array[Int]) extends ProfileMsg diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/TrampolineUtil.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/TrampolineUtil.scala index db274bb6fdc..546cc7dfc43 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/TrampolineUtil.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/TrampolineUtil.scala @@ -16,8 +16,9 @@ package org.apache.spark.sql.rapids.execution -import java.util.concurrent.ThreadPoolExecutor +import java.util.concurrent.{ScheduledExecutorService, ThreadPoolExecutor} +import org.apache.hadoop.conf.Configuration import org.json4s.JsonAST import org.apache.spark.{SparkConf, SparkContext, SparkEnv, SparkMasterRegex, SparkUpgradeException, TaskContext} @@ -40,7 +41,7 @@ import org.apache.spark.sql.rapids.shims.DataTypeUtilsShim import org.apache.spark.sql.rapids.shims.SparkUpgradeExceptionShims import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.storage.BlockManagerId -import org.apache.spark.util.{ShutdownHookManager, Utils} +import org.apache.spark.util.{ShutdownHookManager, ThreadUtils, Utils} object TrampolineUtil { def doExecuteBroadcast[T](child: SparkPlan): Broadcast[T] = child.doExecuteBroadcast() @@ -228,12 +229,17 @@ object TrampolineUtil { // We want to utilize the ThreadUtils class' ThreadPoolExecutor creation // which gives us important Hadoop config variables that are needed for the // Unity Catalog authentication - org.apache.spark.util.ThreadUtils.newDaemonCachedThreadPool(prefix, maxThreadNumber, - keepAliveSeconds) + ThreadUtils.newDaemonCachedThreadPool(prefix, maxThreadNumber, keepAliveSeconds) + } + + def newDaemonSingleThreadScheduledExecutor(threadName: String): ScheduledExecutorService = { + ThreadUtils.newDaemonSingleThreadScheduledExecutor(threadName) } def postEvent(sc: SparkContext, sparkEvent: SparkListenerEvent): Unit = { sc.listenerBus.post(sparkEvent) } + + def getSparkHadoopUtilConf: Configuration = SparkHadoopUtil.get.conf } diff --git a/sql-plugin/src/test/scala/com/nvidia/spark/rapids/RangeConfMatcherSuite.scala b/sql-plugin/src/test/scala/com/nvidia/spark/rapids/RangeConfMatcherSuite.scala new file mode 100644 index 00000000000..518bd3012f2 --- /dev/null +++ b/sql-plugin/src/test/scala/com/nvidia/spark/rapids/RangeConfMatcherSuite.scala @@ -0,0 +1,82 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids + +import org.scalatest.funsuite.AnyFunSuite + +class RangeConfMatcherSuite extends AnyFunSuite { + + test("empty") { + val conf = new RapidsConf(Map(RapidsConf.PROFILE_EXECUTORS.key -> "")) + val matcher = new RangeConfMatcher(conf, RapidsConf.PROFILE_EXECUTORS) + assert(!matcher.contains("x")) + assert(!matcher.contains(0)) + } + + test("bad ranges") { + Seq("-", "-4", "4-", "4-3", "d-4", "4-d", "23a-24b", "3-5,8,x-y").foreach { v => + val conf = new RapidsConf(Map(RapidsConf.PROFILE_EXECUTORS.key -> v)) + assertThrows[IllegalArgumentException] { + new RangeConfMatcher(conf, RapidsConf.PROFILE_EXECUTORS) + } + } + } + + test("singles") { + Seq("driver", "0,driver", "0, driver", "driver, 0", "1, driver, x").foreach { v => + val conf = new RapidsConf(Map(RapidsConf.PROFILE_EXECUTORS.key -> v)) + val matcher = new RangeConfMatcher(conf, RapidsConf.PROFILE_EXECUTORS) + assert(matcher.contains("driver")) + assert(!matcher.contains("driverx")) + assert(!matcher.contains("xdriver")) + assert(!matcher.contains("drive")) + assert(!matcher.contains("drive")) + } + } + + test("range only") { + Seq("7-7", "3-7", "2-30", "2-3,5-7,8-10", "2-3, 5-7, 8-10", + " 2 - 3, 5 - 7, 8 - 10").foreach { v => + val conf = new RapidsConf(Map(RapidsConf.PROFILE_EXECUTORS.key -> v)) + val matcher = new RangeConfMatcher(conf, RapidsConf.PROFILE_EXECUTORS) + assert(matcher.contains("7")) + assert(matcher.contains(7)) + assert(!matcher.contains("0")) + assert(!matcher.contains(0)) + assert(!matcher.contains("70")) + assert(!matcher.contains(70)) + } + } + + test("singles range mix") { + Seq("driver,7-10", "driver, 7 - 10", "driver, 7-10", "3-5,7,1-3,driver").foreach { v => + val conf = new RapidsConf(Map(RapidsConf.PROFILE_EXECUTORS.key -> v)) + val matcher = new RangeConfMatcher(conf, RapidsConf.PROFILE_EXECUTORS) + assert(matcher.contains("driver")) + assert(!matcher.contains("driverx")) + assert(!matcher.contains("xdriver")) + assert(!matcher.contains("drive")) + assert(!matcher.contains("drive")) + assert(matcher.contains("7")) + assert(matcher.contains(7)) + assert(!matcher.contains("0")) + assert(!matcher.contains(0)) + assert(!matcher.contains("70")) + assert(!matcher.contains(70)) + } + } +}