From 80223bb190aedf9cfa51ebb84057dd214438270e Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Sat, 27 Jul 2024 22:19:29 -0400 Subject: [PATCH] [SPARK-48985][CONNECT] Connect Compatible Expression Constructors ### What changes were proposed in this pull request? There are a number of hard coded expressions in the SparkConnectPlanner. Most of these expressions are hardcoded because they are missing a proper constructor, or because they are not registered in the FunctionRegistry. The Column API has a similar problem. This PR fixes most of these exceptions. ### Why are the changes needed? Reduce the number of hard coded expressions in the SparkConnectPlanner and the Column API. This will make it significantly easier to create an implementation agnostic Column API. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #47464 from hvanhovell/SPARK-48985. Authored-by: Herman van Hovell Signed-off-by: Herman van Hovell --- .../resources/error/error-conditions.json | 15 ++++ .../explain-results/function_window.explain | 2 +- .../connect/planner/SparkConnectPlanner.scala | 77 ------------------- .../catalyst/analysis/FunctionRegistry.scala | 2 +- .../analysis/ResolveTimeWindows.scala | 11 ++- .../catalyst/expressions/SessionWindow.scala | 10 --- .../expressions/collectionOperations.scala | 3 + .../expressions/conditionalExpressions.scala | 8 +- .../spark/sql/catalyst/expressions/misc.scala | 2 + .../sql/catalyst/expressions/predicates.scala | 4 + .../expressions/randomExpressions.scala | 9 +++ .../expressions/regexpExpressions.scala | 15 +++- .../expressions/windowExpressions.scala | 13 ++++ .../sql/errors/QueryCompilationErrors.scala | 29 +++++++ .../scala/org/apache/spark/sql/Column.scala | 2 +- .../org/apache/spark/sql/functions.scala | 60 +++++---------- .../spark/sql/StringFunctionsSuite.scala | 12 ++- 17 files changed, 135 insertions(+), 139 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index ad57d7ab49320..4f1172d6968e2 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -2612,6 +2612,11 @@ "expects an integer value in [0, ), but got ." ] }, + "BOOLEAN" : { + "message" : [ + "expects a boolean literal, but got ." + ] + }, "CHARSET" : { "message" : [ "expects one of the , but got ." @@ -2622,11 +2627,21 @@ "expects one of the units without quotes YEAR, QUARTER, MONTH, WEEK, DAY, DAYOFYEAR, HOUR, MINUTE, SECOND, MILLISECOND, MICROSECOND, but got the string literal ." ] }, + "INTEGER" : { + "message" : [ + "expects an integer literal, but got ." + ] + }, "LENGTH" : { "message" : [ "Expects `length` greater than or equal to 0, but got ." ] }, + "LONG" : { + "message" : [ + "expects a long literal, but got ." + ] + }, "NULL" : { "message" : [ "expects a non-NULL value." diff --git a/connect/common/src/test/resources/query-tests/explain-results/function_window.explain b/connect/common/src/test/resources/query-tests/explain-results/function_window.explain index 6adefaa786538..01b2a8907033e 100644 --- a/connect/common/src/test/resources/query-tests/explain-results/function_window.explain +++ b/connect/common/src/test/resources/query-tests/explain-results/function_window.explain @@ -1,4 +1,4 @@ -Project [window#0 AS window#0] +Project [window#0] +- Project [named_struct(start, knownnullable(precisetimestampconversion(((precisetimestampconversion(t#0, TimestampType, LongType) - CASE WHEN (((precisetimestampconversion(t#0, TimestampType, LongType) - 0) % 1000000) < cast(0 as bigint)) THEN (((precisetimestampconversion(t#0, TimestampType, LongType) - 0) % 1000000) + 1000000) ELSE ((precisetimestampconversion(t#0, TimestampType, LongType) - 0) % 1000000) END) - 0), LongType, TimestampType)), end, knownnullable(precisetimestampconversion((((precisetimestampconversion(t#0, TimestampType, LongType) - CASE WHEN (((precisetimestampconversion(t#0, TimestampType, LongType) - 0) % 1000000) < cast(0 as bigint)) THEN (((precisetimestampconversion(t#0, TimestampType, LongType) - 0) % 1000000) + 1000000) ELSE ((precisetimestampconversion(t#0, TimestampType, LongType) - 0) % 1000000) END) - 0) + 1000000), LongType, TimestampType))) AS window#0, d#0, t#0, s#0, x#0L, wt#0] +- Filter isnotnull(t#0) +- LocalRelation , [d#0, t#0, s#0, x#0L, wt#0] diff --git a/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index e790a25ec97f1..405e245b8c7bb 100644 --- a/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -1838,44 +1838,6 @@ class SparkConnectPlanner( .Product(transformExpression(fun.getArgumentsList.asScala.head)) .toAggregateExpression()) - case "when" if fun.getArgumentsCount > 0 => - val children = fun.getArgumentsList.asScala.toSeq.map(transformExpression) - Some(CaseWhen.createFromParser(children)) - - case "in" if fun.getArgumentsCount > 0 => - val children = fun.getArgumentsList.asScala.toSeq.map(transformExpression) - Some(In(children.head, children.tail)) - - case "nth_value" if fun.getArgumentsCount == 3 => - // NthValue does not have a constructor which accepts Expression typed 'ignoreNulls' - val children = fun.getArgumentsList.asScala.map(transformExpression) - val ignoreNulls = extractBoolean(children(2), "ignoreNulls") - Some(NthValue(children(0), children(1), ignoreNulls)) - - case "like" if fun.getArgumentsCount == 3 => - // Like does not have a constructor which accepts Expression typed 'escapeChar' - val children = fun.getArgumentsList.asScala.map(transformExpression) - val escapeChar = extractString(children(2), "escapeChar") - Some(Like(children(0), children(1), escapeChar.charAt(0))) - - case "ilike" if fun.getArgumentsCount == 3 => - // ILike does not have a constructor which accepts Expression typed 'escapeChar' - val children = fun.getArgumentsList.asScala.map(transformExpression) - val escapeChar = extractString(children(2), "escapeChar") - Some(ILike(children(0), children(1), escapeChar.charAt(0))) - - case "lag" if fun.getArgumentsCount == 4 => - // Lag does not have a constructor which accepts Expression typed 'ignoreNulls' - val children = fun.getArgumentsList.asScala.map(transformExpression) - val ignoreNulls = extractBoolean(children(3), "ignoreNulls") - Some(Lag(children.head, children(1), children(2), ignoreNulls)) - - case "lead" if fun.getArgumentsCount == 4 => - // Lead does not have a constructor which accepts Expression typed 'ignoreNulls' - val children = fun.getArgumentsList.asScala.map(transformExpression) - val ignoreNulls = extractBoolean(children(3), "ignoreNulls") - Some(Lead(children.head, children(1), children(2), ignoreNulls)) - case "bloom_filter_agg" if fun.getArgumentsCount == 3 => // [col, expectedNumItems: Long, numBits: Long] val children = fun.getArgumentsList.asScala.map(transformExpression) @@ -1893,33 +1855,6 @@ class SparkConnectPlanner( val unit = extractString(children(0), "unit") Some(TimestampAdd(unit, children(1), children(2))) - case "window" if Seq(2, 3, 4).contains(fun.getArgumentsCount) => - val children = fun.getArgumentsList.asScala.map(transformExpression) - val timeCol = children.head - val windowDuration = extractString(children(1), "windowDuration") - var slideDuration = windowDuration - if (fun.getArgumentsCount >= 3) { - slideDuration = extractString(children(2), "slideDuration") - } - var startTime = "0 second" - if (fun.getArgumentsCount == 4) { - startTime = extractString(children(3), "startTime") - } - Some( - Alias(TimeWindow(timeCol, windowDuration, slideDuration, startTime), "window")( - nonInheritableMetadataKeys = Seq(Dataset.DATASET_ID_KEY, Dataset.COL_POS_KEY))) - - case "session_window" if fun.getArgumentsCount == 2 => - val children = fun.getArgumentsList.asScala.map(transformExpression) - val timeCol = children.head - val sessionWindow = children.last match { - case Literal(s, StringType) if s != null => SessionWindow(timeCol, s.toString) - case other => SessionWindow(timeCol, other) - } - Some( - Alias(sessionWindow, "session_window")(nonInheritableMetadataKeys = - Seq(Dataset.DATASET_ID_KEY, Dataset.COL_POS_KEY))) - case "bucket" if fun.getArgumentsCount == 2 => val children = fun.getArgumentsList.asScala.map(transformExpression) (children.head, children.last) match { @@ -2067,18 +2002,6 @@ class SparkConnectPlanner( val (msgName, desc, options) = extractProtobufArgs(children.toSeq) Some(CatalystDataToProtobuf(children(0), msgName, desc, options)) - case "uuid" if fun.getArgumentsCount == 1 => - // Uuid does not have a constructor which accepts Expression typed 'seed' - val children = fun.getArgumentsList.asScala.map(transformExpression) - val seed = extractLong(children(0), "seed") - Some(Uuid(Some(seed))) - - case "shuffle" if fun.getArgumentsCount == 2 => - // Shuffle does not have a constructor which accepts Expression typed 'seed' - val children = fun.getArgumentsList.asScala.map(transformExpression) - val seed = extractLong(children(1), "seed") - Some(Shuffle(children(0), Some(seed))) - case _ => None } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 739fac1f33fdd..48123254a8fe2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -384,7 +384,7 @@ object FunctionRegistry { expression[Rand]("random", true, Some("3.0.0")), expression[Randn]("randn"), expression[Stack]("stack"), - expression[CaseWhen]("when"), + CaseWhen.registryEntry, // math functions expression[Acos]("acos"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTimeWindows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTimeWindows.scala index a6688f2766214..e506a3629db17 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTimeWindows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTimeWindows.scala @@ -231,6 +231,8 @@ object SessionWindowing extends Rule[LogicalPlan] { val sessionStart = PreciseTimestampConversion(session.timeColumn, session.timeColumn.dataType, LongType) val gapDuration = session.gapDuration match { + case expr if expr.dataType == CalendarIntervalType => + expr case expr if Cast.canCast(expr.dataType, CalendarIntervalType) => Cast(expr, CalendarIntervalType) case other => @@ -257,10 +259,11 @@ object SessionWindowing extends Rule[LogicalPlan] { case s: SessionWindow => sessionAttr } - val filterByTimeRange = session.gapDuration match { - case Literal(interval: CalendarInterval, CalendarIntervalType) => - interval == null || interval.months + interval.days + interval.microseconds <= 0 - case _ => true + val filterByTimeRange = if (gapDuration.foldable) { + val interval = gapDuration.eval().asInstanceOf[CalendarInterval] + interval == null || interval.months + interval.days + interval.microseconds <= 0 + } else { + true } // As same as tumbling window, we add a filter to filter out nulls. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SessionWindow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SessionWindow.scala index 021f119e0a1a6..e39ff458ddc87 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SessionWindow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SessionWindow.scala @@ -18,9 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.trees.TreePattern.{SESSION_WINDOW, TreePattern} -import org.apache.spark.sql.catalyst.util.IntervalUtils import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String /** * Represent the session window. @@ -105,12 +103,4 @@ case class SessionWindow(timeColumn: Expression, gapDuration: Expression) extend object SessionWindow { val marker = "spark.sessionWindow" - - def apply( - timeColumn: Expression, - gapDuration: String): SessionWindow = { - SessionWindow(timeColumn, - Literal(IntervalUtils.safeStringToInterval(UTF8String.fromString(gapDuration)), - CalendarIntervalType)) - } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 708a060ceda5d..7e9bf989e9cfb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1263,6 +1263,9 @@ case class Shuffle(child: Expression, randomSeed: Option[Long] = None) extends U def this(child: Expression) = this(child, None) + def this(child: Expression, seed: Expression) = + this(child, ExpressionWithRandomSeed.expressionToSeed(seed, "shuffle")) + override def stateful: Boolean = true override def seedExpression: Expression = randomSeed.map(Literal.apply).getOrElse(UnresolvedSeed) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index 046d4cbcd5be3..ca74fefb9c032 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -18,7 +18,8 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} +import org.apache.spark.sql.catalyst.analysis.{FunctionRegistryBase, TypeCheckResult, TypeCoercion} +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ @@ -391,6 +392,7 @@ case class CaseWhen( /** Factory methods for CaseWhen. */ object CaseWhen { + def apply(branches: Seq[(Expression, Expression)], elseValue: Expression): CaseWhen = { CaseWhen(branches, Option(elseValue)) } @@ -408,6 +410,10 @@ object CaseWhen { val elseValue = if (branches.size % 2 != 0) Some(branches.last) else None CaseWhen(cases, elseValue) } + + val registryEntry: (String, (ExpressionInfo, FunctionBuilder)) = { + ("when", (FunctionRegistryBase.expressionInfo[CaseWhen]("when", None), createFromParser)) + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index e9fa362de14cd..6629f724c4dda 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -244,6 +244,8 @@ case class Uuid(randomSeed: Option[Long] = None) extends LeafExpression with Non def this() = this(None) + def this(seed: Expression) = this(ExpressionWithRandomSeed.expressionToSeed(seed, "UUID")) + override def seedExpression: Expression = randomSeed.map(Literal.apply).getOrElse(UnresolvedSeed) override def withNewSeed(seed: Long): Uuid = Uuid(Some(seed)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 312493c949911..748dcf688a09e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -445,6 +445,10 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { require(list != null, "list should not be null") + def this(expressions: Seq[Expression]) = { + this(expressions.head, expressions.tail) + } + override def checkInputDataTypes(): TypeCheckResult = { val mismatchOpt = list.find(l => !DataType.equalsStructurally(l.dataType, value.dataType, ignoreNullability = true)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index db78415a0cc54..f5db972a28643 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedSeed import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.trees.TreePattern.{EXPRESSION_WITH_RANDOM_SEED, TreePattern} +import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.types._ import org.apache.spark.util.random.XORShiftRandom @@ -71,6 +72,14 @@ trait ExpressionWithRandomSeed extends Expression { def withNewSeed(seed: Long): Expression } +private[catalyst] object ExpressionWithRandomSeed { + def expressionToSeed(e: Expression, source: String): Option[Long] = e match { + case LongLiteral(seed) => Some(seed) + case Literal(null, _) => None + case _ => throw QueryCompilationErrors.invalidRandomSeedParameter(source, e) + } +} + /** Generate a random column with i.i.d. uniformly distributed values in [0, 1). */ // scalastyle:off line.size.limit @ExpressionDescription( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index 297c709c6d7d9..970397c76a1cd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.trees.BinaryLike import org.apache.spark.sql.catalyst.trees.TreePattern.{LIKE_FAMLIY, REGEXP_EXTRACT_FAMILY, REGEXP_REPLACE, TreePattern} import org.apache.spark.sql.catalyst.util.{CollationSupport, GenericArrayData, StringUtils} -import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.internal.types.{StringTypeAnyCollation, StringTypeBinaryLcase} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -82,6 +82,13 @@ abstract class StringRegexExpression extends BinaryExpression } } +private[catalyst] object StringRegexExpression { + def expressionToEscapeChar(e: Expression): Char = e match { + case StringLiteral(v) if v.length == 1 => v.charAt(0) + case _ => throw QueryCompilationErrors.invalidEscapeChar(e) + } +} + // scalastyle:off line.contains.tab line.size.limit /** * Simple RegEx pattern matching function @@ -137,6 +144,9 @@ case class Like(left: Expression, right: Expression, escapeChar: Char) def this(left: Expression, right: Expression) = this(left, right, '\\') + def this(left: Expression, right: Expression, escapeChar: Expression) = + this(left, right, StringRegexExpression.expressionToEscapeChar(escapeChar)) + override def escape(v: String): String = StringUtils.escapeLikeRegex(v, escapeChar) override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).matches() @@ -259,6 +269,9 @@ case class ILike( escapeChar: Char) extends RuntimeReplaceable with ImplicitCastInputTypes with BinaryLike[Expression] { + def this(left: Expression, right: Expression, escapeChar: Expression) = + this(left, right, StringRegexExpression.expressionToEscapeChar(escapeChar)) + override lazy val replacement: Expression = Like(Lower(left), Lower(right), escapeChar) def this(left: Expression, right: Expression) = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index 5881c456f6e86..4ff5c696a55d3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -333,6 +333,11 @@ object WindowExpression { def hasWindowExpression(e: Expression): Boolean = { e.find(_.isInstanceOf[WindowExpression]).isDefined } + + def expressionToIngnoreNulls(e: Expression, source: String): Boolean = e match { + case BooleanLiteral(ignoreNulls) => ignoreNulls + case _ => throw QueryCompilationErrors.invalidIgnoreNullsParameter(source, e) + } } case class WindowExpression( @@ -525,6 +530,9 @@ case class Lead( def this(input: Expression, offset: Expression, default: Expression) = this(input, offset, default, false) + def this(input: Expression, offset: Expression, default: Expression, ignoreNulls: Expression) = + this(input, offset, default, WindowExpression.expressionToIngnoreNulls(ignoreNulls, "lead")) + def this(input: Expression, offset: Expression) = this(input, offset, Literal(null)) def this(input: Expression) = this(input, Literal(1)) @@ -579,6 +587,9 @@ case class Lag( def this(input: Expression, inputOffset: Expression, default: Expression) = this(input, inputOffset, default, false) + def this(input: Expression, offset: Expression, default: Expression, ignoreNulls: Expression) = + this(input, offset, default, WindowExpression.expressionToIngnoreNulls(ignoreNulls, "lag")) + def this(input: Expression, inputOffset: Expression) = this(input, inputOffset, Literal(null)) def this(input: Expression) = this(input, Literal(1)) @@ -726,6 +737,8 @@ case class NthValue(input: Expression, offset: Expression, ignoreNulls: Boolean) extends AggregateWindowFunction with OffsetWindowFunction with ImplicitCastInputTypes with BinaryLike[Expression] with QueryErrorsBase { + def this(input: Expression, offset: Expression, ignoreNulls: Expression) = + this(input, offset, WindowExpression.expressionToIngnoreNulls(ignoreNulls, "nth_value")) def this(child: Expression, offset: Expression) = this(child, offset, false) override lazy val default = Literal.create(null, input.dataType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index f4169d8054fef..d1acc0dd70759 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -175,6 +175,35 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat "functionName" -> toSQLId(funcName))) } + def invalidRandomSeedParameter(functionName: String, invalidValue: Expression): Throwable = { + invalidParameter("LONG", functionName, "seed", invalidValue) + } + + def invalidReverseParameter(invalidValue: Expression): Throwable = { + invalidParameter("BOOLEAN", "collect_top_k", "reverse", invalidValue) + } + + def invalidNumParameter(invalidValue: Expression): Throwable = { + invalidParameter("INTEGER", "collect_top_k", "num", invalidValue) + } + + def invalidIgnoreNullsParameter(functionName: String, invalidValue: Expression): Throwable = { + invalidParameter("Boolean", functionName, "ignoreNulls", invalidValue) + } + + def invalidParameter( + subClass: String, + functionName: String, + parameter: String, + invalidValue: Expression): Throwable = { + new AnalysisException( + errorClass = "INVALID_PARAMETER_VALUE." + subClass, + messageParameters = Map( + "functionName" -> toSQLId(functionName), + "parameter" -> toSQLId(parameter), + "invalidValue" -> toSQLExpr(invalidValue))) + } + def nullDataSourceOption(option: String): Throwable = { new AnalysisException( errorClass = "NULL_DATA_SOURCE_OPTION", diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 2f383f45f1f2e..3108f1886c299 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -824,7 +824,7 @@ class Column(val expr: Expression) extends Logging { * @since 1.5.0 */ @scala.annotation.varargs - def isin(list: Any*): Column = withExpr { In(expr, list.map(lit(_).expr)) } + def isin(list: Any*): Column = Column.fn("in", this +: list.map(lit): _*) /** * A boolean expression that is evaluated to true if the value of this expression is contained diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 88303b1979a7b..0e62e05900a54 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -527,7 +527,7 @@ object functions { * @since 2.0.0 */ def first(e: Column, ignoreNulls: Boolean): Column = - Column.fn("first", false, ignoreNulls, e) + Column.fn("first", false, e, lit(ignoreNulls)) /** * Aggregate function: returns the first value of a column in a group. @@ -791,7 +791,7 @@ object functions { * @since 2.0.0 */ def last(e: Column, ignoreNulls: Boolean): Column = - Column.fn("last", false, ignoreNulls, e) + Column.fn("last", false, e, lit(ignoreNulls)) /** * Aggregate function: returns the last value of the column in a group. @@ -1485,7 +1485,7 @@ object functions { * @since 3.2.0 */ def lag(e: Column, offset: Int, defaultValue: Any, ignoreNulls: Boolean): Column = - Column.fn("lag", false, ignoreNulls, e, lit(offset), lit(defaultValue)) + Column.fn("lag", false, e, lit(offset), lit(defaultValue), lit(ignoreNulls)) /** * Window function: returns the value that is `offset` rows after the current row, and @@ -1552,7 +1552,7 @@ object functions { * @since 3.2.0 */ def lead(e: Column, offset: Int, defaultValue: Any, ignoreNulls: Boolean): Column = - Column.fn("lead", false, ignoreNulls, e, lit(offset), lit(defaultValue)) + Column.fn("lead", false, e, lit(offset), lit(defaultValue), lit(ignoreNulls)) /** * Window function: returns the value that is the `offset`th row of the window frame @@ -1567,7 +1567,7 @@ object functions { * @since 3.1.0 */ def nth_value(e: Column, offset: Int, ignoreNulls: Boolean): Column = - Column.fn("nth_value", false, ignoreNulls, e, lit(offset)) + Column.fn("nth_value", false, e, lit(offset), lit(ignoreNulls)) /** * Window function: returns the value that is the `offset`th row of the window frame @@ -1856,7 +1856,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def rand(seed: Long): Column = withExpr { Rand(seed) } + def rand(seed: Long): Column = Column.fn("rand", lit(seed)) /** * Generate a random column with independent and identically distributed (i.i.d.) samples @@ -1878,7 +1878,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def randn(seed: Long): Column = withExpr { Randn(seed) } + def randn(seed: Long): Column = Column.fn("randn", lit(seed)) /** * Generate a column with independent and identically distributed (i.i.d.) samples from @@ -3373,7 +3373,7 @@ object functions { * @group misc_funcs * @since 3.5.0 */ - def uuid(): Column = withExpr { Uuid(Some(Utils.random.nextLong)) } + def uuid(): Column = Column.fn("uuid", lit(Utils.random.nextLong)) /** * Returns an encrypted value of `input` using AES in given `mode` with the specified `padding`. @@ -4729,14 +4729,8 @@ object functions { * @group predicate_funcs * @since 3.5.0 */ - def like(str: Column, pattern: Column, escapeChar: Column): Column = withExpr { - escapeChar.expr match { - case StringLiteral(v) if v.length == 1 => - Like(str.expr, pattern.expr, v.charAt(0)) - case _ => - throw QueryCompilationErrors.invalidEscapeChar(escapeChar.expr) - } - } + def like(str: Column, pattern: Column, escapeChar: Column): Column = + Column.fn("like", str, pattern, escapeChar) /** * Returns true if str matches `pattern` with `escapeChar`('\'), null if any arguments are null, @@ -4754,14 +4748,8 @@ object functions { * @group predicate_funcs * @since 3.5.0 */ - def ilike(str: Column, pattern: Column, escapeChar: Column): Column = withExpr { - escapeChar.expr match { - case StringLiteral(v) if v.length == 1 => - ILike(str.expr, pattern.expr, v.charAt(0)) - case _ => - throw QueryCompilationErrors.invalidEscapeChar(escapeChar.expr) - } - } + def ilike(str: Column, pattern: Column, escapeChar: Column): Column = + Column.fn("ilike", str, pattern, escapeChar) /** * Returns true if str matches `pattern` with `escapeChar`('\') case-insensitively, null if any @@ -5705,11 +5693,8 @@ object functions { * @group datetime_funcs * @since 3.2.0 */ - def session_window(timeColumn: Column, gapDuration: String): Column = { - withExpr { - SessionWindow(timeColumn.expr, gapDuration) - }.as("session_window") - } + def session_window(timeColumn: Column, gapDuration: String): Column = + session_window(timeColumn, lit(gapDuration)) /** * Generates session window given a timestamp specifying column. @@ -5743,7 +5728,7 @@ object functions { * @since 3.2.0 */ def session_window(timeColumn: Column, gapDuration: Column): Column = - Column.fn("session_window", timeColumn, gapDuration).as("session_window") + Column.fn("session_window", timeColumn, gapDuration) /** * Converts the number of seconds from the Unix epoch (1970-01-01T00:00:00Z) @@ -7076,7 +7061,7 @@ object functions { * @group array_funcs * @since 2.4.0 */ - def shuffle(e: Column): Column = withExpr { Shuffle(e.expr, Some(Utils.random.nextLong)) } + def shuffle(e: Column): Column = Column.fn("shuffle", e, lit(Utils.random.nextLong)) /** * Returns a reversed string or an array with reverse order of elements. @@ -7415,7 +7400,7 @@ object functions { * @group xml_funcs * @since 4.0.0 */ - def schema_of_xml(xml: Column): Column = withExpr(new SchemaOfXml(xml.expr)) + def schema_of_xml(xml: Column): Column = Column.fn("schema_of_xml", xml) // scalastyle:off line.size.limit @@ -7434,9 +7419,8 @@ object functions { * @since 4.0.0 */ // scalastyle:on line.size.limit - def schema_of_xml(xml: Column, options: java.util.Map[String, String]): Column = { - withExpr(SchemaOfXml(xml.expr, options.asScala.toMap)) - } + def schema_of_xml(xml: Column, options: java.util.Map[String, String]): Column = + fnWithOptions("schema_of_xml", options.asScala.iterator, xml) // scalastyle:off line.size.limit @@ -8434,8 +8418,7 @@ object functions { */ @scala.annotation.varargs @deprecated("Use call_udf") - def callUDF(udfName: String, cols: Column*): Column = - call_function(Seq(udfName), cols: _*) + def callUDF(udfName: String, cols: Column*): Column = call_udf(udfName, cols: _*) /** * Call an user-defined function. @@ -8453,8 +8436,7 @@ object functions { * @since 3.2.0 */ @scala.annotation.varargs - def call_udf(udfName: String, cols: Column*): Column = - call_function(Seq(udfName), cols: _*) + def call_udf(udfName: String, cols: Column*): Column = Column.fn(udfName, cols: _*) /** * Call a SQL function. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index fd2661003a151..e594410014a29 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -1046,7 +1046,8 @@ class StringFunctionsSuite extends QueryTest with SharedSparkSession { df1.select(like(col("a"), col("b"), lit(618))).collect() }, errorClass = "INVALID_ESCAPE_CHAR", - parameters = Map("sqlExpr" -> "\"618\"") + parameters = Map("sqlExpr" -> "\"618\""), + context = ExpectedContext("like", getCurrentClassCallSitePattern) ) checkError( @@ -1054,7 +1055,8 @@ class StringFunctionsSuite extends QueryTest with SharedSparkSession { df1.select(ilike(col("a"), col("b"), lit(618))).collect() }, errorClass = "INVALID_ESCAPE_CHAR", - parameters = Map("sqlExpr" -> "\"618\"") + parameters = Map("sqlExpr" -> "\"618\""), + context = ExpectedContext("ilike", getCurrentClassCallSitePattern) ) // scalastyle:off @@ -1064,7 +1066,8 @@ class StringFunctionsSuite extends QueryTest with SharedSparkSession { df1.select(like(col("a"), col("b"), lit("中国"))).collect() }, errorClass = "INVALID_ESCAPE_CHAR", - parameters = Map("sqlExpr" -> "\"中国\"") + parameters = Map("sqlExpr" -> "\"中国\""), + context = ExpectedContext("like", getCurrentClassCallSitePattern) ) checkError( @@ -1072,7 +1075,8 @@ class StringFunctionsSuite extends QueryTest with SharedSparkSession { df1.select(ilike(col("a"), col("b"), lit("中国"))).collect() }, errorClass = "INVALID_ESCAPE_CHAR", - parameters = Map("sqlExpr" -> "\"中国\"") + parameters = Map("sqlExpr" -> "\"中国\""), + context = ExpectedContext("ilike", getCurrentClassCallSitePattern) ) // scalastyle:on }