From 8a972c2fe8730e41193986a273aa92b234e5beb8 Mon Sep 17 00:00:00 2001 From: Andy Lam Date: Thu, 19 Oct 2023 10:36:32 +0800 Subject: [PATCH] [SPARK-45507][SQL] Correctness fix for nested correlated scalar subqueries with COUNT aggregates ### What changes were proposed in this pull request? We want to use the count bug handling in `DecorrelateInnerQuery` to detect potential count bugs in scalar subqueries. it It is always safe to use `DecorrelateInnerQuery` to handle count bugs, but for efficiency reasons, like for the common case of COUNT on top of the scalar subquery, we would like to avoid an extra left outer join. This PR therefore introduces a simple check to detect such cases before `decorrelate()` - if true, then don't do count bug handling in `decorrelate()`, and vice-versa. ### Why are the changes needed? This PR fixes correctness issues for correlated scalar subqueries pertaining to the COUNT bug. Examples can be found in the JIRA ticket. ### Does this PR introduce _any_ user-facing change? Yes, results will change. ### How was this patch tested? Added SQL end-to-end tests in `count.sql` ### Was this patch authored or co-authored using generative AI tooling? No. Closes #43341 from andylam-db/multiple-count-bug. Authored-by: Andy Lam Signed-off-by: Wenchen Fan --- .../sql/catalyst/optimizer/subquery.scala | 44 ++++- .../apache/spark/sql/internal/SQLConf.scala | 9 + .../nested-scalar-subquery-count-bug.sql.out | 166 ++++++++++++++++++ .../nested-scalar-subquery-count-bug.sql | 34 ++++ .../nested-scalar-subquery-count-bug.sql.out | 125 +++++++++++++ 5 files changed, 372 insertions(+), 6 deletions(-) create mode 100644 sql/core/src/test/resources/sql-tests/analyzer-results/subquery/scalar-subquery/nested-scalar-subquery-count-bug.sql.out create mode 100644 sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/nested-scalar-subquery-count-bug.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/nested-scalar-subquery-count-bug.sql.out diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index 5b95ee1df1be9..1f1a16e909371 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -426,17 +426,49 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper plan.transformExpressionsWithPruning(_.containsPattern(PLAN_EXPRESSION)) { case ScalarSubquery(sub, children, exprId, conditions, hint, mayHaveCountBugOld) if children.nonEmpty => - val (newPlan, newCond) = decorrelate(sub, plan) - val mayHaveCountBug = if (mayHaveCountBugOld.isEmpty) { + + def mayHaveCountBugAgg(a: Aggregate): Boolean = { + a.groupingExpressions.isEmpty && a.aggregateExpressions.exists(_.exists { + case a: AggregateExpression => a.aggregateFunction.defaultResult.isDefined + case _ => false + }) + } + + // The below logic controls handling count bug for scalar subqueries in + // [[DecorrelateInnerQuery]], and if we don't handle it here, we handle it in + // [[RewriteCorrelatedScalarSubquery#constructLeftJoins]]. Note that handling it in + // [[DecorrelateInnerQuery]] is always correct, and turning it off to handle it in + // constructLeftJoins is an optimization, so that additional, redundant left outer joins are + // not introduced. + val handleCountBugInDecorrelate = SQLConf.get.decorrelateInnerQueryEnabled && + !conf.getConf(SQLConf.LEGACY_SCALAR_SUBQUERY_COUNT_BUG_HANDLING) && !(sub match { + // Handle count bug only if there exists lower level Aggs with count bugs. It does not + // matter if the top level agg is count bug vulnerable or not, because: + // 1. If the top level agg is count bug vulnerable, it can be handled in + // constructLeftJoins, unless there are lower aggs that are count bug vulnerable. + // E.g. COUNT(COUNT + COUNT) + // 2. If the top level agg is not count bug vulnerable, it can be count bug vulnerable if + // there are lower aggs that are count bug vulnerable. E.g. SUM(COUNT) + case agg: Aggregate => !agg.child.exists { + case lowerAgg: Aggregate => mayHaveCountBugAgg(lowerAgg) + case _ => false + } + case _ => false + }) + val (newPlan, newCond) = decorrelate(sub, plan, handleCountBugInDecorrelate) + val mayHaveCountBug = if (mayHaveCountBugOld.isDefined) { + // For idempotency, we must save this variable the first time this rule is run, because + // decorrelation introduces a GROUP BY is if one wasn't already present. + mayHaveCountBugOld.get + } else if (handleCountBugInDecorrelate) { + // Count bug was already handled in the above decorrelate function call. + false + } else { // Check whether the pre-rewrite subquery had empty groupingExpressions. If yes, it may // be subject to the COUNT bug. If it has non-empty groupingExpressions, there is // no COUNT bug. val (topPart, havingNode, aggNode) = splitSubquery(sub) (aggNode.isDefined && aggNode.get.groupingExpressions.isEmpty) - } else { - // For idempotency, we must save this variable the first time this rule is run, because - // decorrelation introduces a GROUP BY is if one wasn't already present. - mayHaveCountBugOld.get } ScalarSubquery(newPlan, children, exprId, getJoinCondition(newCond, conditions), hint, Some(mayHaveCountBug)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index a2401c4917c10..e66eadaa914ef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -4507,6 +4507,15 @@ object SQLConf { .booleanConf .createWithDefault(false) + val LEGACY_SCALAR_SUBQUERY_COUNT_BUG_HANDLING = + buildConf("spark.sql.legacy.scalarSubqueryCountBugBehavior") + .internal() + .doc("When set to true, restores legacy behavior of potential incorrect count bug " + + "handling for scalar subqueries.") + .version("4.0.0") + .booleanConf + .createWithDefault(false) + /** * Holds information about keys that have been deprecated. * diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/scalar-subquery/nested-scalar-subquery-count-bug.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/scalar-subquery/nested-scalar-subquery-count-bug.sql.out new file mode 100644 index 0000000000000..aec952887db9b --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/scalar-subquery/nested-scalar-subquery-count-bug.sql.out @@ -0,0 +1,166 @@ +-- Automatically generated by SQLQueryTestSuite +-- !query +CREATE OR REPLACE VIEW t1(a1, a2) as values (0, 1), (1, 2) +-- !query analysis +CreateViewCommand `spark_catalog`.`default`.`t1`, [(a1,None), (a2,None)], values (0, 1), (1, 2), false, true, PersistedView, true + +- LocalRelation [col1#x, col2#x] + + +-- !query +CREATE OR REPLACE VIEW t2(b1, b2) as values (0, 2), (0, 3) +-- !query analysis +CreateViewCommand `spark_catalog`.`default`.`t2`, [(b1,None), (b2,None)], values (0, 2), (0, 3), false, true, PersistedView, true + +- LocalRelation [col1#x, col2#x] + + +-- !query +CREATE OR REPLACE VIEW t3(c1, c2) as values (0, 2), (0, 3) +-- !query analysis +CreateViewCommand `spark_catalog`.`default`.`t3`, [(c1,None), (c2,None)], values (0, 2), (0, 3), false, true, PersistedView, true + +- LocalRelation [col1#x, col2#x] + + +-- !query +set spark.sql.optimizer.decorrelateInnerQuery.enabled=true +-- !query analysis +SetCommand (spark.sql.optimizer.decorrelateInnerQuery.enabled,Some(true)) + + +-- !query +set spark.sql.legacy.scalarSubqueryCountBugBehavior=false +-- !query analysis +SetCommand (spark.sql.legacy.scalarSubqueryCountBugBehavior,Some(false)) + + +-- !query +select ( select sum(cnt) from (select count(*) cnt from t2 where t1.a1 = t2.b1) ) a from t1 order by a desc +-- !query analysis +Sort [a#xL DESC NULLS LAST], true ++- Project [scalar-subquery#x [a1#x] AS a#xL] + : +- Aggregate [sum(cnt#xL) AS sum(cnt)#xL] + : +- SubqueryAlias __auto_generated_subquery_name + : +- Aggregate [count(1) AS cnt#xL] + : +- Filter (outer(a1#x) = b1#x) + : +- SubqueryAlias spark_catalog.default.t2 + : +- View (`spark_catalog`.`default`.`t2`, [b1#x,b2#x]) + : +- Project [cast(col1#x as int) AS b1#x, cast(col2#x as int) AS b2#x] + : +- LocalRelation [col1#x, col2#x] + +- SubqueryAlias spark_catalog.default.t1 + +- View (`spark_catalog`.`default`.`t1`, [a1#x,a2#x]) + +- Project [cast(col1#x as int) AS a1#x, cast(col2#x as int) AS a2#x] + +- LocalRelation [col1#x, col2#x] + + +-- !query +select ( select count(*) from (select count(*) cnt from t2 where t1.a1 = t2.b1) ) a from t1 order by a desc +-- !query analysis +Sort [a#xL DESC NULLS LAST], true ++- Project [scalar-subquery#x [a1#x] AS a#xL] + : +- Aggregate [count(1) AS count(1)#xL] + : +- SubqueryAlias __auto_generated_subquery_name + : +- Aggregate [count(1) AS cnt#xL] + : +- Filter (outer(a1#x) = b1#x) + : +- SubqueryAlias spark_catalog.default.t2 + : +- View (`spark_catalog`.`default`.`t2`, [b1#x,b2#x]) + : +- Project [cast(col1#x as int) AS b1#x, cast(col2#x as int) AS b2#x] + : +- LocalRelation [col1#x, col2#x] + +- SubqueryAlias spark_catalog.default.t1 + +- View (`spark_catalog`.`default`.`t1`, [a1#x,a2#x]) + +- Project [cast(col1#x as int) AS a1#x, cast(col2#x as int) AS a2#x] + +- LocalRelation [col1#x, col2#x] + + +-- !query +select ( + select SUM(l.cnt + r.cnt) + from (select count(*) cnt from t2 where t1.a1 = t2.b1 having cnt = 0) l + join (select count(*) cnt from t3 where t1.a1 = t3.c1 having cnt = 0) r + on l.cnt = r.cnt +) a from t1 order by a desc +-- !query analysis +Sort [a#xL DESC NULLS LAST], true ++- Project [scalar-subquery#x [a1#x && a1#x] AS a#xL] + : +- Aggregate [sum((cnt#xL + cnt#xL)) AS sum((cnt + cnt))#xL] + : +- Join Inner, (cnt#xL = cnt#xL) + : :- SubqueryAlias l + : : +- Filter (cnt#xL = cast(0 as bigint)) + : : +- Aggregate [count(1) AS cnt#xL] + : : +- Filter (outer(a1#x) = b1#x) + : : +- SubqueryAlias spark_catalog.default.t2 + : : +- View (`spark_catalog`.`default`.`t2`, [b1#x,b2#x]) + : : +- Project [cast(col1#x as int) AS b1#x, cast(col2#x as int) AS b2#x] + : : +- LocalRelation [col1#x, col2#x] + : +- SubqueryAlias r + : +- Filter (cnt#xL = cast(0 as bigint)) + : +- Aggregate [count(1) AS cnt#xL] + : +- Filter (outer(a1#x) = c1#x) + : +- SubqueryAlias spark_catalog.default.t3 + : +- View (`spark_catalog`.`default`.`t3`, [c1#x,c2#x]) + : +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x] + : +- LocalRelation [col1#x, col2#x] + +- SubqueryAlias spark_catalog.default.t1 + +- View (`spark_catalog`.`default`.`t1`, [a1#x,a2#x]) + +- Project [cast(col1#x as int) AS a1#x, cast(col2#x as int) AS a2#x] + +- LocalRelation [col1#x, col2#x] + + +-- !query +select ( + select sum(l.cnt + r.cnt) + from (select count(*) cnt from t2 where t1.a1 = t2.b1) l + join (select count(*) cnt from t3 where t1.a1 = t3.c1) r + on l.cnt = r.cnt +) a from t1 order by a desc +-- !query analysis +Sort [a#xL DESC NULLS LAST], true ++- Project [scalar-subquery#x [a1#x && a1#x] AS a#xL] + : +- Aggregate [sum((cnt#xL + cnt#xL)) AS sum((cnt + cnt))#xL] + : +- Join Inner, (cnt#xL = cnt#xL) + : :- SubqueryAlias l + : : +- Aggregate [count(1) AS cnt#xL] + : : +- Filter (outer(a1#x) = b1#x) + : : +- SubqueryAlias spark_catalog.default.t2 + : : +- View (`spark_catalog`.`default`.`t2`, [b1#x,b2#x]) + : : +- Project [cast(col1#x as int) AS b1#x, cast(col2#x as int) AS b2#x] + : : +- LocalRelation [col1#x, col2#x] + : +- SubqueryAlias r + : +- Aggregate [count(1) AS cnt#xL] + : +- Filter (outer(a1#x) = c1#x) + : +- SubqueryAlias spark_catalog.default.t3 + : +- View (`spark_catalog`.`default`.`t3`, [c1#x,c2#x]) + : +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x] + : +- LocalRelation [col1#x, col2#x] + +- SubqueryAlias spark_catalog.default.t1 + +- View (`spark_catalog`.`default`.`t1`, [a1#x,a2#x]) + +- Project [cast(col1#x as int) AS a1#x, cast(col2#x as int) AS a2#x] + +- LocalRelation [col1#x, col2#x] + + +-- !query +reset spark.sql.optimizer.decorrelateInnerQuery.enabled +-- !query analysis +ResetCommand spark.sql.optimizer.decorrelateInnerQuery.enabled + + +-- !query +reset spark.sql.legacy.scalarSubqueryCountBugBehavior +-- !query analysis +ResetCommand spark.sql.legacy.scalarSubqueryCountBugBehavior + + +-- !query +DROP VIEW t1 +-- !query analysis +DropTableCommand `spark_catalog`.`default`.`t1`, false, true, false + + +-- !query +DROP VIEW t2 +-- !query analysis +DropTableCommand `spark_catalog`.`default`.`t2`, false, true, false + + +-- !query +DROP VIEW t3 +-- !query analysis +DropTableCommand `spark_catalog`.`default`.`t3`, false, true, false diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/nested-scalar-subquery-count-bug.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/nested-scalar-subquery-count-bug.sql new file mode 100644 index 0000000000000..86476389a8577 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/nested-scalar-subquery-count-bug.sql @@ -0,0 +1,34 @@ +CREATE OR REPLACE VIEW t1(a1, a2) as values (0, 1), (1, 2); +CREATE OR REPLACE VIEW t2(b1, b2) as values (0, 2), (0, 3); +CREATE OR REPLACE VIEW t3(c1, c2) as values (0, 2), (0, 3); + +set spark.sql.optimizer.decorrelateInnerQuery.enabled=true; +set spark.sql.legacy.scalarSubqueryCountBugBehavior=false; + +-- test for count bug in nested aggregates in correlated scalar subqueries +select ( select sum(cnt) from (select count(*) cnt from t2 where t1.a1 = t2.b1) ) a from t1 order by a desc; + +-- test for count bug in nested counts in correlated scalar subqueries +select ( select count(*) from (select count(*) cnt from t2 where t1.a1 = t2.b1) ) a from t1 order by a desc; + +-- test for count bug in correlated scalar subqueries with nested aggregates with multiple counts +select ( + select SUM(l.cnt + r.cnt) + from (select count(*) cnt from t2 where t1.a1 = t2.b1 having cnt = 0) l + join (select count(*) cnt from t3 where t1.a1 = t3.c1 having cnt = 0) r + on l.cnt = r.cnt +) a from t1 order by a desc; + +-- same as above, without HAVING clause +select ( + select sum(l.cnt + r.cnt) + from (select count(*) cnt from t2 where t1.a1 = t2.b1) l + join (select count(*) cnt from t3 where t1.a1 = t3.c1) r + on l.cnt = r.cnt +) a from t1 order by a desc; + +reset spark.sql.optimizer.decorrelateInnerQuery.enabled; +reset spark.sql.legacy.scalarSubqueryCountBugBehavior; +DROP VIEW t1; +DROP VIEW t2; +DROP VIEW t3; diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/nested-scalar-subquery-count-bug.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/nested-scalar-subquery-count-bug.sql.out new file mode 100644 index 0000000000000..c524d315bafc9 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/nested-scalar-subquery-count-bug.sql.out @@ -0,0 +1,125 @@ +-- Automatically generated by SQLQueryTestSuite +-- !query +CREATE OR REPLACE VIEW t1(a1, a2) as values (0, 1), (1, 2) +-- !query schema +struct<> +-- !query output + + + +-- !query +CREATE OR REPLACE VIEW t2(b1, b2) as values (0, 2), (0, 3) +-- !query schema +struct<> +-- !query output + + + +-- !query +CREATE OR REPLACE VIEW t3(c1, c2) as values (0, 2), (0, 3) +-- !query schema +struct<> +-- !query output + + + +-- !query +set spark.sql.optimizer.decorrelateInnerQuery.enabled=true +-- !query schema +struct +-- !query output +spark.sql.optimizer.decorrelateInnerQuery.enabled true + + +-- !query +set spark.sql.legacy.scalarSubqueryCountBugBehavior=false +-- !query schema +struct +-- !query output +spark.sql.legacy.scalarSubqueryCountBugBehavior false + + +-- !query +select ( select sum(cnt) from (select count(*) cnt from t2 where t1.a1 = t2.b1) ) a from t1 order by a desc +-- !query schema +struct +-- !query output +2 +0 + + +-- !query +select ( select count(*) from (select count(*) cnt from t2 where t1.a1 = t2.b1) ) a from t1 order by a desc +-- !query schema +struct +-- !query output +1 +1 + + +-- !query +select ( + select SUM(l.cnt + r.cnt) + from (select count(*) cnt from t2 where t1.a1 = t2.b1 having cnt = 0) l + join (select count(*) cnt from t3 where t1.a1 = t3.c1 having cnt = 0) r + on l.cnt = r.cnt +) a from t1 order by a desc +-- !query schema +struct +-- !query output +0 +NULL + + +-- !query +select ( + select sum(l.cnt + r.cnt) + from (select count(*) cnt from t2 where t1.a1 = t2.b1) l + join (select count(*) cnt from t3 where t1.a1 = t3.c1) r + on l.cnt = r.cnt +) a from t1 order by a desc +-- !query schema +struct +-- !query output +4 +0 + + +-- !query +reset spark.sql.optimizer.decorrelateInnerQuery.enabled +-- !query schema +struct<> +-- !query output + + + +-- !query +reset spark.sql.legacy.scalarSubqueryCountBugBehavior +-- !query schema +struct<> +-- !query output + + + +-- !query +DROP VIEW t1 +-- !query schema +struct<> +-- !query output + + + +-- !query +DROP VIEW t2 +-- !query schema +struct<> +-- !query output + + + +-- !query +DROP VIEW t3 +-- !query schema +struct<> +-- !query output +