Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-49249][SPARK-49122] Makes SparkSession.addArtifact work with REPL #48120

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added repl/src/test/resources/IntSumUdf.class
Binary file not shown.
22 changes: 22 additions & 0 deletions repl/src/test/resources/IntSumUdf.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

import org.apache.spark.sql.api.java.UDF2

class IntSumUdf extends UDF2[Long, Long, Long] {
override def call(t1: Long, t2: Long): Long = t1 + t2
}
63 changes: 63 additions & 0 deletions repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -396,4 +396,67 @@ class ReplSuite extends SparkFunSuite {
Main.sparkContext.stop()
System.clearProperty("spark.driver.port")
}

test("register UDF via SparkSession.addArtifact") {
val artifactPath = new File("src/test/resources").toPath
val intSumUdfPath = artifactPath.resolve("IntSumUdf.class")
val output = runInterpreterInPasteMode("local",
s"""
|import org.apache.spark.sql.api.java.UDF2
|import org.apache.spark.sql.types.DataTypes
|
|spark.addArtifact("${intSumUdfPath.toString}")
|
|spark.udf.registerJava("intSum", "IntSumUdf", DataTypes.LongType)
|
|val r = spark.range(5)
| .withColumn("id2", col("id") + 1)
| .selectExpr("intSum(id, id2)")
| .collect()
|assert(r.map(_.getLong(0)).toSeq == Seq(1, 3, 5, 7, 9))
|
""".stripMargin)
assertContains("Array([1], [3], [5], [7], [9])", output)
assertDoesNotContain("error:", output)
assertDoesNotContain("Exception", output)
assertDoesNotContain("assertion failed", output)

// The UDF should not work in a new REPL session.
val anotherOutput = runInterpreterInPasteMode("local",
s"""
|val r = spark.range(5)
| .withColumn("id2", col("id") + 1)
| .selectExpr("intSum(id, id2)")
| .collect()
|
""".stripMargin)
assertContains(
"[UNRESOLVED_ROUTINE] Cannot resolve routine `intSum` on search path",
anotherOutput)
}

test("register a class via SparkSession.addArtifact") {
val artifactPath = new File("src/test/resources").toPath
val intSumUdfPath = artifactPath.resolve("IntSumUdf.class")
val output = runInterpreterInPasteMode("local",
s"""
|import org.apache.spark.sql.functions.udf
|
|spark.addArtifact("${intSumUdfPath.toString}")
|
|val intSumUdf = udf((x: Long, y: Long) => new IntSumUdf().call(x, y))
|spark.udf.register("intSum", intSumUdf)
|
|val r = spark.range(5)
| .withColumn("id2", col("id") + 1)
| .selectExpr("intSum(id, id2)")
| .collect()
|assert(r.map(_.getLong(0)).toSeq == Seq(1, 3, 5, 7, 9))
|
""".stripMargin)
assertContains("Array([1], [3], [5], [7], [9])", output)
assertDoesNotContain("error:", output)
assertDoesNotContain("Exception", output)
assertDoesNotContain("assertion failed", output)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -739,8 +739,10 @@ class SparkSession private(
// active session once we are done.
val old = SparkSession.activeThreadSession.get()
SparkSession.setActiveSession(this)
try block finally {
SparkSession.setActiveSession(old)
artifactManager.withResources {
try block finally {
SparkSession.setActiveSession(old)
}
}
}

Expand Down
21 changes: 12 additions & 9 deletions sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ import org.apache.spark.sql.execution.python.UserDefinedPythonFunction
import org.apache.spark.sql.expressions.{SparkUserDefinedFunction, UserDefinedAggregateFunction, UserDefinedAggregator, UserDefinedFunction}
import org.apache.spark.sql.internal.UserDefinedFunctionUtils.toScalaUDF
import org.apache.spark.sql.types.DataType
import org.apache.spark.util.Utils

/**
* Functions for registering user-defined functions. Use `SparkSession.udf` to access this:
Expand All @@ -44,7 +43,7 @@ import org.apache.spark.util.Utils
* @since 1.3.0
*/
@Stable
class UDFRegistration private[sql] (functionRegistry: FunctionRegistry)
class UDFRegistration private[sql] (session: SparkSession, functionRegistry: FunctionRegistry)
extends api.UDFRegistration
with Logging {
protected[sql] def registerPython(name: String, udf: UserDefinedPythonFunction): Unit = {
Expand Down Expand Up @@ -121,7 +120,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry)
*/
private[sql] def registerJavaUDAF(name: String, className: String): Unit = {
try {
val clazz = Utils.classForName[AnyRef](className)
val clazz = session.artifactManager.classloader.loadClass(className)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One follow-up here would be to cache the ArtifactManager classloader. I think we create that thing over and over.

Copy link
Contributor Author

@xupefei xupefei Sep 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree. We can invalidate the cache when a new JAR is added.

if (!classOf[UserDefinedAggregateFunction].isAssignableFrom(clazz)) {
throw QueryCompilationErrors
.classDoesNotImplementUserDefinedAggregateFunctionError(className)
Expand All @@ -138,16 +137,20 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry)

// scalastyle:off line.size.limit
/**
* Register a Java UDF class using reflection, for use from pyspark
* Register a Java UDF class using it's class name. The class must implement one of the UDF
* interfaces in the [[org.apache.spark.sql.api.java]] package, and discoverable by the current
* session's class loader.
*
* @param name udf name
* @param className fully qualified class name of udf
* @param returnDataType return type of udf. If it is null, spark would try to infer
* @param name Name of the UDF.
* @param className Fully qualified class name of the UDF.
* @param returnDataType Return type of UDF. If it is `null`, Spark would try to infer
* via reflection.
*
* @since 4.0.0
*/
private[sql] def registerJava(name: String, className: String, returnDataType: DataType): Unit = {
def registerJava(name: String, className: String, returnDataType: DataType): Unit = {
Comment on lines -148 to +151
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have to make this method public so I can call it from REPL.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not against this. I am trying to understand the user facing consequences though. I'd probably prefer that we add support for Scala UDFs as well. That can be done in a follow-up though.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you file a follow-up?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do.

try {
val clazz = Utils.classForName[AnyRef](className)
val clazz = session.artifactManager.classloader.loadClass(className)
val udfInterfaces = clazz.getGenericInterfaces
.filter(_.isInstanceOf[ParameterizedType])
.map(_.asInstanceOf[ParameterizedType])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class ArtifactManager(session: SparkSession) extends Logging {
(ArtifactUtils.concatenatePaths(artifactPath, "classes"),
s"$artifactURI${File.separator}classes${File.separator}")

protected[artifact] val state: JobArtifactState =
protected[sql] val state: JobArtifactState =
JobArtifactState(session.sessionUUID, Option(classURI))

def withResources[T](f: => T): T = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,10 @@ object SQLExecution extends Logging {
val activeSession = sparkSession
val sc = sparkSession.sparkContext
val localProps = Utils.cloneProperties(sc.getLocalProperties)
val artifactState = JobArtifactSet.getCurrentJobArtifactState.orNull
// `getCurrentJobArtifactState` will return a stat only in Spark Connect mode. In non-Connect
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should be safe to use the SparkSession's jobArtifactState. They should be the same. cc @vicennial.

// mode, we default back to the resources of the current Spark session.
val artifactState = JobArtifactSet.getCurrentJobArtifactState.getOrElse(
activeSession.artifactManager.state)
exec.submit(() => JobArtifactSet.withActiveJobArtifactState(artifactState) {
val originalSession = SparkSession.getActiveSession
val originalLocalProps = sc.getLocalProperties
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ abstract class BaseSessionStateBuilder(
* Note 1: The user-defined functions must be deterministic.
* Note 2: This depends on the `functionRegistry` field.
*/
protected def udfRegistration: UDFRegistration = new UDFRegistration(functionRegistry)
protected def udfRegistration: UDFRegistration = new UDFRegistration(session, functionRegistry)

protected def udtfRegistration: UDTFRegistration = new UDTFRegistration(tableFunctionRegistry)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import org.apache.commons.io.FileUtils

import org.apache.spark.{SparkConf, SparkException}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.api.java.UDF2
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.DataTypes
Expand Down Expand Up @@ -331,24 +330,17 @@ class ArtifactManagerSuite extends SharedSparkSession {
}
}

test("Add UDF as artifact") {
test("Added artifact can be loaded by the current SparkSession") {
val buffer = Files.readAllBytes(artifactPath.resolve("IntSumUdf.class"))
spark.addArtifact(buffer, "IntSumUdf.class")

val instance = artifactManager.classloader
.loadClass("IntSumUdf")
.getDeclaredConstructor()
.newInstance()
.asInstanceOf[UDF2[Long, Long, Long]]
spark.udf.register("intSum", instance, DataTypes.LongType)

artifactManager.withResources {
val r = spark.range(5)
.withColumn("id2", col("id") + 1)
.selectExpr("intSum(id, id2)")
.collect()
assert(r.map(_.getLong(0)).toSeq == Seq(1, 3, 5, 7, 9))
}
spark.udf.registerJava("intSum", "IntSumUdf", DataTypes.LongType)

val r = spark.range(5)
.withColumn("id2", col("id") + 1)
.selectExpr("intSum(id, id2)")
.collect()
assert(r.map(_.getLong(0)).toSeq == Seq(1, 3, 5, 7, 9))
}

private def testAddArtifactToLocalSession(
Expand Down