Skip to content

Commit

Permalink
[SPARK-48985][CONNECT] Connect Compatible Expression Constructors
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
There are a number of hard coded expressions in the SparkConnectPlanner. Most of these expressions are hardcoded because they are missing a proper constructor, or because they are not registered in the FunctionRegistry. The Column API has a similar problem. This PR fixes most of these exceptions.

### Why are the changes needed?
Reduce the number of hard coded expressions in the SparkConnectPlanner and the Column API. This will make it significantly easier to create an implementation agnostic Column API.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
Existing tests.

### Was this patch authored or co-authored using generative AI tooling?
No.

Closes #47464 from hvanhovell/SPARK-48985.

Authored-by: Herman van Hovell <herman@databricks.com>
Signed-off-by: Herman van Hovell <herman@databricks.com>
  • Loading branch information
hvanhovell committed Jul 28, 2024
1 parent c463e07 commit 80223bb
Show file tree
Hide file tree
Showing 17 changed files with 135 additions and 139 deletions.
15 changes: 15 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -2612,6 +2612,11 @@
"expects an integer value in [0, <upper>), but got <invalidValue>."
]
},
"BOOLEAN" : {
"message" : [
"expects a boolean literal, but got <invalidValue>."
]
},
"CHARSET" : {
"message" : [
"expects one of the <charsets>, but got <charset>."
Expand All @@ -2622,11 +2627,21 @@
"expects one of the units without quotes YEAR, QUARTER, MONTH, WEEK, DAY, DAYOFYEAR, HOUR, MINUTE, SECOND, MILLISECOND, MICROSECOND, but got the string literal <invalidValue>."
]
},
"INTEGER" : {
"message" : [
"expects an integer literal, but got <invalidValue>."
]
},
"LENGTH" : {
"message" : [
"Expects `length` greater than or equal to 0, but got <length>."
]
},
"LONG" : {
"message" : [
"expects a long literal, but got <invalidValue>."
]
},
"NULL" : {
"message" : [
"expects a non-NULL value."
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
Project [window#0 AS window#0]
Project [window#0]
+- Project [named_struct(start, knownnullable(precisetimestampconversion(((precisetimestampconversion(t#0, TimestampType, LongType) - CASE WHEN (((precisetimestampconversion(t#0, TimestampType, LongType) - 0) % 1000000) < cast(0 as bigint)) THEN (((precisetimestampconversion(t#0, TimestampType, LongType) - 0) % 1000000) + 1000000) ELSE ((precisetimestampconversion(t#0, TimestampType, LongType) - 0) % 1000000) END) - 0), LongType, TimestampType)), end, knownnullable(precisetimestampconversion((((precisetimestampconversion(t#0, TimestampType, LongType) - CASE WHEN (((precisetimestampconversion(t#0, TimestampType, LongType) - 0) % 1000000) < cast(0 as bigint)) THEN (((precisetimestampconversion(t#0, TimestampType, LongType) - 0) % 1000000) + 1000000) ELSE ((precisetimestampconversion(t#0, TimestampType, LongType) - 0) % 1000000) END) - 0) + 1000000), LongType, TimestampType))) AS window#0, d#0, t#0, s#0, x#0L, wt#0]
+- Filter isnotnull(t#0)
+- LocalRelation <empty>, [d#0, t#0, s#0, x#0L, wt#0]
Original file line number Diff line number Diff line change
Expand Up @@ -1838,44 +1838,6 @@ class SparkConnectPlanner(
.Product(transformExpression(fun.getArgumentsList.asScala.head))
.toAggregateExpression())

case "when" if fun.getArgumentsCount > 0 =>
val children = fun.getArgumentsList.asScala.toSeq.map(transformExpression)
Some(CaseWhen.createFromParser(children))

case "in" if fun.getArgumentsCount > 0 =>
val children = fun.getArgumentsList.asScala.toSeq.map(transformExpression)
Some(In(children.head, children.tail))

case "nth_value" if fun.getArgumentsCount == 3 =>
// NthValue does not have a constructor which accepts Expression typed 'ignoreNulls'
val children = fun.getArgumentsList.asScala.map(transformExpression)
val ignoreNulls = extractBoolean(children(2), "ignoreNulls")
Some(NthValue(children(0), children(1), ignoreNulls))

case "like" if fun.getArgumentsCount == 3 =>
// Like does not have a constructor which accepts Expression typed 'escapeChar'
val children = fun.getArgumentsList.asScala.map(transformExpression)
val escapeChar = extractString(children(2), "escapeChar")
Some(Like(children(0), children(1), escapeChar.charAt(0)))

case "ilike" if fun.getArgumentsCount == 3 =>
// ILike does not have a constructor which accepts Expression typed 'escapeChar'
val children = fun.getArgumentsList.asScala.map(transformExpression)
val escapeChar = extractString(children(2), "escapeChar")
Some(ILike(children(0), children(1), escapeChar.charAt(0)))

case "lag" if fun.getArgumentsCount == 4 =>
// Lag does not have a constructor which accepts Expression typed 'ignoreNulls'
val children = fun.getArgumentsList.asScala.map(transformExpression)
val ignoreNulls = extractBoolean(children(3), "ignoreNulls")
Some(Lag(children.head, children(1), children(2), ignoreNulls))

case "lead" if fun.getArgumentsCount == 4 =>
// Lead does not have a constructor which accepts Expression typed 'ignoreNulls'
val children = fun.getArgumentsList.asScala.map(transformExpression)
val ignoreNulls = extractBoolean(children(3), "ignoreNulls")
Some(Lead(children.head, children(1), children(2), ignoreNulls))

case "bloom_filter_agg" if fun.getArgumentsCount == 3 =>
// [col, expectedNumItems: Long, numBits: Long]
val children = fun.getArgumentsList.asScala.map(transformExpression)
Expand All @@ -1893,33 +1855,6 @@ class SparkConnectPlanner(
val unit = extractString(children(0), "unit")
Some(TimestampAdd(unit, children(1), children(2)))

case "window" if Seq(2, 3, 4).contains(fun.getArgumentsCount) =>
val children = fun.getArgumentsList.asScala.map(transformExpression)
val timeCol = children.head
val windowDuration = extractString(children(1), "windowDuration")
var slideDuration = windowDuration
if (fun.getArgumentsCount >= 3) {
slideDuration = extractString(children(2), "slideDuration")
}
var startTime = "0 second"
if (fun.getArgumentsCount == 4) {
startTime = extractString(children(3), "startTime")
}
Some(
Alias(TimeWindow(timeCol, windowDuration, slideDuration, startTime), "window")(
nonInheritableMetadataKeys = Seq(Dataset.DATASET_ID_KEY, Dataset.COL_POS_KEY)))

case "session_window" if fun.getArgumentsCount == 2 =>
val children = fun.getArgumentsList.asScala.map(transformExpression)
val timeCol = children.head
val sessionWindow = children.last match {
case Literal(s, StringType) if s != null => SessionWindow(timeCol, s.toString)
case other => SessionWindow(timeCol, other)
}
Some(
Alias(sessionWindow, "session_window")(nonInheritableMetadataKeys =
Seq(Dataset.DATASET_ID_KEY, Dataset.COL_POS_KEY)))

case "bucket" if fun.getArgumentsCount == 2 =>
val children = fun.getArgumentsList.asScala.map(transformExpression)
(children.head, children.last) match {
Expand Down Expand Up @@ -2067,18 +2002,6 @@ class SparkConnectPlanner(
val (msgName, desc, options) = extractProtobufArgs(children.toSeq)
Some(CatalystDataToProtobuf(children(0), msgName, desc, options))

case "uuid" if fun.getArgumentsCount == 1 =>
// Uuid does not have a constructor which accepts Expression typed 'seed'
val children = fun.getArgumentsList.asScala.map(transformExpression)
val seed = extractLong(children(0), "seed")
Some(Uuid(Some(seed)))

case "shuffle" if fun.getArgumentsCount == 2 =>
// Shuffle does not have a constructor which accepts Expression typed 'seed'
val children = fun.getArgumentsList.asScala.map(transformExpression)
val seed = extractLong(children(1), "seed")
Some(Shuffle(children(0), Some(seed)))

case _ => None
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ object FunctionRegistry {
expression[Rand]("random", true, Some("3.0.0")),
expression[Randn]("randn"),
expression[Stack]("stack"),
expression[CaseWhen]("when"),
CaseWhen.registryEntry,

// math functions
expression[Acos]("acos"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,8 @@ object SessionWindowing extends Rule[LogicalPlan] {
val sessionStart =
PreciseTimestampConversion(session.timeColumn, session.timeColumn.dataType, LongType)
val gapDuration = session.gapDuration match {
case expr if expr.dataType == CalendarIntervalType =>
expr
case expr if Cast.canCast(expr.dataType, CalendarIntervalType) =>
Cast(expr, CalendarIntervalType)
case other =>
Expand All @@ -257,10 +259,11 @@ object SessionWindowing extends Rule[LogicalPlan] {
case s: SessionWindow => sessionAttr
}

val filterByTimeRange = session.gapDuration match {
case Literal(interval: CalendarInterval, CalendarIntervalType) =>
interval == null || interval.months + interval.days + interval.microseconds <= 0
case _ => true
val filterByTimeRange = if (gapDuration.foldable) {
val interval = gapDuration.eval().asInstanceOf[CalendarInterval]
interval == null || interval.months + interval.days + interval.microseconds <= 0
} else {
true
}

// As same as tumbling window, we add a filter to filter out nulls.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,7 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.trees.TreePattern.{SESSION_WINDOW, TreePattern}
import org.apache.spark.sql.catalyst.util.IntervalUtils
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

/**
* Represent the session window.
Expand Down Expand Up @@ -105,12 +103,4 @@ case class SessionWindow(timeColumn: Expression, gapDuration: Expression) extend

object SessionWindow {
val marker = "spark.sessionWindow"

def apply(
timeColumn: Expression,
gapDuration: String): SessionWindow = {
SessionWindow(timeColumn,
Literal(IntervalUtils.safeStringToInterval(UTF8String.fromString(gapDuration)),
CalendarIntervalType))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1263,6 +1263,9 @@ case class Shuffle(child: Expression, randomSeed: Option[Long] = None) extends U

def this(child: Expression) = this(child, None)

def this(child: Expression, seed: Expression) =
this(child, ExpressionWithRandomSeed.expressionToSeed(seed, "shuffle"))

override def stateful: Boolean = true

override def seedExpression: Expression = randomSeed.map(Literal.apply).getOrElse(UnresolvedSeed)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
import org.apache.spark.sql.catalyst.analysis.{FunctionRegistryBase, TypeCheckResult, TypeCoercion}
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
Expand Down Expand Up @@ -391,6 +392,7 @@ case class CaseWhen(

/** Factory methods for CaseWhen. */
object CaseWhen {

def apply(branches: Seq[(Expression, Expression)], elseValue: Expression): CaseWhen = {
CaseWhen(branches, Option(elseValue))
}
Expand All @@ -408,6 +410,10 @@ object CaseWhen {
val elseValue = if (branches.size % 2 != 0) Some(branches.last) else None
CaseWhen(cases, elseValue)
}

val registryEntry: (String, (ExpressionInfo, FunctionBuilder)) = {
("when", (FunctionRegistryBase.expressionInfo[CaseWhen]("when", None), createFromParser))
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,8 @@ case class Uuid(randomSeed: Option[Long] = None) extends LeafExpression with Non

def this() = this(None)

def this(seed: Expression) = this(ExpressionWithRandomSeed.expressionToSeed(seed, "UUID"))

override def seedExpression: Expression = randomSeed.map(Literal.apply).getOrElse(UnresolvedSeed)

override def withNewSeed(seed: Long): Uuid = Uuid(Some(seed))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,10 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {

require(list != null, "list should not be null")

def this(expressions: Seq[Expression]) = {
this(expressions.head, expressions.tail)
}

override def checkInputDataTypes(): TypeCheckResult = {
val mismatchOpt = list.find(l => !DataType.equalsStructurally(l.dataType, value.dataType,
ignoreNullability = true))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedSeed
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral}
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.trees.TreePattern.{EXPRESSION_WITH_RANDOM_SEED, TreePattern}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.types._
import org.apache.spark.util.random.XORShiftRandom

Expand Down Expand Up @@ -71,6 +72,14 @@ trait ExpressionWithRandomSeed extends Expression {
def withNewSeed(seed: Long): Expression
}

private[catalyst] object ExpressionWithRandomSeed {
def expressionToSeed(e: Expression, source: String): Option[Long] = e match {
case LongLiteral(seed) => Some(seed)
case Literal(null, _) => None
case _ => throw QueryCompilationErrors.invalidRandomSeedParameter(source, e)
}
}

/** Generate a random column with i.i.d. uniformly distributed values in [0, 1). */
// scalastyle:off line.size.limit
@ExpressionDescription(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.trees.BinaryLike
import org.apache.spark.sql.catalyst.trees.TreePattern.{LIKE_FAMLIY, REGEXP_EXTRACT_FAMILY, REGEXP_REPLACE, TreePattern}
import org.apache.spark.sql.catalyst.util.{CollationSupport, GenericArrayData, StringUtils}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.internal.types.{StringTypeAnyCollation, StringTypeBinaryLcase}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
Expand Down Expand Up @@ -82,6 +82,13 @@ abstract class StringRegexExpression extends BinaryExpression
}
}

private[catalyst] object StringRegexExpression {
def expressionToEscapeChar(e: Expression): Char = e match {
case StringLiteral(v) if v.length == 1 => v.charAt(0)
case _ => throw QueryCompilationErrors.invalidEscapeChar(e)
}
}

// scalastyle:off line.contains.tab line.size.limit
/**
* Simple RegEx pattern matching function
Expand Down Expand Up @@ -137,6 +144,9 @@ case class Like(left: Expression, right: Expression, escapeChar: Char)

def this(left: Expression, right: Expression) = this(left, right, '\\')

def this(left: Expression, right: Expression, escapeChar: Expression) =
this(left, right, StringRegexExpression.expressionToEscapeChar(escapeChar))

override def escape(v: String): String = StringUtils.escapeLikeRegex(v, escapeChar)

override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).matches()
Expand Down Expand Up @@ -259,6 +269,9 @@ case class ILike(
escapeChar: Char) extends RuntimeReplaceable
with ImplicitCastInputTypes with BinaryLike[Expression] {

def this(left: Expression, right: Expression, escapeChar: Expression) =
this(left, right, StringRegexExpression.expressionToEscapeChar(escapeChar))

override lazy val replacement: Expression = Like(Lower(left), Lower(right), escapeChar)

def this(left: Expression, right: Expression) =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,11 @@ object WindowExpression {
def hasWindowExpression(e: Expression): Boolean = {
e.find(_.isInstanceOf[WindowExpression]).isDefined
}

def expressionToIngnoreNulls(e: Expression, source: String): Boolean = e match {
case BooleanLiteral(ignoreNulls) => ignoreNulls
case _ => throw QueryCompilationErrors.invalidIgnoreNullsParameter(source, e)
}
}

case class WindowExpression(
Expand Down Expand Up @@ -525,6 +530,9 @@ case class Lead(
def this(input: Expression, offset: Expression, default: Expression) =
this(input, offset, default, false)

def this(input: Expression, offset: Expression, default: Expression, ignoreNulls: Expression) =
this(input, offset, default, WindowExpression.expressionToIngnoreNulls(ignoreNulls, "lead"))

def this(input: Expression, offset: Expression) = this(input, offset, Literal(null))

def this(input: Expression) = this(input, Literal(1))
Expand Down Expand Up @@ -579,6 +587,9 @@ case class Lag(
def this(input: Expression, inputOffset: Expression, default: Expression) =
this(input, inputOffset, default, false)

def this(input: Expression, offset: Expression, default: Expression, ignoreNulls: Expression) =
this(input, offset, default, WindowExpression.expressionToIngnoreNulls(ignoreNulls, "lag"))

def this(input: Expression, inputOffset: Expression) = this(input, inputOffset, Literal(null))

def this(input: Expression) = this(input, Literal(1))
Expand Down Expand Up @@ -726,6 +737,8 @@ case class NthValue(input: Expression, offset: Expression, ignoreNulls: Boolean)
extends AggregateWindowFunction with OffsetWindowFunction with ImplicitCastInputTypes
with BinaryLike[Expression] with QueryErrorsBase {

def this(input: Expression, offset: Expression, ignoreNulls: Expression) =
this(input, offset, WindowExpression.expressionToIngnoreNulls(ignoreNulls, "nth_value"))
def this(child: Expression, offset: Expression) = this(child, offset, false)

override lazy val default = Literal.create(null, input.dataType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,35 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat
"functionName" -> toSQLId(funcName)))
}

def invalidRandomSeedParameter(functionName: String, invalidValue: Expression): Throwable = {
invalidParameter("LONG", functionName, "seed", invalidValue)
}

def invalidReverseParameter(invalidValue: Expression): Throwable = {
invalidParameter("BOOLEAN", "collect_top_k", "reverse", invalidValue)
}

def invalidNumParameter(invalidValue: Expression): Throwable = {
invalidParameter("INTEGER", "collect_top_k", "num", invalidValue)
}

def invalidIgnoreNullsParameter(functionName: String, invalidValue: Expression): Throwable = {
invalidParameter("Boolean", functionName, "ignoreNulls", invalidValue)
}

def invalidParameter(
subClass: String,
functionName: String,
parameter: String,
invalidValue: Expression): Throwable = {
new AnalysisException(
errorClass = "INVALID_PARAMETER_VALUE." + subClass,
messageParameters = Map(
"functionName" -> toSQLId(functionName),
"parameter" -> toSQLId(parameter),
"invalidValue" -> toSQLExpr(invalidValue)))
}

def nullDataSourceOption(option: String): Throwable = {
new AnalysisException(
errorClass = "NULL_DATA_SOURCE_OPTION",
Expand Down
Loading

0 comments on commit 80223bb

Please sign in to comment.