Skip to content

Commit

Permalink
Merge pull request NVIDIA#799 from NVIDIA/branch-0.2
Browse files Browse the repository at this point in the history
[auto-merge] branch-0.2 to branch-0.3 [skip ci] [bot]
  • Loading branch information
nvauto authored Sep 18, 2020
2 parents 1cf9c4d + b69f399 commit 4aaf0b5
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,11 @@ class RapidsDriverPlugin extends DriverPlugin with Logging {
override def init(sc: SparkContext, pluginContext: PluginContext): util.Map[String, String] = {
val sparkConf = pluginContext.conf
RapidsPluginUtils.fixupConfigs(sparkConf)
new RapidsConf(sparkConf).rapidsConfMap
val conf = new RapidsConf(sparkConf)
if (conf.shimsProviderOverride.isDefined) {
ShimLoader.setSparkShimProviderClass(conf.shimsProviderOverride.get)
}
conf.rapidsConfMap
}
}

Expand All @@ -120,6 +124,9 @@ class RapidsExecutorPlugin extends ExecutorPlugin with Logging {
extraConf: util.Map[String, String]): Unit = {
try {
val conf = new RapidsConf(extraConf.asScala.toMap)
if (conf.shimsProviderOverride.isDefined) {
ShimLoader.setSparkShimProviderClass(conf.shimsProviderOverride.get)
}

// we rely on the Rapids Plugin being run with 1 GPU per executor so we can initialize
// on executor startup.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@ import java.util.ServiceLoader

import scala.collection.JavaConverters._

import org.apache.spark.{SPARK_BUILD_USER, SPARK_VERSION, SparkConf}
import org.apache.spark.{SPARK_BUILD_USER, SPARK_VERSION}
import org.apache.spark.internal.Logging

object ShimLoader extends Logging {
private var shimProviderClass: String = null
private var sparkShims: SparkShims = null

private def detectShimProvider(): SparkShimServiceProvider = {
Expand All @@ -47,14 +48,12 @@ object ShimLoader extends Logging {
}

private def findShimProvider(): SparkShimServiceProvider = {
val conf = new RapidsConf(new SparkConf())
if (conf.shimsProviderOverride.isEmpty) {
if (shimProviderClass == null) {
detectShimProvider()
} else {
val classname = conf.shimsProviderOverride.get
logWarning(s"Overriding Spark shims provider to $classname. " +
logWarning(s"Overriding Spark shims provider to $shimProviderClass. " +
"This may be an untested configuration!")
val providerClass = Class.forName(classname)
val providerClass = Class.forName(shimProviderClass)
val constructor = providerClass.getConstructor()
constructor.newInstance().asInstanceOf[SparkShimServiceProvider]
}
Expand All @@ -76,4 +75,8 @@ object ShimLoader extends Logging {
SPARK_VERSION
}
}

def setSparkShimProviderClass(classname: String): Unit = {
shimProviderClass = classname
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,11 @@ abstract class RapidsShuffleInternalManagerBase(conf: SparkConf, isDriver: Boole

private val rapidsConf = new RapidsConf(conf)

// set the shim override if specified since the shuffle manager loads early
if (rapidsConf.shimsProviderOverride.isDefined) {
ShimLoader.setSparkShimProviderClass(rapidsConf.shimsProviderOverride.get)
}

protected val wrapped = new SortShuffleManager(conf)
GpuShuffleEnv.setRapidsShuffleManagerInitialized(true, this.getClass.getCanonicalName)
logWarning("Rapids Shuffle Plugin Enabled")
Expand Down

0 comments on commit 4aaf0b5

Please sign in to comment.