Skip to content

Commit

Permalink
Scala UDF will compile children expressions in Project (NVIDIA#1153)
Browse files Browse the repository at this point in the history
* Scala UDF will compile child expressions in Project

Signed-off-by: Allen Xu <allxu@nvidia.com>

Co-authored-by: Alessandro Bellina <abellina@nvidia.com>

* Rebased to branch-0.3 to resolve conflicts

Signed-off-by: Allen Xu <allxu@nvidia.com>

* Remove flatten test case

Co-authored-by: Allen Xu <allxu@nvidia.com>
Co-authored-by: Alessandro Bellina <abellina@nvidia.com>
  • Loading branch information
3 people authored Jan 5, 2021
1 parent d9e6c0b commit 1b221f9
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ case class LogicalPlanRules() extends Rule[LogicalPlan] with Logging {
plan match {
case project: Project =>
Project(project.projectList.map(e => attemptToReplaceExpression(plan, e))
.asInstanceOf[Seq[NamedExpression]], project.child)
.asInstanceOf[Seq[NamedExpression]], apply(project.child))
case x => {
x.transformExpressions(replacePartialFunc(plan))
}
Expand Down
43 changes: 27 additions & 16 deletions udf-compiler/src/test/scala/com/nvidia/spark/OpcodeSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2333,28 +2333,39 @@ class OpcodeSuite extends FunSuite {
}

val u = makeUdf((x: String, y: String, z: Boolean) => {
var r = new mutable.ArrayBuffer[String]()
r = r :+ x
if (!cond(y)) {
r = r :+ y

if (z) {
r = r :+ transform(y)
}
}
var r = new mutable.ArrayBuffer[String]()
r = r :+ x
if (!cond(y)) {
r = r :+ y

if (z) {
r = r :+ transform(x)
r = r :+ transform(y)
}
r.distinct.toArray
})
}
if (z) {
r = r :+ transform(x)
}
r.distinct.toArray
})

val dataset = List(("######hello", null),
("world", "######hello"),
("", "@@@@target")).toDF("x", "y")
("world", "######hello"),
("", "@@@@target")).toDF("x", "y")
val result = dataset.withColumn("new", u('x, 'y, lit(true)))
val ref = List(("######hello", null, Array("######hello", "@@@@hello")),
("world", "######hello", Array("world", "######hello", "@@@@hello")),
("", "@@@@target", Array("", "@@@@target", "######target", null))).toDF
("world", "######hello", Array("world", "######hello", "@@@@hello")),
("", "@@@@target", Array("", "@@@@target", "######target", null))).toDF
checkEquiv(result, ref)
}

test("compile child expresion in explode") {
val myudf: (String) => Array[String] = a => {
a.split(",")
}
val u = makeUdf(myudf)
val dataset = List("first,second").toDF("x").repartition(1)
var result = dataset.withColumn("new", explode(u(col("x"))))
val ref = List(("first,second","first"),("first,second","second")).toDF("x","new")
checkEquiv(result,ref)
}
}

0 comments on commit 1b221f9

Please sign in to comment.