Skip to content

Commit

Permalink
[SPARK-44219][SQL] Adds extra per-rule validations for optimization r…
Browse files Browse the repository at this point in the history
…ewrites

### What changes were proposed in this pull request?

Adds per-rule validation checks for the following:

1.  aggregate expressions in Aggregate plans are valid.
2. Grouping key types in Aggregate plans cannot by of type Map.
3. No dangling references have been generated.

This validation is by default enabled for all tests or selectively using the spark.sql.planChangeValidation=true flag.

### Why are the changes needed?
Extra validation for optimizer rewrites.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
Unit tests

Closes #41763 from YannisSismanis/SC-130139_followup.

Authored-by: Yannis Sismanis <yannis.sismanis@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
YannisSismanis authored and cloud-fan committed Oct 6, 2023
1 parent 94666e9 commit 2ce1a87
Show file tree
Hide file tree
Showing 4 changed files with 412 additions and 75 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -431,77 +431,7 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB
messageParameters = Map.empty)
}

case Aggregate(groupingExprs, aggregateExprs, _) =>
def checkValidAggregateExpression(expr: Expression): Unit = expr match {
case expr: AggregateExpression =>
val aggFunction = expr.aggregateFunction
aggFunction.children.foreach { child =>
child.foreach {
case expr: AggregateExpression =>
expr.failAnalysis(
errorClass = "NESTED_AGGREGATE_FUNCTION",
messageParameters = Map.empty)
case other => // OK
}

if (!child.deterministic) {
child.failAnalysis(
errorClass = "AGGREGATE_FUNCTION_WITH_NONDETERMINISTIC_EXPRESSION",
messageParameters = Map("sqlExpr" -> toSQLExpr(expr)))
}
}
case _: Attribute if groupingExprs.isEmpty =>
operator.failAnalysis(
errorClass = "MISSING_GROUP_BY",
messageParameters = Map.empty)
case e: Attribute if !groupingExprs.exists(_.semanticEquals(e)) =>
throw QueryCompilationErrors.columnNotInGroupByClauseError(e)
case s: ScalarSubquery
if s.children.nonEmpty && !groupingExprs.exists(_.semanticEquals(s)) =>
s.failAnalysis(
errorClass = "SCALAR_SUBQUERY_IS_IN_GROUP_BY_OR_AGGREGATE_FUNCTION",
messageParameters = Map("sqlExpr" -> toSQLExpr(s)))
case e if groupingExprs.exists(_.semanticEquals(e)) => // OK
// There should be no Window in Aggregate - this case will fail later check anyway.
// Perform this check for special case of lateral column alias, when the window
// expression is not eligible to propagate to upper plan because it is not valid,
// containing non-group-by or non-aggregate-expressions.
case WindowExpression(function, spec) =>
function.children.foreach(checkValidAggregateExpression)
checkValidAggregateExpression(spec)
case e => e.children.foreach(checkValidAggregateExpression)
}

def checkValidGroupingExprs(expr: Expression): Unit = {
if (expr.exists(_.isInstanceOf[AggregateExpression])) {
expr.failAnalysis(
errorClass = "GROUP_BY_AGGREGATE",
messageParameters = Map("sqlExpr" -> expr.sql))
}

// Check if the data type of expr is orderable.
if (!RowOrdering.isOrderable(expr.dataType)) {
expr.failAnalysis(
errorClass = "GROUP_EXPRESSION_TYPE_IS_NOT_ORDERABLE",
messageParameters = Map(
"sqlExpr" -> toSQLExpr(expr),
"dataType" -> toSQLType(expr.dataType)))
}

if (!expr.deterministic) {
// This is just a sanity check, our analysis rule PullOutNondeterministic should
// already pull out those nondeterministic expressions and evaluate them in
// a Project node.
throw SparkException.internalError(
msg = s"Non-deterministic expression '${toSQLExpr(expr)}' should not appear in " +
"grouping expression.",
context = expr.origin.getQueryContext,
summary = expr.origin.context.summary)
}
}

groupingExprs.foreach(checkValidGroupingExprs)
aggregateExprs.foreach(checkValidAggregateExpression)
case a: Aggregate => ExprUtils.assertValidAggregation(a)

case CollectMetrics(name, metrics, _, _) =>
if (name == null || name.isEmpty) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,12 @@ package org.apache.spark.sql.catalyst.expressions
import java.text.{DecimalFormat, DecimalFormatSymbols, ParsePosition}
import java.util.Locale

import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans.logical.Aggregate
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, CharVarcharUtils}
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase, QueryExecutionErrors}
import org.apache.spark.sql.types.{DataType, MapType, StringType, StructType}
Expand Down Expand Up @@ -140,4 +144,77 @@ object ExprUtils extends QueryErrorsBase {
TypeCheckSuccess
}
}

def assertValidAggregation(a: Aggregate): Unit = {
def checkValidAggregateExpression(expr: Expression): Unit = expr match {
case expr: AggregateExpression =>
val aggFunction = expr.aggregateFunction
aggFunction.children.foreach { child =>
child.foreach {
case expr: AggregateExpression =>
expr.failAnalysis(
errorClass = "NESTED_AGGREGATE_FUNCTION",
messageParameters = Map.empty)
case other => // OK
}

if (!child.deterministic) {
child.failAnalysis(
errorClass = "AGGREGATE_FUNCTION_WITH_NONDETERMINISTIC_EXPRESSION",
messageParameters = Map("sqlExpr" -> toSQLExpr(expr)))
}
}
case _: Attribute if a.groupingExpressions.isEmpty =>
a.failAnalysis(
errorClass = "MISSING_GROUP_BY",
messageParameters = Map.empty)
case e: Attribute if !a.groupingExpressions.exists(_.semanticEquals(e)) =>
throw QueryCompilationErrors.columnNotInGroupByClauseError(e)
case s: ScalarSubquery
if s.children.nonEmpty && !a.groupingExpressions.exists(_.semanticEquals(s)) =>
s.failAnalysis(
errorClass = "SCALAR_SUBQUERY_IS_IN_GROUP_BY_OR_AGGREGATE_FUNCTION",
messageParameters = Map("sqlExpr" -> toSQLExpr(s)))
case e if a.groupingExpressions.exists(_.semanticEquals(e)) => // OK
// There should be no Window in Aggregate - this case will fail later check anyway.
// Perform this check for special case of lateral column alias, when the window
// expression is not eligible to propagate to upper plan because it is not valid,
// containing non-group-by or non-aggregate-expressions.
case WindowExpression(function, spec) =>
function.children.foreach(checkValidAggregateExpression)
checkValidAggregateExpression(spec)
case e => e.children.foreach(checkValidAggregateExpression)
}

def checkValidGroupingExprs(expr: Expression): Unit = {
if (expr.exists(_.isInstanceOf[AggregateExpression])) {
expr.failAnalysis(
errorClass = "GROUP_BY_AGGREGATE",
messageParameters = Map("sqlExpr" -> expr.sql))
}

// Check if the data type of expr is orderable.
if (!RowOrdering.isOrderable(expr.dataType)) {
expr.failAnalysis(
errorClass = "GROUP_EXPRESSION_TYPE_IS_NOT_ORDERABLE",
messageParameters = Map(
"sqlExpr" -> toSQLExpr(expr),
"dataType" -> toSQLType(expr.dataType)))
}

if (!expr.deterministic) {
// This is just a sanity check, our analysis rule PullOutNondeterministic should
// already pull out those nondeterministic expressions and evaluate them in
// a Project node.
throw SparkException.internalError(
msg = s"Non-deterministic expression '${toSQLExpr(expr)}' should not appear in " +
"grouping expression.",
context = expr.origin.getQueryContext,
summary = expr.origin.context.summary)
}
}

a.groupingExpressions.foreach(checkValidGroupingExprs)
a.aggregateExpressions.foreach(checkValidAggregateExpression)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.plans.logical

import org.apache.spark.internal.Logging
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.{AliasAwareQueryOutputOrdering, QueryPlan}
Expand All @@ -26,7 +27,7 @@ import org.apache.spark.sql.catalyst.trees.{BinaryLike, LeafLike, TreeNodeTag, U
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.catalyst.util.MetadataColumnHelper
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.{MapType, StructType}


abstract class LogicalPlan
Expand Down Expand Up @@ -325,6 +326,62 @@ object LogicalPlanIntegrity {
LogicalPlanIntegrity.hasUniqueExprIdsForOutput(plan))
}

/**
* This method validates there are no dangling attribute references.
* Returns an error message if the check does not pass, or None if it does pass.
*/
def validateNoDanglingReferences(plan: LogicalPlan): Option[String] = {
plan.collectFirst {
// DML commands and multi instance relations (like InMemoryRelation caches)
// have different output semantics than typical queries.
case _: Command => None
case _: MultiInstanceRelation => None
case n if canGetOutputAttrs(n) =>
if (n.missingInput.nonEmpty) {
Some(s"Aliases ${ n.missingInput.mkString(", ")} are dangling " +
s"in the references for plan:\n ${n.treeString}")
} else {
None
}
}.flatten
}

/**
* Validate that the grouping key types in Aggregate plans are valid.
* Returns an error message if the check fails, or None if it succeeds.
*/
def validateGroupByTypes(plan: LogicalPlan): Option[String] = {
plan.collectFirst {
case a @ Aggregate(groupingExprs, _, _) =>
val badExprs = groupingExprs.filter(_.dataType.isInstanceOf[MapType]).map(_.toString)
if (badExprs.nonEmpty) {
Some(s"Grouping expressions ${badExprs.mkString(", ")} cannot be of type Map " +
s"for plan:\n ${a.treeString}")
} else {
None
}
}.flatten
}

/**
* Validate that the aggregation expressions in Aggregate plans are valid.
* Returns an error message if the check fails, or None if it succeeds.
*/
def validateAggregateExpressions(plan: LogicalPlan): Option[String] = {
plan.collectFirst {
case a: Aggregate =>
try {
ExprUtils.assertValidAggregation(a)
None
} catch {
case e: AnalysisException =>
Some(s"Aggregate: ${a.toString} is not a valid aggregate expression: " +
s"${e.getSimpleMessage}")
}
}.flatten
}


/**
* Validate the structural integrity of an optimized plan.
* For example, we can check after the execution of each rule that each plan:
Expand All @@ -337,7 +394,7 @@ object LogicalPlanIntegrity {
def validateOptimizedPlan(
previousPlan: LogicalPlan,
currentPlan: LogicalPlan): Option[String] = {
if (!currentPlan.resolved) {
var validation = if (!currentPlan.resolved) {
Some("The plan becomes unresolved: " + currentPlan.treeString + "\nThe previous plan: " +
previousPlan.treeString)
} else if (currentPlan.exists(PlanHelper.specialExpressionsInUnsupportedOperator(_).nonEmpty)) {
Expand All @@ -353,6 +410,13 @@ object LogicalPlanIntegrity {
}
}
}
validation = validation
.orElse(LogicalPlanIntegrity.validateNoDanglingReferences(currentPlan))
.orElse(LogicalPlanIntegrity.validateGroupByTypes(currentPlan))
.orElse(LogicalPlanIntegrity.validateAggregateExpressions(currentPlan))
.map(err => s"${err}\nPrevious schema:${previousPlan.output.mkString(", ")}" +
s"\nPrevious plan: ${previousPlan.treeString}")
validation
}
}

Expand Down
Loading

0 comments on commit 2ce1a87

Please sign in to comment.