Skip to content

Commit

Permalink
[SPARK-45760][SQL][FOLLOWUP] Inline With inside conditional branches
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

This is a followup of apache#43623 to fix a regression. For `With` inside conditional branches, they may not be evaluated at all and we should not pull out the common expressions into a `Project`, but just inline.

### Why are the changes needed?

avoid perf regression

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

new test

### Was this patch authored or co-authored using generative AI tooling?

No

Closes apache#43978 from cloud-fan/with.

Authored-by: Wenchen Fan <wenchen@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
cloud-fan committed Nov 28, 2023
1 parent f2ea75f commit a6cda23
Show file tree
Hide file tree
Showing 5 changed files with 185 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,11 @@ trait ConditionalExpression extends Expression {
*/
def alwaysEvaluatedInputs: Seq[Expression]

/**
* Return a copy of itself with a new `alwaysEvaluatedInputs`.
*/
def withNewAlwaysEvaluatedInputs(alwaysEvaluatedInputs: Seq[Expression]): ConditionalExpression

/**
* Return groups of branches. For each group, at least one branch will be hit at runtime,
* so that we can eagerly evaluate the common expressions of a group.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
*/
override def alwaysEvaluatedInputs: Seq[Expression] = predicate :: Nil

override def withNewAlwaysEvaluatedInputs(alwaysEvaluatedInputs: Seq[Expression]): If = {
copy(predicate = alwaysEvaluatedInputs.head)
}

override def branchGroups: Seq[Seq[Expression]] = Seq(Seq(trueValue, falseValue))

final override val nodePatterns : Seq[TreePattern] = Seq(IF)
Expand Down Expand Up @@ -165,8 +169,15 @@ case class CaseWhen(

final override val nodePatterns : Seq[TreePattern] = Seq(CASE_WHEN)

override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression =
super.legacyWithNewChildren(newChildren)
override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): CaseWhen = {
if (newChildren.length % 2 == 0) {
copy(branches = newChildren.grouped(2).map { case Seq(a, b) => (a, b) }.toSeq)
} else {
copy(
branches = newChildren.dropRight(1).grouped(2).map { case Seq(a, b) => (a, b) }.toSeq,
elseValue = newChildren.lastOption)
}
}

// both then and else expressions should be considered.
@transient
Expand Down Expand Up @@ -213,6 +224,10 @@ case class CaseWhen(
*/
override def alwaysEvaluatedInputs: Seq[Expression] = children.head :: Nil

override def withNewAlwaysEvaluatedInputs(alwaysEvaluatedInputs: Seq[Expression]): CaseWhen = {
withNewChildrenInternal(alwaysEvaluatedInputs.toIndexedSeq ++ children.drop(1))
}

override def branchGroups: Seq[Seq[Expression]] = {
// We look at subexpressions in conditions and values of `CaseWhen` separately. It is
// because a subexpression in conditions will be run no matter which condition is matched
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ case class Coalesce(children: Seq[Expression])
*/
override def alwaysEvaluatedInputs: Seq[Expression] = children.head :: Nil

override def withNewAlwaysEvaluatedInputs(alwaysEvaluatedInputs: Seq[Expression]): Coalesce = {
withNewChildrenInternal(alwaysEvaluatedInputs.toIndexedSeq ++ children.drop(1))
}

override def branchGroups: Seq[Seq[Expression]] = if (children.length > 1) {
// If there is only one child, the first child is already covered by
// `alwaysEvaluatedInputs` and we should exclude it here.
Expand Down Expand Up @@ -290,6 +294,10 @@ case class NaNvl(left: Expression, right: Expression)
*/
override def alwaysEvaluatedInputs: Seq[Expression] = left :: Nil

override def withNewAlwaysEvaluatedInputs(alwaysEvaluatedInputs: Seq[Expression]): NaNvl = {
copy(left = alwaysEvaluatedInputs.head)
}

override def branchGroups: Seq[Seq[Expression]] = Seq(children)

override def eval(input: InternalRow): Any = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ package org.apache.spark.sql.catalyst.optimizer

import scala.collection.mutable

import org.apache.spark.sql.catalyst.expressions.{Alias, CommonExpressionDef, CommonExpressionRef, Expression, With}
import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern.{COMMON_EXPR_REF, WITH_EXPRESSION}
Expand All @@ -35,56 +36,92 @@ object RewriteWithExpression extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = {
plan.transformWithPruning(_.containsPattern(WITH_EXPRESSION)) {
case p if p.expressions.exists(_.containsPattern(WITH_EXPRESSION)) =>
var newChildren = p.children
var newPlan: LogicalPlan = p.transformExpressionsUp {
case With(child, defs) =>
val refToExpr = mutable.HashMap.empty[Long, Expression]
val childProjections = Array.fill(newChildren.size)(mutable.ArrayBuffer.empty[Alias])
val inputPlans = p.children.toArray
var newPlan: LogicalPlan = p.mapExpressions { expr =>
rewriteWithExprAndInputPlans(expr, inputPlans)
}
newPlan = newPlan.withNewChildren(inputPlans.toIndexedSeq)
if (p.output == newPlan.output) {
newPlan
} else {
Project(p.output, newPlan)
}
}
}

private def rewriteWithExprAndInputPlans(
e: Expression,
inputPlans: Array[LogicalPlan]): Expression = {
if (!e.containsPattern(WITH_EXPRESSION)) return e
e match {
case w: With =>
// Rewrite nested With expressions first
val child = rewriteWithExprAndInputPlans(w.child, inputPlans)
val defs = w.defs.map(rewriteWithExprAndInputPlans(_, inputPlans))
val refToExpr = mutable.HashMap.empty[Long, Expression]
val childProjections = Array.fill(inputPlans.length)(mutable.ArrayBuffer.empty[Alias])

defs.zipWithIndex.foreach { case (CommonExpressionDef(child, id), index) =>
if (child.containsPattern(COMMON_EXPR_REF)) {
throw SparkException.internalError(
"Common expression definition cannot reference other Common expression definitions")
}

defs.zipWithIndex.foreach { case (CommonExpressionDef(child, id), index) =>
if (CollapseProject.isCheap(child)) {
refToExpr(id) = child
} else {
val childProjectionIndex = newChildren.indexWhere(
c => child.references.subsetOf(c.outputSet)
)
if (childProjectionIndex == -1) {
// When we cannot rewrite the common expressions, force to inline them so that the
// query can still run. This can happen if the join condition contains `With` and
// the common expression references columns from both join sides.
// TODO: things can go wrong if the common expression is nondeterministic. We
// don't fix it for now to match the old buggy behavior when certain
// `RuntimeReplaceable` did not use the `With` expression.
// TODO: we should calculate the ref count and also inline the common expression
// if it's ref count is 1.
refToExpr(id) = child
} else {
val alias = Alias(child, s"_common_expr_$index")()
childProjections(childProjectionIndex) += alias
refToExpr(id) = alias.toAttribute
}
}
if (CollapseProject.isCheap(child)) {
refToExpr(id) = child
} else {
val childProjectionIndex = inputPlans.indexWhere(
c => child.references.subsetOf(c.outputSet)
)
if (childProjectionIndex == -1) {
// When we cannot rewrite the common expressions, force to inline them so that the
// query can still run. This can happen if the join condition contains `With` and
// the common expression references columns from both join sides.
// TODO: things can go wrong if the common expression is nondeterministic. We
// don't fix it for now to match the old buggy behavior when certain
// `RuntimeReplaceable` did not use the `With` expression.
// TODO: we should calculate the ref count and also inline the common expression
// if it's ref count is 1.
refToExpr(id) = child
} else {
val alias = Alias(child, s"_common_expr_$index")()
childProjections(childProjectionIndex) += alias
refToExpr(id) = alias.toAttribute
}
}
}

for (i <- inputPlans.indices) {
val projectList = childProjections(i)
if (projectList.nonEmpty) {
inputPlans(i) = Project(inputPlans(i).output ++ projectList, inputPlans(i))
}
}

newChildren = newChildren.zip(childProjections).map { case (child, projections) =>
if (projections.nonEmpty) {
Project(child.output ++ projections, child)
} else {
child
}
child.transformWithPruning(_.containsPattern(COMMON_EXPR_REF)) {
case ref: CommonExpressionRef =>
if (!refToExpr.contains(ref.id)) {
throw SparkException.internalError("Undefined common expression id " + ref.id)
}
refToExpr(ref.id)
}

case c: ConditionalExpression =>
val newAlwaysEvaluatedInputs = c.alwaysEvaluatedInputs.map(
rewriteWithExprAndInputPlans(_, inputPlans))
val newExpr = c.withNewAlwaysEvaluatedInputs(newAlwaysEvaluatedInputs)
// Use transformUp to handle nested With.
newExpr.transformUpWithPruning(_.containsPattern(WITH_EXPRESSION)) {
case With(child, defs) =>
// For With in the conditional branches, they may not be evaluated at all and we can't
// pull the common expressions into a project which will always be evaluated. Inline it.
val refToExpr = defs.map(d => d.id -> d.child).toMap
child.transformWithPruning(_.containsPattern(COMMON_EXPR_REF)) {
case ref: CommonExpressionRef => refToExpr(ref.id)
}
}

newPlan = newPlan.withNewChildren(newChildren)
if (p.output == newPlan.output) {
newPlan
} else {
Project(p.output, newPlan)
}
case other => other.mapChildren(rewriteWithExprAndInputPlans(_, inputPlans))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, CommonExpressionDef, CommonExpressionRef, With}
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Coalesce, CommonExpressionDef, CommonExpressionRef, With}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
Expand Down Expand Up @@ -57,7 +58,7 @@ class RewriteWithExpressionSuite extends PlanTest {
)
}

test("nested WITH expression") {
test("nested WITH expression in the definition expression") {
val a = testRelation.output.head
val commonExprDef = CommonExpressionDef(a + a)
val ref = new CommonExpressionRef(commonExprDef)
Expand Down Expand Up @@ -85,6 +86,57 @@ class RewriteWithExpressionSuite extends PlanTest {
)
}

test("nested WITH expression in the main expression") {
val a = testRelation.output.head
val commonExprDef = CommonExpressionDef(a + a)
val ref = new CommonExpressionRef(commonExprDef)
val innerExpr = With(ref + ref, Seq(commonExprDef))
val innerCommonExprName = "_common_expr_0"

val b = testRelation.output.last
val outerCommonExprDef = CommonExpressionDef(b + b)
val outerRef = new CommonExpressionRef(outerCommonExprDef)
val outerExpr = With(outerRef * outerRef + innerExpr, Seq(outerCommonExprDef))
val outerCommonExprName = "_common_expr_0"

val plan = testRelation.select(outerExpr.as("col"))
val rewrittenInnerExpr = (a + a).as(innerCommonExprName)
val rewrittenOuterExpr = (b + b).as(outerCommonExprName)
val finalExpr = rewrittenOuterExpr.toAttribute * rewrittenOuterExpr.toAttribute +
(rewrittenInnerExpr.toAttribute + rewrittenInnerExpr.toAttribute)
comparePlans(
Optimizer.execute(plan),
testRelation
.select((testRelation.output :+ rewrittenInnerExpr): _*)
.select((testRelation.output :+ rewrittenInnerExpr.toAttribute :+ rewrittenOuterExpr): _*)
.select(finalExpr.as("col"))
.analyze
)
}

test("correlated nested WITH expression is not supported") {
val b = testRelation.output.last
val outerCommonExprDef = CommonExpressionDef(b + b)
val outerRef = new CommonExpressionRef(outerCommonExprDef)

val a = testRelation.output.head
// The inner expression definition references the outer expression
val commonExprDef1 = CommonExpressionDef(a + a + outerRef)
val ref1 = new CommonExpressionRef(commonExprDef1)
val innerExpr1 = With(ref1 + ref1, Seq(commonExprDef1))

val outerExpr1 = With(outerRef + innerExpr1, Seq(outerCommonExprDef))
intercept[SparkException](Optimizer.execute(testRelation.select(outerExpr1.as("col"))))

val commonExprDef2 = CommonExpressionDef(a + a)
val ref2 = new CommonExpressionRef(commonExprDef2)
// The inner main expression references the outer expression
val innerExpr2 = With(ref2 + outerRef, Seq(commonExprDef1))

val outerExpr2 = With(outerRef + innerExpr2, Seq(outerCommonExprDef))
intercept[SparkException](Optimizer.execute(testRelation.select(outerExpr2.as("col"))))
}

test("WITH expression in filter") {
val a = testRelation.output.head
val commonExprDef = CommonExpressionDef(a + a)
Expand Down Expand Up @@ -154,4 +206,27 @@ class RewriteWithExpressionSuite extends PlanTest {
)
)
}

test("WITH expression inside conditional expression") {
val a = testRelation.output.head
val commonExprDef = CommonExpressionDef(a + a)
val ref = new CommonExpressionRef(commonExprDef)
val expr = Coalesce(Seq(a, With(ref * ref, Seq(commonExprDef))))
val inlinedExpr = Coalesce(Seq(a, (a + a) * (a + a)))
val plan = testRelation.select(expr.as("col"))
// With in the conditional branches is always inlined.
comparePlans(Optimizer.execute(plan), testRelation.select(inlinedExpr.as("col")))

val expr2 = Coalesce(Seq(With(ref * ref, Seq(commonExprDef)), a))
val plan2 = testRelation.select(expr2.as("col"))
val commonExprName = "_common_expr_0"
// With in the always-evaluated branches can still be optimized.
comparePlans(
Optimizer.execute(plan2),
testRelation
.select((testRelation.output :+ (a + a).as(commonExprName)): _*)
.select(Coalesce(Seq(($"$commonExprName" * $"$commonExprName"), a)).as("col"))
.analyze
)
}
}

0 comments on commit a6cda23

Please sign in to comment.