-
Notifications
You must be signed in to change notification settings - Fork 28.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[SPARK-45507][SQL] Correctness fix for nested correlated scalar subqu…
…eries with COUNT aggregates ### What changes were proposed in this pull request? We want to use the count bug handling in `DecorrelateInnerQuery` to detect potential count bugs in scalar subqueries. it It is always safe to use `DecorrelateInnerQuery` to handle count bugs, but for efficiency reasons, like for the common case of COUNT on top of the scalar subquery, we would like to avoid an extra left outer join. This PR therefore introduces a simple check to detect such cases before `decorrelate()` - if true, then don't do count bug handling in `decorrelate()`, and vice-versa. ### Why are the changes needed? This PR fixes correctness issues for correlated scalar subqueries pertaining to the COUNT bug. Examples can be found in the JIRA ticket. ### Does this PR introduce _any_ user-facing change? Yes, results will change. ### How was this patch tested? Added SQL end-to-end tests in `count.sql` ### Was this patch authored or co-authored using generative AI tooling? No. Closes #43341 from andylam-db/multiple-count-bug. Authored-by: Andy Lam <andy.lam@databricks.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
- Loading branch information
1 parent
1e94415
commit 8a972c2
Showing
5 changed files
with
372 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
166 changes: 166 additions & 0 deletions
166
...-tests/analyzer-results/subquery/scalar-subquery/nested-scalar-subquery-count-bug.sql.out
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
-- Automatically generated by SQLQueryTestSuite | ||
-- !query | ||
CREATE OR REPLACE VIEW t1(a1, a2) as values (0, 1), (1, 2) | ||
-- !query analysis | ||
CreateViewCommand `spark_catalog`.`default`.`t1`, [(a1,None), (a2,None)], values (0, 1), (1, 2), false, true, PersistedView, true | ||
+- LocalRelation [col1#x, col2#x] | ||
|
||
|
||
-- !query | ||
CREATE OR REPLACE VIEW t2(b1, b2) as values (0, 2), (0, 3) | ||
-- !query analysis | ||
CreateViewCommand `spark_catalog`.`default`.`t2`, [(b1,None), (b2,None)], values (0, 2), (0, 3), false, true, PersistedView, true | ||
+- LocalRelation [col1#x, col2#x] | ||
|
||
|
||
-- !query | ||
CREATE OR REPLACE VIEW t3(c1, c2) as values (0, 2), (0, 3) | ||
-- !query analysis | ||
CreateViewCommand `spark_catalog`.`default`.`t3`, [(c1,None), (c2,None)], values (0, 2), (0, 3), false, true, PersistedView, true | ||
+- LocalRelation [col1#x, col2#x] | ||
|
||
|
||
-- !query | ||
set spark.sql.optimizer.decorrelateInnerQuery.enabled=true | ||
-- !query analysis | ||
SetCommand (spark.sql.optimizer.decorrelateInnerQuery.enabled,Some(true)) | ||
|
||
|
||
-- !query | ||
set spark.sql.legacy.scalarSubqueryCountBugBehavior=false | ||
-- !query analysis | ||
SetCommand (spark.sql.legacy.scalarSubqueryCountBugBehavior,Some(false)) | ||
|
||
|
||
-- !query | ||
select ( select sum(cnt) from (select count(*) cnt from t2 where t1.a1 = t2.b1) ) a from t1 order by a desc | ||
-- !query analysis | ||
Sort [a#xL DESC NULLS LAST], true | ||
+- Project [scalar-subquery#x [a1#x] AS a#xL] | ||
: +- Aggregate [sum(cnt#xL) AS sum(cnt)#xL] | ||
: +- SubqueryAlias __auto_generated_subquery_name | ||
: +- Aggregate [count(1) AS cnt#xL] | ||
: +- Filter (outer(a1#x) = b1#x) | ||
: +- SubqueryAlias spark_catalog.default.t2 | ||
: +- View (`spark_catalog`.`default`.`t2`, [b1#x,b2#x]) | ||
: +- Project [cast(col1#x as int) AS b1#x, cast(col2#x as int) AS b2#x] | ||
: +- LocalRelation [col1#x, col2#x] | ||
+- SubqueryAlias spark_catalog.default.t1 | ||
+- View (`spark_catalog`.`default`.`t1`, [a1#x,a2#x]) | ||
+- Project [cast(col1#x as int) AS a1#x, cast(col2#x as int) AS a2#x] | ||
+- LocalRelation [col1#x, col2#x] | ||
|
||
|
||
-- !query | ||
select ( select count(*) from (select count(*) cnt from t2 where t1.a1 = t2.b1) ) a from t1 order by a desc | ||
-- !query analysis | ||
Sort [a#xL DESC NULLS LAST], true | ||
+- Project [scalar-subquery#x [a1#x] AS a#xL] | ||
: +- Aggregate [count(1) AS count(1)#xL] | ||
: +- SubqueryAlias __auto_generated_subquery_name | ||
: +- Aggregate [count(1) AS cnt#xL] | ||
: +- Filter (outer(a1#x) = b1#x) | ||
: +- SubqueryAlias spark_catalog.default.t2 | ||
: +- View (`spark_catalog`.`default`.`t2`, [b1#x,b2#x]) | ||
: +- Project [cast(col1#x as int) AS b1#x, cast(col2#x as int) AS b2#x] | ||
: +- LocalRelation [col1#x, col2#x] | ||
+- SubqueryAlias spark_catalog.default.t1 | ||
+- View (`spark_catalog`.`default`.`t1`, [a1#x,a2#x]) | ||
+- Project [cast(col1#x as int) AS a1#x, cast(col2#x as int) AS a2#x] | ||
+- LocalRelation [col1#x, col2#x] | ||
|
||
|
||
-- !query | ||
select ( | ||
select SUM(l.cnt + r.cnt) | ||
from (select count(*) cnt from t2 where t1.a1 = t2.b1 having cnt = 0) l | ||
join (select count(*) cnt from t3 where t1.a1 = t3.c1 having cnt = 0) r | ||
on l.cnt = r.cnt | ||
) a from t1 order by a desc | ||
-- !query analysis | ||
Sort [a#xL DESC NULLS LAST], true | ||
+- Project [scalar-subquery#x [a1#x && a1#x] AS a#xL] | ||
: +- Aggregate [sum((cnt#xL + cnt#xL)) AS sum((cnt + cnt))#xL] | ||
: +- Join Inner, (cnt#xL = cnt#xL) | ||
: :- SubqueryAlias l | ||
: : +- Filter (cnt#xL = cast(0 as bigint)) | ||
: : +- Aggregate [count(1) AS cnt#xL] | ||
: : +- Filter (outer(a1#x) = b1#x) | ||
: : +- SubqueryAlias spark_catalog.default.t2 | ||
: : +- View (`spark_catalog`.`default`.`t2`, [b1#x,b2#x]) | ||
: : +- Project [cast(col1#x as int) AS b1#x, cast(col2#x as int) AS b2#x] | ||
: : +- LocalRelation [col1#x, col2#x] | ||
: +- SubqueryAlias r | ||
: +- Filter (cnt#xL = cast(0 as bigint)) | ||
: +- Aggregate [count(1) AS cnt#xL] | ||
: +- Filter (outer(a1#x) = c1#x) | ||
: +- SubqueryAlias spark_catalog.default.t3 | ||
: +- View (`spark_catalog`.`default`.`t3`, [c1#x,c2#x]) | ||
: +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x] | ||
: +- LocalRelation [col1#x, col2#x] | ||
+- SubqueryAlias spark_catalog.default.t1 | ||
+- View (`spark_catalog`.`default`.`t1`, [a1#x,a2#x]) | ||
+- Project [cast(col1#x as int) AS a1#x, cast(col2#x as int) AS a2#x] | ||
+- LocalRelation [col1#x, col2#x] | ||
|
||
|
||
-- !query | ||
select ( | ||
select sum(l.cnt + r.cnt) | ||
from (select count(*) cnt from t2 where t1.a1 = t2.b1) l | ||
join (select count(*) cnt from t3 where t1.a1 = t3.c1) r | ||
on l.cnt = r.cnt | ||
) a from t1 order by a desc | ||
-- !query analysis | ||
Sort [a#xL DESC NULLS LAST], true | ||
+- Project [scalar-subquery#x [a1#x && a1#x] AS a#xL] | ||
: +- Aggregate [sum((cnt#xL + cnt#xL)) AS sum((cnt + cnt))#xL] | ||
: +- Join Inner, (cnt#xL = cnt#xL) | ||
: :- SubqueryAlias l | ||
: : +- Aggregate [count(1) AS cnt#xL] | ||
: : +- Filter (outer(a1#x) = b1#x) | ||
: : +- SubqueryAlias spark_catalog.default.t2 | ||
: : +- View (`spark_catalog`.`default`.`t2`, [b1#x,b2#x]) | ||
: : +- Project [cast(col1#x as int) AS b1#x, cast(col2#x as int) AS b2#x] | ||
: : +- LocalRelation [col1#x, col2#x] | ||
: +- SubqueryAlias r | ||
: +- Aggregate [count(1) AS cnt#xL] | ||
: +- Filter (outer(a1#x) = c1#x) | ||
: +- SubqueryAlias spark_catalog.default.t3 | ||
: +- View (`spark_catalog`.`default`.`t3`, [c1#x,c2#x]) | ||
: +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x] | ||
: +- LocalRelation [col1#x, col2#x] | ||
+- SubqueryAlias spark_catalog.default.t1 | ||
+- View (`spark_catalog`.`default`.`t1`, [a1#x,a2#x]) | ||
+- Project [cast(col1#x as int) AS a1#x, cast(col2#x as int) AS a2#x] | ||
+- LocalRelation [col1#x, col2#x] | ||
|
||
|
||
-- !query | ||
reset spark.sql.optimizer.decorrelateInnerQuery.enabled | ||
-- !query analysis | ||
ResetCommand spark.sql.optimizer.decorrelateInnerQuery.enabled | ||
|
||
|
||
-- !query | ||
reset spark.sql.legacy.scalarSubqueryCountBugBehavior | ||
-- !query analysis | ||
ResetCommand spark.sql.legacy.scalarSubqueryCountBugBehavior | ||
|
||
|
||
-- !query | ||
DROP VIEW t1 | ||
-- !query analysis | ||
DropTableCommand `spark_catalog`.`default`.`t1`, false, true, false | ||
|
||
|
||
-- !query | ||
DROP VIEW t2 | ||
-- !query analysis | ||
DropTableCommand `spark_catalog`.`default`.`t2`, false, true, false | ||
|
||
|
||
-- !query | ||
DROP VIEW t3 | ||
-- !query analysis | ||
DropTableCommand `spark_catalog`.`default`.`t3`, false, true, false |
34 changes: 34 additions & 0 deletions
34
.../resources/sql-tests/inputs/subquery/scalar-subquery/nested-scalar-subquery-count-bug.sql
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
CREATE OR REPLACE VIEW t1(a1, a2) as values (0, 1), (1, 2); | ||
CREATE OR REPLACE VIEW t2(b1, b2) as values (0, 2), (0, 3); | ||
CREATE OR REPLACE VIEW t3(c1, c2) as values (0, 2), (0, 3); | ||
|
||
set spark.sql.optimizer.decorrelateInnerQuery.enabled=true; | ||
set spark.sql.legacy.scalarSubqueryCountBugBehavior=false; | ||
|
||
-- test for count bug in nested aggregates in correlated scalar subqueries | ||
select ( select sum(cnt) from (select count(*) cnt from t2 where t1.a1 = t2.b1) ) a from t1 order by a desc; | ||
|
||
-- test for count bug in nested counts in correlated scalar subqueries | ||
select ( select count(*) from (select count(*) cnt from t2 where t1.a1 = t2.b1) ) a from t1 order by a desc; | ||
|
||
-- test for count bug in correlated scalar subqueries with nested aggregates with multiple counts | ||
select ( | ||
select SUM(l.cnt + r.cnt) | ||
from (select count(*) cnt from t2 where t1.a1 = t2.b1 having cnt = 0) l | ||
join (select count(*) cnt from t3 where t1.a1 = t3.c1 having cnt = 0) r | ||
on l.cnt = r.cnt | ||
) a from t1 order by a desc; | ||
|
||
-- same as above, without HAVING clause | ||
select ( | ||
select sum(l.cnt + r.cnt) | ||
from (select count(*) cnt from t2 where t1.a1 = t2.b1) l | ||
join (select count(*) cnt from t3 where t1.a1 = t3.c1) r | ||
on l.cnt = r.cnt | ||
) a from t1 order by a desc; | ||
|
||
reset spark.sql.optimizer.decorrelateInnerQuery.enabled; | ||
reset spark.sql.legacy.scalarSubqueryCountBugBehavior; | ||
DROP VIEW t1; | ||
DROP VIEW t2; | ||
DROP VIEW t3; |
125 changes: 125 additions & 0 deletions
125
...urces/sql-tests/results/subquery/scalar-subquery/nested-scalar-subquery-count-bug.sql.out
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
-- Automatically generated by SQLQueryTestSuite | ||
-- !query | ||
CREATE OR REPLACE VIEW t1(a1, a2) as values (0, 1), (1, 2) | ||
-- !query schema | ||
struct<> | ||
-- !query output | ||
|
||
|
||
|
||
-- !query | ||
CREATE OR REPLACE VIEW t2(b1, b2) as values (0, 2), (0, 3) | ||
-- !query schema | ||
struct<> | ||
-- !query output | ||
|
||
|
||
|
||
-- !query | ||
CREATE OR REPLACE VIEW t3(c1, c2) as values (0, 2), (0, 3) | ||
-- !query schema | ||
struct<> | ||
-- !query output | ||
|
||
|
||
|
||
-- !query | ||
set spark.sql.optimizer.decorrelateInnerQuery.enabled=true | ||
-- !query schema | ||
struct<key:string,value:string> | ||
-- !query output | ||
spark.sql.optimizer.decorrelateInnerQuery.enabled true | ||
|
||
|
||
-- !query | ||
set spark.sql.legacy.scalarSubqueryCountBugBehavior=false | ||
-- !query schema | ||
struct<key:string,value:string> | ||
-- !query output | ||
spark.sql.legacy.scalarSubqueryCountBugBehavior false | ||
|
||
|
||
-- !query | ||
select ( select sum(cnt) from (select count(*) cnt from t2 where t1.a1 = t2.b1) ) a from t1 order by a desc | ||
-- !query schema | ||
struct<a:bigint> | ||
-- !query output | ||
2 | ||
0 | ||
|
||
|
||
-- !query | ||
select ( select count(*) from (select count(*) cnt from t2 where t1.a1 = t2.b1) ) a from t1 order by a desc | ||
-- !query schema | ||
struct<a:bigint> | ||
-- !query output | ||
1 | ||
1 | ||
|
||
|
||
-- !query | ||
select ( | ||
select SUM(l.cnt + r.cnt) | ||
from (select count(*) cnt from t2 where t1.a1 = t2.b1 having cnt = 0) l | ||
join (select count(*) cnt from t3 where t1.a1 = t3.c1 having cnt = 0) r | ||
on l.cnt = r.cnt | ||
) a from t1 order by a desc | ||
-- !query schema | ||
struct<a:bigint> | ||
-- !query output | ||
0 | ||
NULL | ||
|
||
|
||
-- !query | ||
select ( | ||
select sum(l.cnt + r.cnt) | ||
from (select count(*) cnt from t2 where t1.a1 = t2.b1) l | ||
join (select count(*) cnt from t3 where t1.a1 = t3.c1) r | ||
on l.cnt = r.cnt | ||
) a from t1 order by a desc | ||
-- !query schema | ||
struct<a:bigint> | ||
-- !query output | ||
4 | ||
0 | ||
|
||
|
||
-- !query | ||
reset spark.sql.optimizer.decorrelateInnerQuery.enabled | ||
-- !query schema | ||
struct<> | ||
-- !query output | ||
|
||
|
||
|
||
-- !query | ||
reset spark.sql.legacy.scalarSubqueryCountBugBehavior | ||
-- !query schema | ||
struct<> | ||
-- !query output | ||
|
||
|
||
|
||
-- !query | ||
DROP VIEW t1 | ||
-- !query schema | ||
struct<> | ||
-- !query output | ||
|
||
|
||
|
||
-- !query | ||
DROP VIEW t2 | ||
-- !query schema | ||
struct<> | ||
-- !query output | ||
|
||
|
||
|
||
-- !query | ||
DROP VIEW t3 | ||
-- !query schema | ||
struct<> | ||
-- !query output | ||
|