From 3dd11dbccef4660d9af7346c41c740d70fa517f2 Mon Sep 17 00:00:00 2001 From: Smit Patel Date: Wed, 5 Jan 2022 12:10:55 -0800 Subject: [PATCH] Query: Avoid stackoverflow in lifting group by aggregate term Correlate the scalar subquery with parent SelectExpression Resolves #27094 --- .../SqlExpressions/SelectExpression.Helper.cs | 15 ++--- .../Query/SqlExpressions/SelectExpression.cs | 19 +++++- .../Query/NorthwindGroupByQueryTestBase.cs | 54 +++++++++++++++++ .../Query/SimpleQueryTestBase.cs | 60 +++++++++++++++++++ .../NorthwindGroupByQuerySqlServerTest.cs | 59 ++++++++++++++++++ .../Query/SimpleQuerySqlServerTest.cs | 36 +++++++++++ .../Query/NorthwindGroupByQuerySqliteTest.cs | 12 ++++ 7 files changed, 243 insertions(+), 12 deletions(-) diff --git a/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.Helper.cs b/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.Helper.cs index ed599055dc7..ee8146c9a95 100644 --- a/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.Helper.cs +++ b/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.Helper.cs @@ -795,7 +795,8 @@ private sealed class CloningExpressionVisitor : ExpressionVisitor Tags = selectExpression.Tags, _usedAliases = selectExpression._usedAliases.ToHashSet(), _projectionMapping = newProjectionMappings, - _groupingCorrelationPredicate = groupingCorrelationPredicate + _groupingCorrelationPredicate = groupingCorrelationPredicate, + _groupingParentSelectExpressionId = selectExpression._groupingParentSelectExpressionId }; newSelectExpression._tptLeftJoinTables.AddRange(selectExpression._tptLeftJoinTables); @@ -869,7 +870,9 @@ public GroupByAggregateLiftingExpressionVisitor(SelectExpression selectExpressio && subquery._groupBy.Count == 0 && subquery.Predicate != null && ((AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue27102", out var enabled) && enabled) - || subquery.Predicate.Equals(subquery._groupingCorrelationPredicate))) + || subquery.Predicate.Equals(subquery._groupingCorrelationPredicate)) + && ((AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue27094", out var enabled2) && enabled2) + || subquery._groupingParentSelectExpressionId == _selectExpression._groupingParentSelectExpressionId)) { var initialTableCounts = 0; var potentialTableCount = Math.Min(_selectExpression._tables.Count, subquery._tables.Count); @@ -897,7 +900,7 @@ public GroupByAggregateLiftingExpressionVisitor(SelectExpression selectExpressio // We only replace columns from initial tables. // Additional tables may have been added to outer from other terms which may end up matching on table alias var columnExpressionReplacingExpressionVisitor = - AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue27083", out var enabled2) && enabled2 + AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue27083", out var enabled3) && enabled3 ? new ColumnExpressionReplacingExpressionVisitor( subquery, _selectExpression._tableReferences) : new ColumnExpressionReplacingExpressionVisitor( @@ -924,12 +927,6 @@ public GroupByAggregateLiftingExpressionVisitor(SelectExpression selectExpressio } } - if (expression is SelectExpression innerSelectExpression - && innerSelectExpression.GroupBy.Count > 0) - { - expression = new GroupByAggregateLiftingExpressionVisitor(innerSelectExpression).Visit(innerSelectExpression); - } - return base.Visit(expression); } diff --git a/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs b/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs index fddea73a0f7..208b73cf9f6 100644 --- a/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs +++ b/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs @@ -62,6 +62,7 @@ public sealed partial class SelectExpression : TableExpressionBase private readonly List _aliasForClientProjections = new(); private SqlExpression? _groupingCorrelationPredicate; + private Guid? _groupingParentSelectExpressionId; private CloningExpressionVisitor? _cloningExpressionVisitor; private SelectExpression( @@ -1255,6 +1256,11 @@ public GroupByShaperExpression ApplyGrouping( // We generate the cloned expression before changing identifier for this SelectExpression // because we are going to erase grouping for cloned expression. + if (!(AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue27094", out var enabled2) && enabled2)) + { + _groupingParentSelectExpressionId = Guid.NewGuid(); + + } var clonedSelectExpression = Clone(); var correlationPredicate = groupByTerms.Zip(clonedSelectExpression._groupBy) .Select(e => sqlExpressionFactory.Equal(e.First, e.Second)) @@ -1487,13 +1493,17 @@ private void ApplySetOperation(SetOperationType setOperationType, SelectExpressi Predicate = Predicate, Having = Having, Offset = Offset, - Limit = Limit + Limit = Limit, + _groupingParentSelectExpressionId = _groupingParentSelectExpressionId, + _groupingCorrelationPredicate = _groupingCorrelationPredicate }; Offset = null; Limit = null; IsDistinct = false; Predicate = null; Having = null; + _groupingCorrelationPredicate = null; + _groupingParentSelectExpressionId = null; _groupBy.Clear(); _orderings.Clear(); _tables.Clear(); @@ -2808,7 +2818,8 @@ private SqlRemappingVisitor PushdownIntoSubqueryInternal() Predicate = Predicate, Having = Having, Offset = Offset, - Limit = Limit + Limit = Limit, + _groupingParentSelectExpressionId = _groupingParentSelectExpressionId }; subquery._usedAliases = _usedAliases; _tables.Clear(); @@ -3273,6 +3284,7 @@ private void AddTable(TableExpressionBase tableExpressionBase, TableReferenceExp tableReferenceExpression.Alias = uniqueAlias; tableExpressionBase = (TableExpressionBase)new AliasUniquefier(_usedAliases).Visit(tableExpressionBase); + _tables.Add(tableExpressionBase); _tableReferences.Add(tableReferenceExpression); } @@ -3469,7 +3481,8 @@ protected override Expression VisitChildren(ExpressionVisitor visitor) IsDistinct = IsDistinct, Tags = Tags, _usedAliases = _usedAliases, - _groupingCorrelationPredicate = groupingCorrelationPredicate + _groupingCorrelationPredicate = groupingCorrelationPredicate, + _groupingParentSelectExpressionId = _groupingParentSelectExpressionId }; newSelectExpression._tptLeftJoinTables.AddRange(_tptLeftJoinTables); diff --git a/test/EFCore.Specification.Tests/Query/NorthwindGroupByQueryTestBase.cs b/test/EFCore.Specification.Tests/Query/NorthwindGroupByQueryTestBase.cs index 4a289cd229c..92f4b87780c 100644 --- a/test/EFCore.Specification.Tests/Query/NorthwindGroupByQueryTestBase.cs +++ b/test/EFCore.Specification.Tests/Query/NorthwindGroupByQueryTestBase.cs @@ -3415,6 +3415,60 @@ public virtual Task AsEnumerable_in_subquery_for_GroupBy(bool async) entryCount: 15); } + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual Task GroupBy_aggregate_from_multiple_query_in_same_projection(bool async) + { + return AssertQuery( + async, + ss => ss.Set().GroupBy(e => e.CustomerID) + .Select(g => new + { + g.Key, + A = ss.Set().Where(e => e.City == "Seattle").GroupBy(e => e.City) + .Select(g2 => new { g2.Key, C = g2.Count() + g.Count() }) + .OrderBy(e => 1) + .FirstOrDefault() + }), + elementSorter: e => e.Key); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual Task GroupBy_aggregate_from_multiple_query_in_same_projection_2(bool async) + { + return AssertQuery( + async, + ss => ss.Set().GroupBy(e => e.CustomerID) + .Select(g => new + { + g.Key, + A = ss.Set().Where(e => e.City == "Seattle").GroupBy(e => e.City) + .Select(g2 => g2.Count() + g.Min(e => e.OrderID)) + .OrderBy(e => 1) + .FirstOrDefault() + }), + elementSorter: e => e.Key); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual Task GroupBy_aggregate_from_multiple_query_in_same_projection_3(bool async) + { + return AssertQuery( + async, + ss => ss.Set().GroupBy(e => e.CustomerID) + .Select(g => new + { + g.Key, + A = ss.Set().Where(e => e.City == "Seattle").GroupBy(e => e.City) + .Select(g2 => g2.Count() + g.Count() ) + .OrderBy(e => e) + .FirstOrDefault() + }), + elementSorter: e => e.Key); + } + #endregion #region GroupByAndDistinctWithCorrelatedCollection diff --git a/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.cs b/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.cs index 4633013d764..f0fdfd2057e 100644 --- a/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.cs +++ b/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.cs @@ -726,5 +726,65 @@ protected class TimeSheet public int? OrderId { get; set; } public Order Order { get; set; } } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual async Task Aggregate_over_subquery_in_group_by_projection_2(bool async) + { + var contextFactory = await InitializeAsync(); + using var context = contextFactory.CreateContext(); + + var query = from t in context.Table + group t.Id by t.Value into tg + select new + { + A = tg.Key, + B = context.Table.Where(t => t.Value == tg.Max() * 6).Max(t => (int?)t.Id), + }; + + var orders = async + ? await query.ToListAsync() + : query.ToList(); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual async Task Group_by_aggregate_in_subquery_projection_after_group_by(bool async) + { + var contextFactory = await InitializeAsync(); + using var context = contextFactory.CreateContext(); + + var query = from t in context.Table + group t.Id by t.Value into tg + select new + { + A = tg.Key, + B = tg.Sum(), + C = (from t in context.Table + group t.Id by t.Value into tg2 + select tg.Sum() + tg2.Sum() + ).OrderBy(e => 1).FirstOrDefault() + }; + + var orders = async + ? await query.ToListAsync() + : query.ToList(); + } + + protected class Context27094 : DbContext + { + public Context27094(DbContextOptions options) + : base(options) + { + } + + public DbSet Table { get; set; } + } + + protected class Table + { + public int Id { get; set; } + public int? Value { get; set; } + } } } diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindGroupByQuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindGroupByQuerySqlServerTest.cs index 0c065afcc42..e9161ea5c7d 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindGroupByQuerySqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindGroupByQuerySqlServerTest.cs @@ -2808,6 +2808,65 @@ WHERE [c].[CustomerID] LIKE N'F%' ORDER BY [c].[CustomerID], [t2].[CustomerID0]"); } + public override async Task GroupBy_aggregate_from_multiple_query_in_same_projection(bool async) + { + await base.GroupBy_aggregate_from_multiple_query_in_same_projection(async); + + AssertSql( + @"SELECT [t].[CustomerID], [t0].[Key], [t0].[C], [t0].[c0] +FROM ( + SELECT [o].[CustomerID] + FROM [Orders] AS [o] + GROUP BY [o].[CustomerID] +) AS [t] +OUTER APPLY ( + SELECT TOP(1) [e].[City] AS [Key], COUNT(*) + ( + SELECT COUNT(*) + FROM [Orders] AS [o0] + WHERE ([t].[CustomerID] = [o0].[CustomerID]) OR ([t].[CustomerID] IS NULL AND [o0].[CustomerID] IS NULL)) AS [C], 1 AS [c0] + FROM [Employees] AS [e] + WHERE [e].[City] = N'Seattle' + GROUP BY [e].[City] + ORDER BY (SELECT 1) +) AS [t0]"); + } + + public override async Task GroupBy_aggregate_from_multiple_query_in_same_projection_2(bool async) + { + await base.GroupBy_aggregate_from_multiple_query_in_same_projection_2(async); + + AssertSql( + @"SELECT [o].[CustomerID] AS [Key], COALESCE(( + SELECT TOP(1) COUNT(*) + MIN([o].[OrderID]) + FROM [Employees] AS [e] + WHERE [e].[City] = N'Seattle' + GROUP BY [e].[City] + ORDER BY (SELECT 1)), 0) AS [A] +FROM [Orders] AS [o] +GROUP BY [o].[CustomerID]"); + } + + public override async Task GroupBy_aggregate_from_multiple_query_in_same_projection_3(bool async) + { + await base.GroupBy_aggregate_from_multiple_query_in_same_projection_3(async); + + AssertSql( + @"SELECT [o].[CustomerID] AS [Key], COALESCE(( + SELECT TOP(1) COUNT(*) + ( + SELECT COUNT(*) + FROM [Orders] AS [o0] + WHERE ([o].[CustomerID] = [o0].[CustomerID]) OR ([o].[CustomerID] IS NULL AND [o0].[CustomerID] IS NULL)) + FROM [Employees] AS [e] + WHERE [e].[City] = N'Seattle' + GROUP BY [e].[City] + ORDER BY COUNT(*) + ( + SELECT COUNT(*) + FROM [Orders] AS [o0] + WHERE ([o].[CustomerID] = [o0].[CustomerID]) OR ([o].[CustomerID] IS NULL AND [o0].[CustomerID] IS NULL))), 0) AS [A] +FROM [Orders] AS [o] +GROUP BY [o].[CustomerID]"); + } + public override async Task GroupBy_scalar_aggregate_in_set_operation(bool async) { await base.GroupBy_scalar_aggregate_in_set_operation(async); diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.cs index 106bae7a553..d36f5a37fe8 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.cs @@ -177,5 +177,41 @@ FROM [Order] AS [o] WHERE ([o].[Number] <> N'A1') OR [o].[Number] IS NULL GROUP BY [o].[CustomerId], [o].[Number]"); } + + public override async Task Aggregate_over_subquery_in_group_by_projection_2(bool async) + { + await base.Aggregate_over_subquery_in_group_by_projection_2(async); + + AssertSql( + @"SELECT [t].[Value] AS [A], ( + SELECT MAX([t0].[Id]) + FROM [Table] AS [t0] + WHERE ([t0].[Value] = (( + SELECT MAX([t1].[Id]) + FROM [Table] AS [t1] + WHERE ([t].[Value] = [t1].[Value]) OR ([t].[Value] IS NULL AND [t1].[Value] IS NULL)) * 6)) OR ([t0].[Value] IS NULL AND ( + SELECT MAX([t1].[Id]) + FROM [Table] AS [t1] + WHERE ([t].[Value] = [t1].[Value]) OR ([t].[Value] IS NULL AND [t1].[Value] IS NULL)) IS NULL)) AS [B] +FROM [Table] AS [t] +GROUP BY [t].[Value]"); + } + + public override async Task Group_by_aggregate_in_subquery_projection_after_group_by(bool async) + { + await base.Group_by_aggregate_in_subquery_projection_after_group_by(async); + + AssertSql( + @"SELECT [t].[Value] AS [A], COALESCE(SUM([t].[Id]), 0) AS [B], COALESCE(( + SELECT TOP(1) ( + SELECT COALESCE(SUM([t1].[Id]), 0) + FROM [Table] AS [t1] + WHERE ([t].[Value] = [t1].[Value]) OR ([t].[Value] IS NULL AND [t1].[Value] IS NULL)) + COALESCE(SUM([t0].[Id]), 0) + FROM [Table] AS [t0] + GROUP BY [t0].[Value] + ORDER BY (SELECT 1)), 0) AS [C] +FROM [Table] AS [t] +GROUP BY [t].[Value]"); + } } } diff --git a/test/EFCore.Sqlite.FunctionalTests/Query/NorthwindGroupByQuerySqliteTest.cs b/test/EFCore.Sqlite.FunctionalTests/Query/NorthwindGroupByQuerySqliteTest.cs index 6039d0f7ca4..8af9b688414 100644 --- a/test/EFCore.Sqlite.FunctionalTests/Query/NorthwindGroupByQuerySqliteTest.cs +++ b/test/EFCore.Sqlite.FunctionalTests/Query/NorthwindGroupByQuerySqliteTest.cs @@ -3,6 +3,7 @@ using System; using System.Threading.Tasks; +using Microsoft.Data.Sqlite; using Microsoft.EntityFrameworkCore.Sqlite.Internal; using Microsoft.EntityFrameworkCore.TestUtilities; using Xunit; @@ -76,5 +77,16 @@ public override async Task Complex_query_with_group_by_in_subquery5(bool async) public override async Task Odata_groupby_empty_key(bool async) => await Assert.ThrowsAsync(() => base.Odata_groupby_empty_key(async)); + + public override async Task GroupBy_aggregate_from_multiple_query_in_same_projection(bool async) + => Assert.Equal( + SqliteStrings.ApplyNotSupported, + (await Assert.ThrowsAsync( + () => base.GroupBy_aggregate_from_multiple_query_in_same_projection(async))).Message); + + public override async Task GroupBy_aggregate_from_multiple_query_in_same_projection_3(bool async) + => await Assert.ThrowsAsync(() => base.GroupBy_aggregate_from_multiple_query_in_same_projection_3(async)); + + } }