Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-45760][SQL][FOLLOWUP] Inline With inside conditional branches #43978

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,11 @@ trait ConditionalExpression extends Expression {
*/
def alwaysEvaluatedInputs: Seq[Expression]

/**
* Return a copy of itself with a new `alwaysEvaluatedInputs`.
*/
def withNewAlwaysEvaluatedInputs(alwaysEvaluatedInputs: Seq[Expression]): ConditionalExpression

/**
* Return groups of branches. For each group, at least one branch will be hit at runtime,
* so that we can eagerly evaluate the common expressions of a group.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
*/
override def alwaysEvaluatedInputs: Seq[Expression] = predicate :: Nil

override def withNewAlwaysEvaluatedInputs(alwaysEvaluatedInputs: Seq[Expression]): If = {
copy(predicate = alwaysEvaluatedInputs.head)
}

override def branchGroups: Seq[Seq[Expression]] = Seq(Seq(trueValue, falseValue))

final override val nodePatterns : Seq[TreePattern] = Seq(IF)
Expand Down Expand Up @@ -165,8 +169,15 @@ case class CaseWhen(

final override val nodePatterns : Seq[TreePattern] = Seq(CASE_WHEN)

override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression =
super.legacyWithNewChildren(newChildren)
override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): CaseWhen = {
if (newChildren.length % 2 == 0) {
copy(branches = newChildren.grouped(2).map { case Seq(a, b) => (a, b) }.toSeq)
} else {
copy(
branches = newChildren.dropRight(1).grouped(2).map { case Seq(a, b) => (a, b) }.toSeq,
elseValue = newChildren.lastOption)
}
}

// both then and else expressions should be considered.
@transient
Expand Down Expand Up @@ -213,6 +224,10 @@ case class CaseWhen(
*/
override def alwaysEvaluatedInputs: Seq[Expression] = children.head :: Nil

override def withNewAlwaysEvaluatedInputs(alwaysEvaluatedInputs: Seq[Expression]): CaseWhen = {
withNewChildrenInternal(alwaysEvaluatedInputs.toIndexedSeq ++ children.drop(1))
}

override def branchGroups: Seq[Seq[Expression]] = {
// We look at subexpressions in conditions and values of `CaseWhen` separately. It is
// because a subexpression in conditions will be run no matter which condition is matched
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ case class Coalesce(children: Seq[Expression])
*/
override def alwaysEvaluatedInputs: Seq[Expression] = children.head :: Nil

override def withNewAlwaysEvaluatedInputs(alwaysEvaluatedInputs: Seq[Expression]): Coalesce = {
withNewChildrenInternal(alwaysEvaluatedInputs.toIndexedSeq ++ children.drop(1))
}

override def branchGroups: Seq[Seq[Expression]] = if (children.length > 1) {
// If there is only one child, the first child is already covered by
// `alwaysEvaluatedInputs` and we should exclude it here.
Expand Down Expand Up @@ -290,6 +294,10 @@ case class NaNvl(left: Expression, right: Expression)
*/
override def alwaysEvaluatedInputs: Seq[Expression] = left :: Nil

override def withNewAlwaysEvaluatedInputs(alwaysEvaluatedInputs: Seq[Expression]): NaNvl = {
copy(left = alwaysEvaluatedInputs.head)
}

override def branchGroups: Seq[Seq[Expression]] = Seq(children)

override def eval(input: InternalRow): Any = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer

import scala.collection.mutable

import org.apache.spark.sql.catalyst.expressions.{Alias, CommonExpressionDef, CommonExpressionRef, Expression, With}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern.{COMMON_EXPR_REF, WITH_EXPRESSION}
Expand All @@ -35,56 +35,82 @@ object RewriteWithExpression extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = {
plan.transformWithPruning(_.containsPattern(WITH_EXPRESSION)) {
case p if p.expressions.exists(_.containsPattern(WITH_EXPRESSION)) =>
var newChildren = p.children
var newPlan: LogicalPlan = p.transformExpressionsUp {
case With(child, defs) =>
val refToExpr = mutable.HashMap.empty[Long, Expression]
val childProjections = Array.fill(newChildren.size)(mutable.ArrayBuffer.empty[Alias])
val inputPlans = p.children.toArray
var newPlan: LogicalPlan = p.mapExpressions { expr =>
rewriteWithExprAndInputPlans(expr, inputPlans)
}
newPlan = newPlan.withNewChildren(inputPlans.toIndexedSeq)
if (p.output == newPlan.output) {
newPlan
} else {
Project(p.output, newPlan)
}
}
}

defs.zipWithIndex.foreach { case (CommonExpressionDef(child, id), index) =>
if (CollapseProject.isCheap(child)) {
refToExpr(id) = child
} else {
val childProjectionIndex = newChildren.indexWhere(
c => child.references.subsetOf(c.outputSet)
)
if (childProjectionIndex == -1) {
// When we cannot rewrite the common expressions, force to inline them so that the
// query can still run. This can happen if the join condition contains `With` and
// the common expression references columns from both join sides.
// TODO: things can go wrong if the common expression is nondeterministic. We
// don't fix it for now to match the old buggy behavior when certain
// `RuntimeReplaceable` did not use the `With` expression.
// TODO: we should calculate the ref count and also inline the common expression
// if it's ref count is 1.
refToExpr(id) = child
} else {
val alias = Alias(child, s"_common_expr_$index")()
childProjections(childProjectionIndex) += alias
refToExpr(id) = alias.toAttribute
}
}
}
private def rewriteWithExprAndInputPlans(
e: Expression,
inputPlans: Array[LogicalPlan]): Expression = {
if (!e.containsPattern(WITH_EXPRESSION)) return e
e match {
case w: With =>
// Rewrite nested With expression in CommonExpressionDef first.
val defs = w.defs.map(rewriteWithExprAndInputPlans(_, inputPlans))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that we have "manual" recursion (instead of transformExpressionsUp()), shall we deal with nested Withs in w.child too?

Copy link
Contributor

@peter-toth peter-toth Nov 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, the current logic seems to behave correctly if there is an inner With in an outer With's child and the inner has a definition with a reference to an outer definition . (The previous transformExpressionsUp() had issues in that case.) But the rule is not idempotent now, so maybe we should recurse into w.child after replacing CommonExpressionRefs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good catch! It seems doesn't matter when to recurse into w.child, either before replacing CommonExpressionRef or after is fine?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe before is better, as the expression tree may be much larger after replacing CommonExpressionRef

Copy link
Contributor

@peter-toth peter-toth Nov 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure. E.g. if we have With(With(x + x, Seq(x = y + y)), Seq(y = a + 1)) where x and y are references and a is an attribute and we would recurse into With(x + x, Seq(x = y + y)) before replacing the y references to actual attributes, that aliases a + 1, then the childProjectionIndex calculation for y + y won't find the right child, will it? But an UT covering this case would be good. :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh correlated nested With! I'm not sure if we want to support it or not... But at least we should fail if we don't want to support it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then we may need a test for that (either supported or failed if not).

val refToExpr = mutable.HashMap.empty[Long, Expression]
val childProjections = Array.fill(inputPlans.length)(mutable.ArrayBuffer.empty[Alias])

newChildren = newChildren.zip(childProjections).map { case (child, projections) =>
if (projections.nonEmpty) {
Project(child.output ++ projections, child)
} else {
child
}
defs.zipWithIndex.foreach { case (CommonExpressionDef(child, id), index) =>
if (CollapseProject.isCheap(child)) {
refToExpr(id) = child
} else {
val childProjectionIndex = inputPlans.indexWhere(
c => child.references.subsetOf(c.outputSet)
)
if (childProjectionIndex == -1) {
// When we cannot rewrite the common expressions, force to inline them so that the
// query can still run. This can happen if the join condition contains `With` and
// the common expression references columns from both join sides.
// TODO: things can go wrong if the common expression is nondeterministic. We
// don't fix it for now to match the old buggy behavior when certain
// `RuntimeReplaceable` did not use the `With` expression.
// TODO: we should calculate the ref count and also inline the common expression
// if it's ref count is 1.
refToExpr(id) = child
} else {
val alias = Alias(child, s"_common_expr_$index")()
childProjections(childProjectionIndex) += alias
refToExpr(id) = alias.toAttribute
}
}
}

for (i <- inputPlans.indices) {
val projectList = childProjections(i)
if (projectList.nonEmpty) {
inputPlans(i) = Project(inputPlans(i).output ++ projectList, inputPlans(i))
}
}

w.child.transformWithPruning(_.containsPattern(COMMON_EXPR_REF)) {
case ref: CommonExpressionRef => refToExpr(ref.id)
}

case c: ConditionalExpression =>
val newAlwaysEvaluatedInputs = c.alwaysEvaluatedInputs.map(
rewriteWithExprAndInputPlans(_, inputPlans))
Comment on lines +110 to +111
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is dealing with common expressions only in always evaluated input e.g., predicate of If.

How about common expressions shared between predicate and branches?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about it before. The problem is that it's hard to update the original ConditionalExpression with the new shared common expressions. alwaysEvaluatedInputs is static so that I can let every ConditionalExpression to implement a method to update it.

val newExpr = c.withNewAlwaysEvaluatedInputs(newAlwaysEvaluatedInputs)
// Use transformUp to handle nested With.
newExpr.transformUpWithPruning(_.containsPattern(WITH_EXPRESSION)) {
case With(child, defs) =>
// For With in the conditional branches, they may not be evaluated at all and we can't
// pull the common expressions into a project which will always be evaluated. Inline it.
Comment on lines +115 to +117
Copy link
Member

@viirya viirya Nov 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, for specific conditional expression, e.g., If, we can still extract common expression shared on both branches which will be evaluated for sure?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as https://github.com/apache/spark/pull/43978/files#r1403392772 .

It's easy to find these common expressions shared on both branches, but it's hard to put them back to If. I think it's better to do it when we make it into a general rule that find shared common expressions and create With to deduplicate.

val refToExpr = defs.map(d => d.id -> d.child).toMap
child.transformWithPruning(_.containsPattern(COMMON_EXPR_REF)) {
case ref: CommonExpressionRef => refToExpr(ref.id)
}
}

newPlan = newPlan.withNewChildren(newChildren)
if (p.output == newPlan.output) {
newPlan
} else {
Project(p.output, newPlan)
}
case other => other.mapChildren(rewriteWithExprAndInputPlans(_, inputPlans))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, CommonExpressionDef, CommonExpressionRef, With}
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Coalesce, CommonExpressionDef, CommonExpressionRef, With}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
Expand Down Expand Up @@ -154,4 +154,27 @@ class RewriteWithExpressionSuite extends PlanTest {
)
)
}

test("WITH expression inside conditional expression") {
val a = testRelation.output.head
val commonExprDef = CommonExpressionDef(a + a)
val ref = new CommonExpressionRef(commonExprDef)
val expr = Coalesce(Seq(a, With(ref * ref, Seq(commonExprDef))))
val inlinedExpr = Coalesce(Seq(a, (a + a) * (a + a)))
val plan = testRelation.select(expr.as("col"))
// With in the conditional branches is always inlined.
comparePlans(Optimizer.execute(plan), testRelation.select(inlinedExpr.as("col")))

val expr2 = Coalesce(Seq(With(ref * ref, Seq(commonExprDef)), a))
val plan2 = testRelation.select(expr2.as("col"))
val commonExprName = "_common_expr_0"
// With in the always-evaluated branches can still be optimized.
comparePlans(
Optimizer.execute(plan2),
testRelation
.select((testRelation.output :+ (a + a).as(commonExprName)): _*)
.select(Coalesce(Seq(($"$commonExprName" * $"$commonExprName"), a)).as("col"))
.analyze
)
}
}