Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move ProxyShuffleInternalManagerBase to api [databricks] #9506

Merged
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions datagen/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,6 @@
</dependencies>
</profile>
<profile>
<!--
Note that we are using the Spark version for all of the Databricks dependencies as well.
The jenkins/databricks/build.sh script handles installing the jars as maven artifacts.
This is to make it easier and not have to change version numbers for each individual dependency
and deal with differences between Databricks versions
-->
<id>dbdeps</id>
<activation>
<property>
Expand Down
6 changes: 1 addition & 5 deletions dist/unshimmed-common-from-spark311.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,11 @@ com/nvidia/spark/rapids/RapidsExecutorStartupMsg*
com/nvidia/spark/rapids/RapidsExecutorUpdateMsg*
com/nvidia/spark/rapids/RapidsShuffleHeartbeatHandler*
com/nvidia/spark/rapids/SQLExecPlugin*
com/nvidia/spark/rapids/ShimLoader*
com/nvidia/spark/rapids/ShimReflectionUtils*
com/nvidia/spark/rapids/ShimLoaderTemp*
com/nvidia/spark/rapids/SparkShims*
com/nvidia/spark/rapids/optimizer/SQLOptimizerPlugin*
com/nvidia/spark/udf/Plugin*
org/apache/spark/sql/rapids/AdaptiveSparkPlanHelperShim*
org/apache/spark/sql/rapids/ExecutionPlanCaptureCallback*
org/apache/spark/sql/rapids/ProxyRapidsShuffleInternalManagerBase*
org/apache/spark/sql/rapids/execution/Unshimmed*
org/apache/spark/sql/rapids/RapidsShuffleManagerLike*
rapids/*.py
rapids4spark-version-info.properties
6 changes: 0 additions & 6 deletions integration_tests/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,6 @@
</dependencies>
</profile>
<profile>
<!--
Note that we are using the Spark version for all of the Databricks dependencies as well.
The jenkins/databricks/build.sh script handles installing the jars as maven artifacts.
This is to make it easier and not have to change version numbers for each individual dependency
and deal with differences between Databricks versions
-->
<id>dbdeps</id>
<activation>
<property>
Expand Down
7 changes: 7 additions & 0 deletions shim-deps/databricks/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@
<!--
This module is going to be used as a provided dependncy. The dependencies below
are compile-scope so they are propagated to dependents as provided

Note that we are using the custom Databricks Spark version for all of the Databricks
dependencies as well.

The jenkins/databricks/build.sh script handles installing the jars as maven artifacts.
This is to make it easier and not have to change version numbers for each individual dependency
and deal with differences between Databricks versions
-->
<dependencies>
<dependency>
Expand Down
6 changes: 0 additions & 6 deletions shuffle-plugin/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,6 @@
</dependencies>
</profile>
<profile>
<!--
Note that we are using the Spark version for all of the Databricks dependencies as well.
The jenkins/databricks/build.sh script handles installing the jars as maven artifacts.
This is to make it easier and not have to change version numbers for each individual dependency
and deal with differences between Databricks versions
-->
<id>dbdeps</id>
<activation>
<property>
Expand Down
17 changes: 17 additions & 0 deletions sql-plugin-api/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,23 @@
</dependency>
</dependencies>
</profile>
<profile>
<id>dbdeps</id>
<activation>
<property>
<name>databricks</name>
</property>
</activation>
<dependencies>
<dependency>
<groupId>com.nvidia</groupId>
<artifactId>rapids-4-spark-db-bom</artifactId>
<version>${project.version}</version>
<type>pom</type>
<scope>provided</scope>
</dependency>
</dependencies>
</profile>
</profiles>
<build>
<plugins>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,12 @@

package com.nvidia.spark.rapids


import java.net.URL

import scala.annotation.tailrec
import scala.collection.JavaConverters._
import scala.collection.JavaConverters.enumerationAsScalaIteratorConverter
import scala.util.Try

import com.nvidia.spark.GpuCachedBatchSerializer
import com.nvidia.spark.rapids.delta.DeltaProbe
import com.nvidia.spark.rapids.iceberg.IcebergProvider
import org.apache.commons.lang3.reflect.MethodUtils

import org.apache.spark.{SPARK_BRANCH, SPARK_BUILD_DATE, SPARK_BUILD_USER, SPARK_REPO_URL, SPARK_REVISION, SPARK_VERSION, SparkConf, SparkEnv}
Expand All @@ -35,39 +32,28 @@ import org.apache.spark.sql.Strategy
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.{ColumnarRule, SparkPlan}
import org.apache.spark.sql.rapids.{AdaptiveSparkPlanHelperShim, ExecutionPlanCaptureCallbackBase}
import org.apache.spark.sql.rapids.execution.UnshimmedTrampolineUtil
import org.apache.spark.util.MutableURLClassLoader

/*
Plugin jar uses non-standard class file layout. It consists of three types of areas,
"parallel worlds" in the JDK's com.sun.istack.internal.tools.ParallelWorldClassLoader parlance

1. a few publicly documented classes in the conventional layout at the top
2. a large fraction of classes whose bytecode is identical under all supported Spark versions
in spark3xx-common
3. a smaller fraction of classes that differ under one of the supported Spark versions

com/nvidia/spark/SQLPlugin.class

spark3xx-common/com/nvidia/spark/rapids/CastExprMeta.class

spark311/org/apache/spark/sql/rapids/GpuUnaryMinus.class
spark320/org/apache/spark/sql/rapids/GpuUnaryMinus.class

Each shim can see a consistent parallel world without conflicts by referencing
only one conflicting directory.

E.g., Spark 3.2.0 Shim will use

jar:file:/home/spark/rapids-4-spark_2.12-23.12.0.jar!/spark3xx-common/
jar:file:/home/spark/rapids-4-spark_2.12-23.12.0.jar!/spark320/

Spark 3.1.1 will use

jar:file:/home/spark/rapids-4-spark_2.12-23.12.0.jar!/spark3xx-common/
jar:file:/home/spark/rapids-4-spark_2.12-23.12.0.jar!/spark311/

Using these Jar URL's allows referencing different bytecode produced from identical sources
by incompatible Scala / Spark dependencies.
*/
Expand Down Expand Up @@ -133,7 +119,7 @@ object ShimLoader extends Logging {
s"org.apache.spark.sql.rapids.shims.$shimId.RapidsShuffleInternalManager"
}

@tailrec
@scala.annotation.tailrec
private def findURLClassLoader(classLoader: ClassLoader): Option[ClassLoader] = {
// walk up the classloader hierarchy until we hit a classloader we can mutate
// in the upstream Spark, non-REPL/batch mode serdeClassLoader is already mutable
Expand All @@ -149,17 +135,14 @@ object ShimLoader extends Logging {
// fast path
logInfo(s"findURLClassLoader found a URLClassLoader $urlCl")
Option(urlCl)
case replCl if replCl.getClass.getName == "org.apache.spark.repl.ExecutorClassLoader" ||
replCl.getClass.getName == "org.apache.spark.executor.ExecutorClassLoader" =>
// Spark 3.5.0 changed the package of ExecutorClassLoader so we check for it being
// either old package name or new one.
case replCl if replCl.getClass.getName == "org.apache.spark.repl.ExecutorClassLoader" =>
gerashegalov marked this conversation as resolved.
Show resolved Hide resolved
// https://issues.apache.org/jira/browse/SPARK-18646
val parentLoader = MethodUtils.invokeMethod(replCl, true, "parentLoader")
.asInstanceOf[ClassLoader]
logInfo(s"findURLClassLoader found $replCl, trying parentLoader=$parentLoader")
findURLClassLoader(parentLoader)
case urlAddable: ClassLoader if null != MethodUtils.getMatchingMethod(
urlAddable.getClass, "addURL", classOf[java.net.URL]) =>
urlAddable.getClass, "addURL", classOf[java.net.URL]) =>
// slow defensive path
logInfo(s"findURLClassLoader found a urLAddable classloader $urlAddable")
Option(urlAddable)
Expand Down Expand Up @@ -208,8 +191,8 @@ object ShimLoader extends Logging {
tmpClassLoader = new MutableURLClassLoader(Array(shimURL, shimCommonURL),
getClass.getClassLoader)
logWarning("Found an unexpected context classloader " +
s"${Thread.currentThread().getContextClassLoader}. We will try to recover from this, " +
"but it may cause class loading problems.")
s"${Thread.currentThread().getContextClassLoader}. We will try to recover from this, " +
"but it may cause class loading problems.")
}
tmpClassLoader
} else {
Expand Down Expand Up @@ -257,7 +240,7 @@ object ShimLoader extends Logging {
}

assert(serviceProviderList.nonEmpty, "Classpath should contain the resource for " +
serviceProviderListPath)
serviceProviderListPath)

val numShimServiceProviders = serviceProviderList.size
val (matchingProviders, restProviders) = serviceProviderList.flatMap { shimServiceProviderStr =>
Expand Down Expand Up @@ -293,13 +276,13 @@ object ShimLoader extends Logging {
// this class will be loaded again by the real executor classloader
provider.getClass.getName
}.getOrElse {
val supportedVersions = restProviders.map {
case (p, _) =>
val buildVer = shimIdFromPackageName(p.getClass.getName).drop("spark".length)
s"${p.getShimVersion} {buildver=${buildVer}}"
}.mkString(", ")
throw new IllegalArgumentException(
s"This RAPIDS Plugin build does not support Spark build ${sparkVersion}. " +
val supportedVersions = restProviders.map {
case (p, _) =>
val buildVer = shimIdFromPackageName(p.getClass.getName).drop("spark".length)
s"${p.getShimVersion} {buildver=${buildVer}}"
}.mkString(", ")
throw new IllegalArgumentException(
s"This RAPIDS Plugin build does not support Spark build ${sparkVersion}. " +
s"Supported Spark versions: ${supportedVersions}. " +
"Consult the Release documentation at " +
"https://nvidia.github.io/spark-rapids/docs/download.html")
Expand Down Expand Up @@ -340,20 +323,12 @@ object ShimLoader extends Logging {
SPARK_BUILD_DATE
)

//
// Reflection-based API with Spark to switch the classloader used by the caller
//

def newOptimizerClass(className: String): Optimizer = {
ShimReflectionUtils.newInstanceOf[Optimizer](className)
}

def newInternalShuffleManager(conf: SparkConf, isDriver: Boolean): Any = {
val shuffleClassLoader = getShimClassLoader()
val shuffleClassName = getRapidsShuffleInternalClass
val shuffleClass = shuffleClassLoader.loadClass(shuffleClassName)
shuffleClass.getConstructor(classOf[SparkConf], java.lang.Boolean.TYPE)
.newInstance(conf, java.lang.Boolean.valueOf(isDriver))
.newInstance(conf, java.lang.Boolean.valueOf(isDriver))
}

def newDriverPlugin(): DriverPlugin = {
Expand Down Expand Up @@ -385,48 +360,12 @@ object ShimLoader extends Logging {
newInstanceOf("com.nvidia.spark.rapids.InternalExclusiveModeGpuDiscoveryPlugin")
}

def newParquetCachedBatchSerializer(): GpuCachedBatchSerializer = {
ShimReflectionUtils.newInstanceOf("com.nvidia.spark.rapids.ParquetCachedBatchSerializer")
}

def loadColumnarRDD(): Class[_] = {
ShimReflectionUtils.
loadClass("org.apache.spark.sql.rapids.execution.InternalColumnarRddConverter")
}

def newExplainPlan(): ExplainPlanBase = {
ShimReflectionUtils.newInstanceOf[ExplainPlanBase]("com.nvidia.spark.rapids.ExplainPlanImpl")
}

def newHiveProvider(): HiveProvider= {
ShimReflectionUtils.
newInstanceOf[HiveProvider]("org.apache.spark.sql.hive.rapids.HiveProviderImpl")
}

def newAvroProvider(): AvroProvider = ShimReflectionUtils.newInstanceOf[AvroProvider](
"org.apache.spark.sql.rapids.AvroProviderImpl")

def newDeltaProbe(): DeltaProbe = ShimReflectionUtils.newInstanceOf[DeltaProbe](
"com.nvidia.spark.rapids.delta.DeltaProbeImpl")

def newIcebergProvider(): IcebergProvider = ShimReflectionUtils.newInstanceOf[IcebergProvider](
"com.nvidia.spark.rapids.iceberg.IcebergProviderImpl")

def newPlanShims(): PlanShims = ShimReflectionUtils.newInstanceOf[PlanShims](
"com.nvidia.spark.rapids.shims.PlanShimsImpl"
)

def loadGpuColumnVector(): Class[_] = {
ShimReflectionUtils.loadClass("com.nvidia.spark.rapids.GpuColumnVector")
}

def newAdaptiveSparkPlanHelperShim(): AdaptiveSparkPlanHelperShim =
ShimReflectionUtils.newInstanceOf[AdaptiveSparkPlanHelperShim](
"com.nvidia.spark.rapids.AdaptiveSparkPlanHelperImpl"
)

def newExecutionPlanCaptureCallbackBase(): ExecutionPlanCaptureCallbackBase =
ShimReflectionUtils.
newInstanceOf[ExecutionPlanCaptureCallbackBase](
"org.apache.spark.sql.rapids.ShimmedExecutionPlanCaptureCallbackImpl")
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.rapids

import com.nvidia.spark.rapids.ShimLoader

import org.apache.spark.{ShuffleDependency, SparkConf, TaskContext}
import org.apache.spark.shuffle._


/**
* Trait that makes it easy to check whether we are dealing with the
* a RAPIDS Shuffle Manager
*/
trait RapidsShuffleManagerLike {
def isDriver: Boolean
def initialize: Unit
}

/**
* A simple proxy wrapper allowing to delay loading of the
* real implementation to a later point when ShimLoader
* has already updated Spark classloaders.
*
* @param conf
* @param isDriver
*/
class ProxyRapidsShuffleInternalManagerBase(
conf: SparkConf,
override val isDriver: Boolean
) extends RapidsShuffleManagerLike with Proxy {

// touched in the plugin code after the shim initialization
// is complete
lazy val self: ShuffleManager = ShimLoader.newInternalShuffleManager(conf, isDriver)
.asInstanceOf[ShuffleManager]

// This function touches the lazy val `self` so we actually instantiate
// the manager. This is called from both the driver and executor.
// In the driver, it's mostly to display information on how to enable/disable the manager,
// in the executor, the UCXShuffleTransport starts and allocates memory at this time.
override def initialize: Unit = self

def getWriter[K, V](
handle: ShuffleHandle,
mapId: Long,
context: TaskContext,
metrics: ShuffleWriteMetricsReporter
): ShuffleWriter[K, V] = {
self.getWriter(handle, mapId, context, metrics)
}

def getReader[K, C](
handle: ShuffleHandle,
startMapIndex: Int,
endMapIndex: Int,
startPartition: Int,
endPartition: Int,
context: TaskContext,
metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = {
self.getReader(handle,
startMapIndex, endMapIndex, startPartition, endPartition,
context, metrics)
}

def registerShuffle[K, V, C](
shuffleId: Int,
dependency: ShuffleDependency[K, V, C]
): ShuffleHandle = {
self.registerShuffle(shuffleId, dependency)
}

def unregisterShuffle(shuffleId: Int): Boolean = self.unregisterShuffle(shuffleId)

def shuffleBlockResolver: ShuffleBlockResolver = self.shuffleBlockResolver

def stop(): Unit = self.stop()
}

Loading