Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-45507][SQL] Correctness fix for nested correlated scalar subqueries with COUNT aggregates #43341

Closed
wants to merge 16 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -360,17 +360,49 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper
plan.transformExpressionsWithPruning(_.containsPattern(PLAN_EXPRESSION)) {
case ScalarSubquery(sub, children, exprId, conditions, hint, mayHaveCountBugOld)
if children.nonEmpty =>
val (newPlan, newCond) = decorrelate(sub, plan)
val mayHaveCountBug = if (mayHaveCountBugOld.isEmpty) {

def mayHaveCountBugAgg(a: Aggregate): Boolean = {
a.groupingExpressions.isEmpty && a.aggregateExpressions.exists(_.exists {
case a: AggregateExpression => a.aggregateFunction.defaultResult.isDefined
case _ => false
})
}

// The below logic controls handling count bug for scalar subqueries in
// [[DecorrelateInnerQuery]], and if we don't handle it here, we handle it in
// [[RewriteCorrelatedScalarSubquery#constructLeftJoins]]. Note that handling it in
// [[DecorrelateInnerQuery]] is always correct, and turning it off to handle it in
// constructLeftJoins is an optimization, so that additional, redundant left outer joins are
// not introduced.
val handleCountBugInDecorrelate = SQLConf.get.decorrelateInnerQueryEnabled &&
!conf.getConf(SQLConf.LEGACY_SCALAR_SUBQUERY_COUNT_BUG_HANDLING) && !(sub match {
// Handle count bug only if there exists lower level Aggs with count bugs. It does not
// matter if the top level agg is count bug vulnerable or not, because:
// 1. If the top level agg is count bug vulnerable, it can be handled in
// constructLeftJoins, unless there are lower aggs that are count bug vulnerable.
// E.g. COUNT(COUNT + COUNT)
// 2. If the top level agg is not count bug vulnerable, it can be count bug vulnerable if
// there are lower aggs that are count bug vulnerable. E.g. SUM(COUNT)
case agg: Aggregate => !agg.child.exists {
case lowerAgg: Aggregate => mayHaveCountBugAgg(lowerAgg)
case _ => false
}
case _ => false
})
val (newPlan, newCond) = decorrelate(sub, plan, handleCountBugInDecorrelate)
val mayHaveCountBug = if (mayHaveCountBugOld.isDefined) {
// For idempotency, we must save this variable the first time this rule is run, because
// decorrelation introduces a GROUP BY is if one wasn't already present.
mayHaveCountBugOld.get
} else if (handleCountBugInDecorrelate) {
// Count bug was already handled in the above decorrelate function call.
false
} else {
// Check whether the pre-rewrite subquery had empty groupingExpressions. If yes, it may
// be subject to the COUNT bug. If it has non-empty groupingExpressions, there is
// no COUNT bug.
val (topPart, havingNode, aggNode) = splitSubquery(sub)
(aggNode.isDefined && aggNode.get.groupingExpressions.isEmpty)
} else {
// For idempotency, we must save this variable the first time this rule is run, because
// decorrelation introduces a GROUP BY is if one wasn't already present.
mayHaveCountBugOld.get
}
ScalarSubquery(newPlan, children, exprId, getJoinCondition(newCond, conditions),
hint, Some(mayHaveCountBug))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4468,6 +4468,15 @@ object SQLConf {
.booleanConf
.createWithDefault(false)

val LEGACY_SCALAR_SUBQUERY_COUNT_BUG_HANDLING =
buildConf("spark.sql.legacy.scalarSubqueryCountBugBehavior")
.internal()
.doc("When set to true, restores legacy behavior of potential incorrect count bug " +
"handling for scalar subqueries.")
.version("4.0.0")
.booleanConf
.createWithDefault(false)

/**
* Holds information about keys that have been deprecated.
*
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
-- Automatically generated by SQLQueryTestSuite
-- !query
CREATE OR REPLACE VIEW t1(a1, a2) as values (0, 1), (1, 2)
-- !query analysis
CreateViewCommand `spark_catalog`.`default`.`t1`, [(a1,None), (a2,None)], values (0, 1), (1, 2), false, true, PersistedView, true
+- LocalRelation [col1#x, col2#x]


-- !query
CREATE OR REPLACE VIEW t2(b1, b2) as values (0, 2), (0, 3)
-- !query analysis
CreateViewCommand `spark_catalog`.`default`.`t2`, [(b1,None), (b2,None)], values (0, 2), (0, 3), false, true, PersistedView, true
+- LocalRelation [col1#x, col2#x]


-- !query
CREATE OR REPLACE VIEW t3(c1, c2) as values (0, 2), (0, 3)
-- !query analysis
CreateViewCommand `spark_catalog`.`default`.`t3`, [(c1,None), (c2,None)], values (0, 2), (0, 3), false, true, PersistedView, true
+- LocalRelation [col1#x, col2#x]


-- !query
set spark.sql.optimizer.decorrelateInnerQuery.enabled=true
-- !query analysis
SetCommand (spark.sql.optimizer.decorrelateInnerQuery.enabled,Some(true))


-- !query
set spark.sql.legacy.scalarSubqueryCountBugBehavior=false
-- !query analysis
SetCommand (spark.sql.legacy.scalarSubqueryCountBugBehavior,Some(false))


-- !query
select ( select sum(cnt) from (select count(*) cnt from t2 where t1.a1 = t2.b1) ) a from t1 order by a desc
-- !query analysis
Sort [a#xL DESC NULLS LAST], true
+- Project [scalar-subquery#x [a1#x] AS a#xL]
: +- Aggregate [sum(cnt#xL) AS sum(cnt)#xL]
: +- SubqueryAlias __auto_generated_subquery_name
: +- Aggregate [count(1) AS cnt#xL]
: +- Filter (outer(a1#x) = b1#x)
: +- SubqueryAlias spark_catalog.default.t2
: +- View (`spark_catalog`.`default`.`t2`, [b1#x,b2#x])
: +- Project [cast(col1#x as int) AS b1#x, cast(col2#x as int) AS b2#x]
: +- LocalRelation [col1#x, col2#x]
+- SubqueryAlias spark_catalog.default.t1
+- View (`spark_catalog`.`default`.`t1`, [a1#x,a2#x])
+- Project [cast(col1#x as int) AS a1#x, cast(col2#x as int) AS a2#x]
+- LocalRelation [col1#x, col2#x]


-- !query
select ( select count(*) from (select count(*) cnt from t2 where t1.a1 = t2.b1) ) a from t1 order by a desc
-- !query analysis
Sort [a#xL DESC NULLS LAST], true
+- Project [scalar-subquery#x [a1#x] AS a#xL]
: +- Aggregate [count(1) AS count(1)#xL]
: +- SubqueryAlias __auto_generated_subquery_name
: +- Aggregate [count(1) AS cnt#xL]
: +- Filter (outer(a1#x) = b1#x)
: +- SubqueryAlias spark_catalog.default.t2
: +- View (`spark_catalog`.`default`.`t2`, [b1#x,b2#x])
: +- Project [cast(col1#x as int) AS b1#x, cast(col2#x as int) AS b2#x]
: +- LocalRelation [col1#x, col2#x]
+- SubqueryAlias spark_catalog.default.t1
+- View (`spark_catalog`.`default`.`t1`, [a1#x,a2#x])
+- Project [cast(col1#x as int) AS a1#x, cast(col2#x as int) AS a2#x]
+- LocalRelation [col1#x, col2#x]


-- !query
select (
select SUM(l.cnt + r.cnt)
from (select count(*) cnt from t2 where t1.a1 = t2.b1 having cnt = 0) l
join (select count(*) cnt from t3 where t1.a1 = t3.c1 having cnt = 0) r
on l.cnt = r.cnt
) a from t1 order by a desc
-- !query analysis
Sort [a#xL DESC NULLS LAST], true
+- Project [scalar-subquery#x [a1#x && a1#x] AS a#xL]
: +- Aggregate [sum((cnt#xL + cnt#xL)) AS sum((cnt + cnt))#xL]
: +- Join Inner, (cnt#xL = cnt#xL)
: :- SubqueryAlias l
: : +- Filter (cnt#xL = cast(0 as bigint))
: : +- Aggregate [count(1) AS cnt#xL]
: : +- Filter (outer(a1#x) = b1#x)
: : +- SubqueryAlias spark_catalog.default.t2
: : +- View (`spark_catalog`.`default`.`t2`, [b1#x,b2#x])
: : +- Project [cast(col1#x as int) AS b1#x, cast(col2#x as int) AS b2#x]
: : +- LocalRelation [col1#x, col2#x]
: +- SubqueryAlias r
: +- Filter (cnt#xL = cast(0 as bigint))
: +- Aggregate [count(1) AS cnt#xL]
: +- Filter (outer(a1#x) = c1#x)
: +- SubqueryAlias spark_catalog.default.t3
: +- View (`spark_catalog`.`default`.`t3`, [c1#x,c2#x])
: +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
: +- LocalRelation [col1#x, col2#x]
+- SubqueryAlias spark_catalog.default.t1
+- View (`spark_catalog`.`default`.`t1`, [a1#x,a2#x])
+- Project [cast(col1#x as int) AS a1#x, cast(col2#x as int) AS a2#x]
+- LocalRelation [col1#x, col2#x]


-- !query
select (
select sum(l.cnt + r.cnt)
from (select count(*) cnt from t2 where t1.a1 = t2.b1) l
join (select count(*) cnt from t3 where t1.a1 = t3.c1) r
on l.cnt = r.cnt
) a from t1 order by a desc
-- !query analysis
Sort [a#xL DESC NULLS LAST], true
+- Project [scalar-subquery#x [a1#x && a1#x] AS a#xL]
: +- Aggregate [sum((cnt#xL + cnt#xL)) AS sum((cnt + cnt))#xL]
: +- Join Inner, (cnt#xL = cnt#xL)
: :- SubqueryAlias l
: : +- Aggregate [count(1) AS cnt#xL]
: : +- Filter (outer(a1#x) = b1#x)
: : +- SubqueryAlias spark_catalog.default.t2
: : +- View (`spark_catalog`.`default`.`t2`, [b1#x,b2#x])
: : +- Project [cast(col1#x as int) AS b1#x, cast(col2#x as int) AS b2#x]
: : +- LocalRelation [col1#x, col2#x]
: +- SubqueryAlias r
: +- Aggregate [count(1) AS cnt#xL]
: +- Filter (outer(a1#x) = c1#x)
: +- SubqueryAlias spark_catalog.default.t3
: +- View (`spark_catalog`.`default`.`t3`, [c1#x,c2#x])
: +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
: +- LocalRelation [col1#x, col2#x]
+- SubqueryAlias spark_catalog.default.t1
+- View (`spark_catalog`.`default`.`t1`, [a1#x,a2#x])
+- Project [cast(col1#x as int) AS a1#x, cast(col2#x as int) AS a2#x]
+- LocalRelation [col1#x, col2#x]


-- !query
reset spark.sql.optimizer.decorrelateInnerQuery.enabled
-- !query analysis
ResetCommand spark.sql.optimizer.decorrelateInnerQuery.enabled


-- !query
reset spark.sql.legacy.scalarSubqueryCountBugBehavior
-- !query analysis
ResetCommand spark.sql.legacy.scalarSubqueryCountBugBehavior


-- !query
DROP VIEW t1
-- !query analysis
DropTableCommand `spark_catalog`.`default`.`t1`, false, true, false


-- !query
DROP VIEW t2
-- !query analysis
DropTableCommand `spark_catalog`.`default`.`t2`, false, true, false


-- !query
DROP VIEW t3
-- !query analysis
DropTableCommand `spark_catalog`.`default`.`t3`, false, true, false
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
CREATE OR REPLACE VIEW t1(a1, a2) as values (0, 1), (1, 2);
CREATE OR REPLACE VIEW t2(b1, b2) as values (0, 2), (0, 3);
CREATE OR REPLACE VIEW t3(c1, c2) as values (0, 2), (0, 3);

set spark.sql.optimizer.decorrelateInnerQuery.enabled=true;
set spark.sql.legacy.scalarSubqueryCountBugBehavior=false;

-- test for count bug in nested aggregates in correlated scalar subqueries
select ( select sum(cnt) from (select count(*) cnt from t2 where t1.a1 = t2.b1) ) a from t1 order by a desc;

-- test for count bug in nested counts in correlated scalar subqueries
select ( select count(*) from (select count(*) cnt from t2 where t1.a1 = t2.b1) ) a from t1 order by a desc;

-- test for count bug in correlated scalar subqueries with nested aggregates with multiple counts
select (
select SUM(l.cnt + r.cnt)
from (select count(*) cnt from t2 where t1.a1 = t2.b1 having cnt = 0) l
join (select count(*) cnt from t3 where t1.a1 = t3.c1 having cnt = 0) r
on l.cnt = r.cnt
) a from t1 order by a desc;

-- same as above, without HAVING clause
select (
select sum(l.cnt + r.cnt)
from (select count(*) cnt from t2 where t1.a1 = t2.b1) l
join (select count(*) cnt from t3 where t1.a1 = t3.c1) r
on l.cnt = r.cnt
) a from t1 order by a desc;

reset spark.sql.optimizer.decorrelateInnerQuery.enabled;
reset spark.sql.legacy.scalarSubqueryCountBugBehavior;
DROP VIEW t1;
DROP VIEW t2;
DROP VIEW t3;
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
-- Automatically generated by SQLQueryTestSuite
-- !query
CREATE OR REPLACE VIEW t1(a1, a2) as values (0, 1), (1, 2)
-- !query schema
struct<>
-- !query output



-- !query
CREATE OR REPLACE VIEW t2(b1, b2) as values (0, 2), (0, 3)
-- !query schema
struct<>
-- !query output



-- !query
CREATE OR REPLACE VIEW t3(c1, c2) as values (0, 2), (0, 3)
-- !query schema
struct<>
-- !query output



-- !query
set spark.sql.optimizer.decorrelateInnerQuery.enabled=true
-- !query schema
struct<key:string,value:string>
-- !query output
spark.sql.optimizer.decorrelateInnerQuery.enabled true


-- !query
set spark.sql.legacy.scalarSubqueryCountBugBehavior=false
-- !query schema
struct<key:string,value:string>
-- !query output
spark.sql.legacy.scalarSubqueryCountBugBehavior false


-- !query
select ( select sum(cnt) from (select count(*) cnt from t2 where t1.a1 = t2.b1) ) a from t1 order by a desc
-- !query schema
struct<a:bigint>
-- !query output
2
0


-- !query
select ( select count(*) from (select count(*) cnt from t2 where t1.a1 = t2.b1) ) a from t1 order by a desc
-- !query schema
struct<a:bigint>
-- !query output
1
1


-- !query
select (
select SUM(l.cnt + r.cnt)
from (select count(*) cnt from t2 where t1.a1 = t2.b1 having cnt = 0) l
join (select count(*) cnt from t3 where t1.a1 = t3.c1 having cnt = 0) r
on l.cnt = r.cnt
) a from t1 order by a desc
-- !query schema
struct<a:bigint>
-- !query output
0
NULL


-- !query
select (
select sum(l.cnt + r.cnt)
from (select count(*) cnt from t2 where t1.a1 = t2.b1) l
join (select count(*) cnt from t3 where t1.a1 = t3.c1) r
on l.cnt = r.cnt
) a from t1 order by a desc
-- !query schema
struct<a:bigint>
-- !query output
4
0


-- !query
reset spark.sql.optimizer.decorrelateInnerQuery.enabled
-- !query schema
struct<>
-- !query output



-- !query
reset spark.sql.legacy.scalarSubqueryCountBugBehavior
-- !query schema
struct<>
-- !query output



-- !query
DROP VIEW t1
-- !query schema
struct<>
-- !query output



-- !query
DROP VIEW t2
-- !query schema
struct<>
-- !query output



-- !query
DROP VIEW t3
-- !query schema
struct<>
-- !query output