Skip to content

Commit

Permalink
[SPARK-44736][CONNECT] Add Dataset.explode to Spark Connect Scala Client
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
This PR adds Dataset.explode to the Spark Connect Scala Client.

### Why are the changes needed?
To increase compatibility with the existing Dataset API in sql/core.

### Does this PR introduce _any_ user-facing change?
Yes, it adds a new method to the scala client.

### How was this patch tested?
I added a test to `UserDefinedFunctionE2ETestSuite`.

Closes #42418 from hvanhovell/SPARK-44736.

Lead-authored-by: Herman van Hovell <herman@databricks.com>
Co-authored-by: itholic <haejoon.lee@databricks.com>
Co-authored-by: Juliusz Sompolski <julek@databricks.com>
Co-authored-by: Martin Grund <martin.grund@databricks.com>
Co-authored-by: Hyukjin Kwon <gurwls223@apache.org>
Co-authored-by: Kent Yao <yao@apache.org>
Co-authored-by: Wenchen Fan <wenchen@databricks.com>
Co-authored-by: Wei Liu <wei.liu@databricks.com>
Co-authored-by: Ruifeng Zheng <ruifengz@apache.org>
Co-authored-by: Gengliang Wang <gengliang@apache.org>
Co-authored-by: Yuming Wang <yumwang@ebay.com>
Co-authored-by: Herman van Hovell <hvanhovell@databricks.com>
Co-authored-by: 余良 <yul165@chinaunicom.cn>
Co-authored-by: Dongjoon Hyun <dhyun@apple.com>
Co-authored-by: Jack Chen <jack.chen@databricks.com>
Co-authored-by: srielau <serge@rielau.com>
Co-authored-by: zhyhimont <zhyhimont@gmail.com>
Co-authored-by: Daniel Tenedorio <daniel.tenedorio@databricks.com>
Co-authored-by: Dongjoon Hyun <dongjoon@apache.org>
Co-authored-by: Zhyhimont Dmitry <zhyhimont.d@profitero.com>
Co-authored-by: Sandip Agarwala <131817656+sandip-db@users.noreply.github.com>
Co-authored-by: yangjie01 <yangjie01@baidu.com>
Co-authored-by: Yihong He <yihong.he@databricks.com>
Co-authored-by: Rameshkrishnan Muthusamy <rameshkrishnan_muthusamy@apple.com>
Co-authored-by: Jia Fan <fanjiaeminem@qq.com>
Co-authored-by: allisonwang-db <allison.wang@databricks.com>
Co-authored-by: Utkarsh <utkarsh.agarwal@databricks.com>
Co-authored-by: Cheng Pan <chengpan@apache.org>
Co-authored-by: Jason Li <jason.li@databricks.com>
Co-authored-by: Shu Wang <swang7@linkedin.com>
Co-authored-by: Nicolas Fraison <nicolas.fraison@datadoghq.com>
Co-authored-by: Max Gekk <max.gekk@gmail.com>
Co-authored-by: panbingkun <pbk1982@gmail.com>
Co-authored-by: Ziqi Liu <ziqi.liu@databricks.com>
Signed-off-by: Herman van Hovell <herman@databricks.com>
  • Loading branch information
1 parent 7070b36 commit f496cd1
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,14 @@ import java.util.{Collections, Locale}
import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.reflect.ClassTag
import scala.reflect.runtime.universe.TypeTag
import scala.util.control.NonFatal

import org.apache.spark.SparkException
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.api.java.function._
import org.apache.spark.connect.proto
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._
import org.apache.spark.sql.catalyst.expressions.OrderUtils
Expand Down Expand Up @@ -2728,6 +2730,74 @@ class Dataset[T] private[sql] (
flatMap(UdfUtils.flatMapFuncToScalaFunc(f))(encoder)
}

/**
* (Scala-specific) Returns a new Dataset where each row has been expanded to zero or more rows
* by the provided function. This is similar to a `LATERAL VIEW` in HiveQL. The columns of the
* input row are implicitly joined with each row that is output by the function.
*
* Given that this is deprecated, as an alternative, you can explode columns either using
* `functions.explode()` or `flatMap()`. The following example uses these alternatives to count
* the number of books that contain a given word:
*
* {{{
* case class Book(title: String, words: String)
* val ds: Dataset[Book]
*
* val allWords = ds.select($"title", explode(split($"words", " ")).as("word"))
*
* val bookCountPerWord = allWords.groupBy("word").agg(count_distinct("title"))
* }}}
*
* Using `flatMap()` this can similarly be exploded as:
*
* {{{
* ds.flatMap(_.words.split(" "))
* }}}
*
* @group untypedrel
* @since 3.5.0
*/
@deprecated("use flatMap() or select() with functions.explode() instead", "3.5.0")
def explode[A <: Product: TypeTag](input: Column*)(f: Row => TraversableOnce[A]): DataFrame = {
val generator = ScalarUserDefinedFunction(
UdfUtils.traversableOnceToSeq(f),
UnboundRowEncoder :: Nil,
ScalaReflection.encoderFor[Seq[A]])
select(col("*"), functions.inline(generator(struct(input: _*))))
}

/**
* (Scala-specific) Returns a new Dataset where a single column has been expanded to zero or
* more rows by the provided function. This is similar to a `LATERAL VIEW` in HiveQL. All
* columns of the input row are implicitly joined with each value that is output by the
* function.
*
* Given that this is deprecated, as an alternative, you can explode columns either using
* `functions.explode()`:
*
* {{{
* ds.select(explode(split($"words", " ")).as("word"))
* }}}
*
* or `flatMap()`:
*
* {{{
* ds.flatMap(_.words.split(" "))
* }}}
*
* @group untypedrel
* @since 3.5.0
*/
@deprecated("use flatMap() or select() with functions.explode() instead", "3.5.0")
def explode[A, B: TypeTag](inputColumn: String, outputColumn: String)(
f: A => TraversableOnce[B]): DataFrame = {
val generator = ScalarUserDefinedFunction(
UdfUtils.traversableOnceToSeq(f),
Nil,
ScalaReflection.encoderFor[Seq[B]])
select(col("*"), functions.explode(generator(col(inputColumn))).as((outputColumn)))
}

/**
* Applies a function `f` to all rows.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,66 @@ class UserDefinedFunctionE2ETestSuite extends QueryTest {
rows.forEach(x => assert(x == 42))
}

test("(deprecated) Dataset explode") {
val session: SparkSession = spark
import session.implicits._
val result1 = spark
.range(3)
.filter(col("id") =!= 1L)
.explode(col("id") + 41, col("id") + 10) { case Row(x: Long, y: Long) =>
Iterator((x, x - 1), (y, y + 1))
}
.as[(Long, Long, Long)]
.collect()
.toSeq
assert(result1 === Seq((0L, 41L, 40L), (0L, 10L, 11L), (2L, 43L, 42L), (2L, 12L, 13L)))

val result2 = Seq((1, "a b c"), (2, "a b"), (3, "a"))
.toDF("number", "letters")
.explode('letters) { case Row(letters: String) =>
letters.split(' ').map(Tuple1.apply).toSeq
}
.as[(Int, String, String)]
.collect()
.toSeq
assert(
result2 === Seq(
(1, "a b c", "a"),
(1, "a b c", "b"),
(1, "a b c", "c"),
(2, "a b", "a"),
(2, "a b", "b"),
(3, "a", "a")))

val result3 = Seq("a b c", "d e")
.toDF("words")
.explode("words", "word") { word: String =>
word.split(' ').toSeq
}
.select(col("word"))
.as[String]
.collect()
.toSeq
assert(result3 === Seq("a", "b", "c", "d", "e"))

val result4 = Seq("a b c", "d e")
.toDF("words")
.explode("words", "word") { word: String =>
word.split(' ').map(s => s -> s.head.toInt).toSeq
}
.select(col("word"), col("words"))
.as[((String, Int), String)]
.collect()
.toSeq
assert(
result4 === Seq(
(("a", 97), "a b c"),
(("b", 98), "a b c"),
(("c", 99), "a b c"),
(("d", 100), "d e"),
(("e", 101), "d e")))
}

test("Dataset typed flat map - java") {
val rows = spark
.range(5)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,6 @@ object CheckConnectJvmClientCompatibility {
ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.sqlContext"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.metadataColumn"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.selectUntyped"), // protected
ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.explode"), // deprecated
ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.rdd"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.toJavaRDD"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.javaRDD"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,10 @@ private[sql] object UdfUtils extends Serializable {

def noOp[V, K](): V => K = _ => null.asInstanceOf[K]

def traversableOnceToSeq[A, B](f: A => TraversableOnce[B]): A => Seq[B] = { value =>
f(value).toSeq
}

// (1 to 22).foreach { i =>
// val extTypeArgs = (0 to i).map(_ => "_").mkString(", ")
// val anyTypeArgs = (0 to i).map(_ => "Any").mkString(", ")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,8 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging {
val commonUdf = rel.getFunc
commonUdf.getFunctionCase match {
case proto.CommonInlineUserDefinedFunction.FunctionCase.SCALAR_SCALA_UDF =>
transformTypedMapPartitions(commonUdf, baseRel)
val analyzed = session.sessionState.executePlan(baseRel).analyzed
transformTypedMapPartitions(commonUdf, analyzed)
case proto.CommonInlineUserDefinedFunction.FunctionCase.PYTHON_UDF =>
val pythonUdf = transformPythonUDF(commonUdf)
val isBarrier = if (rel.hasIsBarrier) rel.getIsBarrier else false
Expand Down

0 comments on commit f496cd1

Please sign in to comment.