Skip to content

Commit

Permalink
Move ProxyShuffleInternalManagerBase to api [databricks] (#9506)
Browse files Browse the repository at this point in the history
This PR is taken from #9444 to separate the aspect of moving ProxyShuffleInternalManagerBase
to the sql-plugin-api module

- moves ProxyShuffleInternalManagerBase to sql-plugin-api module along
  with the required dependencies including ShimLoader
- Adds dbdeps profile to sql-plugin-api because we now have to depend 
  on spark classes in the API module as well
- Leave the parts of ShimLoader that cannot be moved now as ShimLoaderTemp
    
Signed-off-by: Gera Shegalov <gera@apache.org>
  • Loading branch information
gerashegalov authored Oct 24, 2023
1 parent bdc563e commit 93c03e8
Show file tree
Hide file tree
Showing 25 changed files with 224 additions and 202 deletions.
6 changes: 0 additions & 6 deletions datagen/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,6 @@
</dependencies>
</profile>
<profile>
<!--
Note that we are using the Spark version for all of the Databricks dependencies as well.
The jenkins/databricks/build.sh script handles installing the jars as maven artifacts.
This is to make it easier and not have to change version numbers for each individual dependency
and deal with differences between Databricks versions
-->
<id>dbdeps</id>
<activation>
<property>
Expand Down
6 changes: 1 addition & 5 deletions dist/unshimmed-common-from-spark311.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,11 @@ com/nvidia/spark/rapids/RapidsExecutorStartupMsg*
com/nvidia/spark/rapids/RapidsExecutorUpdateMsg*
com/nvidia/spark/rapids/RapidsShuffleHeartbeatHandler*
com/nvidia/spark/rapids/SQLExecPlugin*
com/nvidia/spark/rapids/ShimLoader*
com/nvidia/spark/rapids/ShimReflectionUtils*
com/nvidia/spark/rapids/ShimLoaderTemp*
com/nvidia/spark/rapids/SparkShims*
com/nvidia/spark/rapids/optimizer/SQLOptimizerPlugin*
com/nvidia/spark/udf/Plugin*
org/apache/spark/sql/rapids/AdaptiveSparkPlanHelperShim*
org/apache/spark/sql/rapids/ExecutionPlanCaptureCallback*
org/apache/spark/sql/rapids/ProxyRapidsShuffleInternalManagerBase*
org/apache/spark/sql/rapids/execution/Unshimmed*
org/apache/spark/sql/rapids/RapidsShuffleManagerLike*
rapids/*.py
rapids4spark-version-info.properties
6 changes: 0 additions & 6 deletions integration_tests/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,6 @@
</dependencies>
</profile>
<profile>
<!--
Note that we are using the Spark version for all of the Databricks dependencies as well.
The jenkins/databricks/build.sh script handles installing the jars as maven artifacts.
This is to make it easier and not have to change version numbers for each individual dependency
and deal with differences between Databricks versions
-->
<id>dbdeps</id>
<activation>
<property>
Expand Down
7 changes: 7 additions & 0 deletions shim-deps/databricks/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@
<!--
This module is going to be used as a provided dependncy. The dependencies below
are compile-scope so they are propagated to dependents as provided
Note that we are using the custom Databricks Spark version for all of the Databricks
dependencies as well.
The jenkins/databricks/build.sh script handles installing the jars as maven artifacts.
This is to make it easier and not have to change version numbers for each individual dependency
and deal with differences between Databricks versions
-->
<dependencies>
<dependency>
Expand Down
6 changes: 0 additions & 6 deletions shuffle-plugin/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,6 @@
</dependencies>
</profile>
<profile>
<!--
Note that we are using the Spark version for all of the Databricks dependencies as well.
The jenkins/databricks/build.sh script handles installing the jars as maven artifacts.
This is to make it easier and not have to change version numbers for each individual dependency
and deal with differences between Databricks versions
-->
<id>dbdeps</id>
<activation>
<property>
Expand Down
17 changes: 17 additions & 0 deletions sql-plugin-api/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,23 @@
</dependency>
</dependencies>
</profile>
<profile>
<id>dbdeps</id>
<activation>
<property>
<name>databricks</name>
</property>
</activation>
<dependencies>
<dependency>
<groupId>com.nvidia</groupId>
<artifactId>rapids-4-spark-db-bom</artifactId>
<version>${project.version}</version>
<type>pom</type>
<scope>provided</scope>
</dependency>
</dependencies>
</profile>
</profiles>
<build>
<plugins>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,12 @@

package com.nvidia.spark.rapids


import java.net.URL

import scala.annotation.tailrec
import scala.collection.JavaConverters._
import scala.collection.JavaConverters.enumerationAsScalaIteratorConverter
import scala.util.Try

import com.nvidia.spark.GpuCachedBatchSerializer
import com.nvidia.spark.rapids.delta.DeltaProbe
import com.nvidia.spark.rapids.iceberg.IcebergProvider
import org.apache.commons.lang3.reflect.MethodUtils

import org.apache.spark.{SPARK_BRANCH, SPARK_BUILD_DATE, SPARK_BUILD_USER, SPARK_REPO_URL, SPARK_REVISION, SPARK_VERSION, SparkConf, SparkEnv}
Expand All @@ -35,39 +32,28 @@ import org.apache.spark.sql.Strategy
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.{ColumnarRule, SparkPlan}
import org.apache.spark.sql.rapids.{AdaptiveSparkPlanHelperShim, ExecutionPlanCaptureCallbackBase}
import org.apache.spark.sql.rapids.execution.UnshimmedTrampolineUtil
import org.apache.spark.util.MutableURLClassLoader

/*
Plugin jar uses non-standard class file layout. It consists of three types of areas,
"parallel worlds" in the JDK's com.sun.istack.internal.tools.ParallelWorldClassLoader parlance
1. a few publicly documented classes in the conventional layout at the top
2. a large fraction of classes whose bytecode is identical under all supported Spark versions
in spark3xx-common
3. a smaller fraction of classes that differ under one of the supported Spark versions
com/nvidia/spark/SQLPlugin.class
spark3xx-common/com/nvidia/spark/rapids/CastExprMeta.class
spark311/org/apache/spark/sql/rapids/GpuUnaryMinus.class
spark320/org/apache/spark/sql/rapids/GpuUnaryMinus.class
Each shim can see a consistent parallel world without conflicts by referencing
only one conflicting directory.
E.g., Spark 3.2.0 Shim will use
jar:file:/home/spark/rapids-4-spark_2.12-23.12.0.jar!/spark3xx-common/
jar:file:/home/spark/rapids-4-spark_2.12-23.12.0.jar!/spark320/
Spark 3.1.1 will use
jar:file:/home/spark/rapids-4-spark_2.12-23.12.0.jar!/spark3xx-common/
jar:file:/home/spark/rapids-4-spark_2.12-23.12.0.jar!/spark311/
Using these Jar URL's allows referencing different bytecode produced from identical sources
by incompatible Scala / Spark dependencies.
*/
Expand Down Expand Up @@ -133,7 +119,7 @@ object ShimLoader extends Logging {
s"org.apache.spark.sql.rapids.shims.$shimId.RapidsShuffleInternalManager"
}

@tailrec
@scala.annotation.tailrec
private def findURLClassLoader(classLoader: ClassLoader): Option[ClassLoader] = {
// walk up the classloader hierarchy until we hit a classloader we can mutate
// in the upstream Spark, non-REPL/batch mode serdeClassLoader is already mutable
Expand All @@ -159,7 +145,7 @@ object ShimLoader extends Logging {
logInfo(s"findURLClassLoader found $replCl, trying parentLoader=$parentLoader")
findURLClassLoader(parentLoader)
case urlAddable: ClassLoader if null != MethodUtils.getMatchingMethod(
urlAddable.getClass, "addURL", classOf[java.net.URL]) =>
urlAddable.getClass, "addURL", classOf[java.net.URL]) =>
// slow defensive path
logInfo(s"findURLClassLoader found a urLAddable classloader $urlAddable")
Option(urlAddable)
Expand Down Expand Up @@ -208,8 +194,8 @@ object ShimLoader extends Logging {
tmpClassLoader = new MutableURLClassLoader(Array(shimURL, shimCommonURL),
getClass.getClassLoader)
logWarning("Found an unexpected context classloader " +
s"${Thread.currentThread().getContextClassLoader}. We will try to recover from this, " +
"but it may cause class loading problems.")
s"${Thread.currentThread().getContextClassLoader}. We will try to recover from this, " +
"but it may cause class loading problems.")
}
tmpClassLoader
} else {
Expand Down Expand Up @@ -257,7 +243,7 @@ object ShimLoader extends Logging {
}

assert(serviceProviderList.nonEmpty, "Classpath should contain the resource for " +
serviceProviderListPath)
serviceProviderListPath)

val numShimServiceProviders = serviceProviderList.size
val (matchingProviders, restProviders) = serviceProviderList.flatMap { shimServiceProviderStr =>
Expand Down Expand Up @@ -293,13 +279,13 @@ object ShimLoader extends Logging {
// this class will be loaded again by the real executor classloader
provider.getClass.getName
}.getOrElse {
val supportedVersions = restProviders.map {
case (p, _) =>
val buildVer = shimIdFromPackageName(p.getClass.getName).drop("spark".length)
s"${p.getShimVersion} {buildver=${buildVer}}"
}.mkString(", ")
throw new IllegalArgumentException(
s"This RAPIDS Plugin build does not support Spark build ${sparkVersion}. " +
val supportedVersions = restProviders.map {
case (p, _) =>
val buildVer = shimIdFromPackageName(p.getClass.getName).drop("spark".length)
s"${p.getShimVersion} {buildver=${buildVer}}"
}.mkString(", ")
throw new IllegalArgumentException(
s"This RAPIDS Plugin build does not support Spark build ${sparkVersion}. " +
s"Supported Spark versions: ${supportedVersions}. " +
"Consult the Release documentation at " +
"https://nvidia.github.io/spark-rapids/docs/download.html")
Expand Down Expand Up @@ -340,20 +326,12 @@ object ShimLoader extends Logging {
SPARK_BUILD_DATE
)

//
// Reflection-based API with Spark to switch the classloader used by the caller
//

def newOptimizerClass(className: String): Optimizer = {
ShimReflectionUtils.newInstanceOf[Optimizer](className)
}

def newInternalShuffleManager(conf: SparkConf, isDriver: Boolean): Any = {
val shuffleClassLoader = getShimClassLoader()
val shuffleClassName = getRapidsShuffleInternalClass
val shuffleClass = shuffleClassLoader.loadClass(shuffleClassName)
shuffleClass.getConstructor(classOf[SparkConf], java.lang.Boolean.TYPE)
.newInstance(conf, java.lang.Boolean.valueOf(isDriver))
.newInstance(conf, java.lang.Boolean.valueOf(isDriver))
}

def newDriverPlugin(): DriverPlugin = {
Expand Down Expand Up @@ -385,48 +363,12 @@ object ShimLoader extends Logging {
newInstanceOf("com.nvidia.spark.rapids.InternalExclusiveModeGpuDiscoveryPlugin")
}

def newParquetCachedBatchSerializer(): GpuCachedBatchSerializer = {
ShimReflectionUtils.newInstanceOf("com.nvidia.spark.rapids.ParquetCachedBatchSerializer")
}

def loadColumnarRDD(): Class[_] = {
ShimReflectionUtils.
loadClass("org.apache.spark.sql.rapids.execution.InternalColumnarRddConverter")
}

def newExplainPlan(): ExplainPlanBase = {
ShimReflectionUtils.newInstanceOf[ExplainPlanBase]("com.nvidia.spark.rapids.ExplainPlanImpl")
}

def newHiveProvider(): HiveProvider= {
ShimReflectionUtils.
newInstanceOf[HiveProvider]("org.apache.spark.sql.hive.rapids.HiveProviderImpl")
}

def newAvroProvider(): AvroProvider = ShimReflectionUtils.newInstanceOf[AvroProvider](
"org.apache.spark.sql.rapids.AvroProviderImpl")

def newDeltaProbe(): DeltaProbe = ShimReflectionUtils.newInstanceOf[DeltaProbe](
"com.nvidia.spark.rapids.delta.DeltaProbeImpl")

def newIcebergProvider(): IcebergProvider = ShimReflectionUtils.newInstanceOf[IcebergProvider](
"com.nvidia.spark.rapids.iceberg.IcebergProviderImpl")

def newPlanShims(): PlanShims = ShimReflectionUtils.newInstanceOf[PlanShims](
"com.nvidia.spark.rapids.shims.PlanShimsImpl"
)

def loadGpuColumnVector(): Class[_] = {
ShimReflectionUtils.loadClass("com.nvidia.spark.rapids.GpuColumnVector")
}

def newAdaptiveSparkPlanHelperShim(): AdaptiveSparkPlanHelperShim =
ShimReflectionUtils.newInstanceOf[AdaptiveSparkPlanHelperShim](
"com.nvidia.spark.rapids.AdaptiveSparkPlanHelperImpl"
)

def newExecutionPlanCaptureCallbackBase(): ExecutionPlanCaptureCallbackBase =
ShimReflectionUtils.
newInstanceOf[ExecutionPlanCaptureCallbackBase](
"org.apache.spark.sql.rapids.ShimmedExecutionPlanCaptureCallbackImpl")
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.rapids

import com.nvidia.spark.rapids.ShimLoader

import org.apache.spark.{ShuffleDependency, SparkConf, TaskContext}
import org.apache.spark.shuffle._


/**
* Trait that makes it easy to check whether we are dealing with the
* a RAPIDS Shuffle Manager
*/
trait RapidsShuffleManagerLike {
def isDriver: Boolean
def initialize: Unit
}

/**
* A simple proxy wrapper allowing to delay loading of the
* real implementation to a later point when ShimLoader
* has already updated Spark classloaders.
*
* @param conf
* @param isDriver
*/
class ProxyRapidsShuffleInternalManagerBase(
conf: SparkConf,
override val isDriver: Boolean
) extends RapidsShuffleManagerLike with Proxy {

// touched in the plugin code after the shim initialization
// is complete
lazy val self: ShuffleManager = ShimLoader.newInternalShuffleManager(conf, isDriver)
.asInstanceOf[ShuffleManager]

// This function touches the lazy val `self` so we actually instantiate
// the manager. This is called from both the driver and executor.
// In the driver, it's mostly to display information on how to enable/disable the manager,
// in the executor, the UCXShuffleTransport starts and allocates memory at this time.
override def initialize: Unit = self

def getWriter[K, V](
handle: ShuffleHandle,
mapId: Long,
context: TaskContext,
metrics: ShuffleWriteMetricsReporter
): ShuffleWriter[K, V] = {
self.getWriter(handle, mapId, context, metrics)
}

def getReader[K, C](
handle: ShuffleHandle,
startMapIndex: Int,
endMapIndex: Int,
startPartition: Int,
endPartition: Int,
context: TaskContext,
metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = {
self.getReader(handle,
startMapIndex, endMapIndex, startPartition, endPartition,
context, metrics)
}

def registerShuffle[K, V, C](
shuffleId: Int,
dependency: ShuffleDependency[K, V, C]
): ShuffleHandle = {
self.registerShuffle(shuffleId, dependency)
}

def unregisterShuffle(shuffleId: Int): Boolean = self.unregisterShuffle(shuffleId)

def shuffleBlockResolver: ShuffleBlockResolver = self.shuffleBlockResolver

def stop(): Unit = self.stop()
}

6 changes: 0 additions & 6 deletions sql-plugin/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -192,12 +192,6 @@
</dependencies>
</profile>
<profile>
<!--
Note that we are using the Spark version for all of the Databricks dependencies as well.
The jenkins/databricks/build.sh script handles installing the jars as maven artifacts.
This is to make it easier and not have to change version numbers for each individual dependency
and deal with differences between Databricks versions
-->
<id>dbdeps</id>
<activation>
<property>
Expand Down
Loading

0 comments on commit 93c03e8

Please sign in to comment.