Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-45507][SQL] Correctness fix for nested correlated scalar subqueries with COUNT aggregates #43341

Closed
wants to merge 16 commits into from
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -360,8 +360,31 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper
plan.transformExpressionsWithPruning(_.containsPattern(PLAN_EXPRESSION)) {
case ScalarSubquery(sub, children, exprId, conditions, hint, mayHaveCountBugOld)
if children.nonEmpty =>
val (newPlan, newCond) = decorrelate(sub, plan)
val mayHaveCountBug = if (mayHaveCountBugOld.isEmpty) {

def mayHaveCountBugAgg(a: Aggregate): Boolean = {
a.groupingExpressions.isEmpty && a.aggregateExpressions.exists(_.exists {
case a: AggregateExpression => a.aggregateFunction.defaultResult.isDefined
case _ => false
})
}

// We want to handle count bug for scalar subqueries, except for the cases where the
andylam-db marked this conversation as resolved.
Show resolved Hide resolved
// subquery is a simple top level Aggregate which can have a count bug (note: the below
// logic also takes into account nested COUNTs). This is because for these cases, we don't
// want to introduce redundant left outer joins in [[DecorrelateInnerQuery]], when the
// necessary left outer join will be added in [[RewriteCorrelatedScalarSubquery]].
val shouldHandleCountBug = !(sub match {
andylam-db marked this conversation as resolved.
Show resolved Hide resolved
case agg: Aggregate => mayHaveCountBugAgg(agg) && !agg.exists {
case lowerAgg: Aggregate => mayHaveCountBugAgg(lowerAgg)
case _ => false
}
case _ => false
})
val (newPlan, newCond) = decorrelate(sub, plan, shouldHandleCountBug)
val mayHaveCountBug = if (shouldHandleCountBug) {
// Count bug was already handled in the above decorrelate function call.
false
} else if (mayHaveCountBugOld.isEmpty) {
andylam-db marked this conversation as resolved.
Show resolved Hide resolved
// Check whether the pre-rewrite subquery had empty groupingExpressions. If yes, it may
// be subject to the COUNT bug. If it has non-empty groupingExpressions, there is
// no COUNT bug.
Expand Down
125 changes: 125 additions & 0 deletions sql/core/src/test/resources/sql-tests/analyzer-results/count.sql.out
Original file line number Diff line number Diff line change
Expand Up @@ -228,3 +228,128 @@ org.apache.spark.sql.AnalysisException
"targetString" : "testData"
}
}


-- !query
CREATE OR REPLACE VIEW t1(a1, a2) as values (0, 1), (1, 2)
-- !query analysis
CreateViewCommand `spark_catalog`.`default`.`t1`, [(a1,None), (a2,None)], values (0, 1), (1, 2), false, true, PersistedView, true
+- LocalRelation [col1#x, col2#x]


-- !query
CREATE OR REPLACE VIEW t2(b1, b2) as values (0, 2), (0, 3)
-- !query analysis
CreateViewCommand `spark_catalog`.`default`.`t2`, [(b1,None), (b2,None)], values (0, 2), (0, 3), false, true, PersistedView, true
+- LocalRelation [col1#x, col2#x]


-- !query
CREATE OR REPLACE VIEW t3(c1, c2) as values (0, 2), (0, 3)
-- !query analysis
CreateViewCommand `spark_catalog`.`default`.`t3`, [(c1,None), (c2,None)], values (0, 2), (0, 3), false, true, PersistedView, true
+- LocalRelation [col1#x, col2#x]


-- !query
select ( select sum(cnt) from (select count(*) cnt from t2 where t1.a1 = t2.b1) ) a from t1 order by a desc
-- !query analysis
Sort [a#xL DESC NULLS LAST], true
+- Project [scalar-subquery#x [a1#x] AS a#xL]
: +- Aggregate [sum(cnt#xL) AS sum(cnt)#xL]
: +- SubqueryAlias __auto_generated_subquery_name
: +- Aggregate [count(1) AS cnt#xL]
: +- Filter (outer(a1#x) = b1#x)
: +- SubqueryAlias spark_catalog.default.t2
: +- View (`spark_catalog`.`default`.`t2`, [b1#x,b2#x])
: +- Project [cast(col1#x as int) AS b1#x, cast(col2#x as int) AS b2#x]
: +- LocalRelation [col1#x, col2#x]
+- SubqueryAlias spark_catalog.default.t1
+- View (`spark_catalog`.`default`.`t1`, [a1#x,a2#x])
+- Project [cast(col1#x as int) AS a1#x, cast(col2#x as int) AS a2#x]
+- LocalRelation [col1#x, col2#x]


-- !query
select ( select count(*) from (select count(*) cnt from t2 where t1.a1 = t2.b1) ) a from t1 order by a desc
-- !query analysis
Sort [a#xL DESC NULLS LAST], true
+- Project [scalar-subquery#x [a1#x] AS a#xL]
: +- Aggregate [count(1) AS count(1)#xL]
: +- SubqueryAlias __auto_generated_subquery_name
: +- Aggregate [count(1) AS cnt#xL]
: +- Filter (outer(a1#x) = b1#x)
: +- SubqueryAlias spark_catalog.default.t2
: +- View (`spark_catalog`.`default`.`t2`, [b1#x,b2#x])
: +- Project [cast(col1#x as int) AS b1#x, cast(col2#x as int) AS b2#x]
: +- LocalRelation [col1#x, col2#x]
+- SubqueryAlias spark_catalog.default.t1
+- View (`spark_catalog`.`default`.`t1`, [a1#x,a2#x])
+- Project [cast(col1#x as int) AS a1#x, cast(col2#x as int) AS a2#x]
+- LocalRelation [col1#x, col2#x]


-- !query
select (
select SUM(l.cnt + r.cnt)
from (select count(*) cnt from t2 where t1.a1 = t2.b1 having cnt = 0) l
join (select count(*) cnt from t3 where t1.a1 = t3.c1 having cnt = 0) r
on l.cnt = r.cnt
) a from t1 order by a desc
-- !query analysis
Sort [a#xL DESC NULLS LAST], true
+- Project [scalar-subquery#x [a1#x && a1#x] AS a#xL]
: +- Aggregate [sum((cnt#xL + cnt#xL)) AS sum((cnt + cnt))#xL]
: +- Join Inner, (cnt#xL = cnt#xL)
: :- SubqueryAlias l
: : +- Filter (cnt#xL = cast(0 as bigint))
: : +- Aggregate [count(1) AS cnt#xL]
: : +- Filter (outer(a1#x) = b1#x)
: : +- SubqueryAlias spark_catalog.default.t2
: : +- View (`spark_catalog`.`default`.`t2`, [b1#x,b2#x])
: : +- Project [cast(col1#x as int) AS b1#x, cast(col2#x as int) AS b2#x]
: : +- LocalRelation [col1#x, col2#x]
: +- SubqueryAlias r
: +- Filter (cnt#xL = cast(0 as bigint))
: +- Aggregate [count(1) AS cnt#xL]
: +- Filter (outer(a1#x) = c1#x)
: +- SubqueryAlias spark_catalog.default.t3
: +- View (`spark_catalog`.`default`.`t3`, [c1#x,c2#x])
: +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
: +- LocalRelation [col1#x, col2#x]
+- SubqueryAlias spark_catalog.default.t1
+- View (`spark_catalog`.`default`.`t1`, [a1#x,a2#x])
+- Project [cast(col1#x as int) AS a1#x, cast(col2#x as int) AS a2#x]
+- LocalRelation [col1#x, col2#x]


-- !query
select (
select sum(l.cnt + r.cnt)
from (select count(*) cnt from t2 where t1.a1 = t2.b1) l
join (select count(*) cnt from t3 where t1.a1 = t3.c1) r
on l.cnt = r.cnt
) a from t1 order by a desc
-- !query analysis
Sort [a#xL DESC NULLS LAST], true
+- Project [scalar-subquery#x [a1#x && a1#x] AS a#xL]
: +- Aggregate [sum((cnt#xL + cnt#xL)) AS sum((cnt + cnt))#xL]
: +- Join Inner, (cnt#xL = cnt#xL)
: :- SubqueryAlias l
: : +- Aggregate [count(1) AS cnt#xL]
: : +- Filter (outer(a1#x) = b1#x)
: : +- SubqueryAlias spark_catalog.default.t2
: : +- View (`spark_catalog`.`default`.`t2`, [b1#x,b2#x])
: : +- Project [cast(col1#x as int) AS b1#x, cast(col2#x as int) AS b2#x]
: : +- LocalRelation [col1#x, col2#x]
: +- SubqueryAlias r
: +- Aggregate [count(1) AS cnt#xL]
: +- Filter (outer(a1#x) = c1#x)
: +- SubqueryAlias spark_catalog.default.t3
: +- View (`spark_catalog`.`default`.`t3`, [c1#x,c2#x])
: +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
: +- LocalRelation [col1#x, col2#x]
+- SubqueryAlias spark_catalog.default.t1
+- View (`spark_catalog`.`default`.`t1`, [a1#x,a2#x])
+- Project [cast(col1#x as int) AS a1#x, cast(col2#x as int) AS a2#x]
+- LocalRelation [col1#x, col2#x]
26 changes: 26 additions & 0 deletions sql/core/src/test/resources/sql-tests/inputs/count.sql
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,29 @@ SELECT count(testData.*) FROM testData;
-- count with a single tblName.* as parameter
set spark.sql.legacy.allowStarWithSingleTableIdentifierInCount=false;
SELECT count(testData.*) FROM testData;

CREATE OR REPLACE VIEW t1(a1, a2) as values (0, 1), (1, 2);
CREATE OR REPLACE VIEW t2(b1, b2) as values (0, 2), (0, 3);
CREATE OR REPLACE VIEW t3(c1, c2) as values (0, 2), (0, 3);

-- test for count bug in correlated scalar subqueries
select ( select sum(cnt) from (select count(*) cnt from t2 where t1.a1 = t2.b1) ) a from t1 order by a desc;

-- test for count bug in correlated scalar subqueries with nested counts
select ( select count(*) from (select count(*) cnt from t2 where t1.a1 = t2.b1) ) a from t1 order by a desc;

-- test for count bug in correlated scalar subqueries with multiple count aggregates
select (
select SUM(l.cnt + r.cnt)
from (select count(*) cnt from t2 where t1.a1 = t2.b1 having cnt = 0) l
join (select count(*) cnt from t3 where t1.a1 = t3.c1 having cnt = 0) r
on l.cnt = r.cnt
) a from t1 order by a desc;

-- same as above, without HAVING clause
select (
select sum(l.cnt + r.cnt)
from (select count(*) cnt from t2 where t1.a1 = t2.b1) l
join (select count(*) cnt from t3 where t1.a1 = t3.c1) r
on l.cnt = r.cnt
) a from t1 order by a desc;
70 changes: 70 additions & 0 deletions sql/core/src/test/resources/sql-tests/results/count.sql.out
Original file line number Diff line number Diff line change
Expand Up @@ -196,3 +196,73 @@ org.apache.spark.sql.AnalysisException
"targetString" : "testData"
}
}


-- !query
CREATE OR REPLACE VIEW t1(a1, a2) as values (0, 1), (1, 2)
-- !query schema
struct<>
-- !query output



-- !query
CREATE OR REPLACE VIEW t2(b1, b2) as values (0, 2), (0, 3)
-- !query schema
struct<>
-- !query output



-- !query
CREATE OR REPLACE VIEW t3(c1, c2) as values (0, 2), (0, 3)
-- !query schema
struct<>
-- !query output



-- !query
select ( select sum(cnt) from (select count(*) cnt from t2 where t1.a1 = t2.b1) ) a from t1 order by a desc
-- !query schema
struct<a:bigint>
-- !query output
2
0


-- !query
select ( select count(*) from (select count(*) cnt from t2 where t1.a1 = t2.b1) ) a from t1 order by a desc
-- !query schema
struct<a:bigint>
-- !query output
1
1


-- !query
select (
select SUM(l.cnt + r.cnt)
from (select count(*) cnt from t2 where t1.a1 = t2.b1 having cnt = 0) l
join (select count(*) cnt from t3 where t1.a1 = t3.c1 having cnt = 0) r
on l.cnt = r.cnt
) a from t1 order by a desc
-- !query schema
struct<a:bigint>
-- !query output
0
NULL


-- !query
select (
select sum(l.cnt + r.cnt)
from (select count(*) cnt from t2 where t1.a1 = t2.b1) l
join (select count(*) cnt from t3 where t1.a1 = t3.c1) r
on l.cnt = r.cnt
) a from t1 order by a desc
-- !query schema
struct<a:bigint>
-- !query output
4
0