Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Query: Allow final GroupBy navigation #29205

Merged
merged 2 commits into from
Oct 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -454,14 +454,14 @@ private static ShapedQueryExpression CreateShapedQueryExpression(IEntityType ent
{
// This could be group by entity type
if (remappedKeySelector is not EntityShaperExpression
{ ValueBufferExpression: ProjectionBindingExpression })
{ ValueBufferExpression: ProjectionBindingExpression pbe } ese)
{
// ValueBufferExpression can be JsonQuery, ProjectionBindingExpression, EntityProjection
// We only allow ProjectionBindingExpression which represents a regular entity
return null;
}

translatedKey = remappedKeySelector;
translatedKey = ese.Update(((SelectExpression)pbe.QueryExpression).GetProjection(pbe));
}

if (elementSelector != null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ protected override Expression VisitExtension(Expression extensionExpression)
switch (extensionExpression)
{
case RelationalEntityShaperExpression entityShaperExpression
when entityShaperExpression.ValueBufferExpression is ProjectionBindingExpression projectionBindingExpression:
when !_inline && entityShaperExpression.ValueBufferExpression is ProjectionBindingExpression projectionBindingExpression:
{
if (!_variableShaperMapping.TryGetValue(entityShaperExpression.ValueBufferExpression, out var accessor))
{
Expand Down Expand Up @@ -484,6 +484,29 @@ protected override Expression VisitExtension(Expression extensionExpression)
return accessor;
}

case RelationalEntityShaperExpression entityShaperExpression
when _inline && entityShaperExpression.ValueBufferExpression is ProjectionBindingExpression projectionBindingExpression:
{
if (entityShaperExpression.EntityType.GetMappingStrategy() == RelationalAnnotationNames.TpcMappingStrategy)
{
var concreteTypes = entityShaperExpression.EntityType.GetDerivedTypesInclusive().Where(e => !e.IsAbstract())
.ToArray();
// Single concrete TPC entity type won't have discriminator column.
// We store the value here and inject it directly rather than reading from server.
if (concreteTypes.Length == 1)
{
_singleEntityTypeDiscriminatorValues[
(ProjectionBindingExpression)entityShaperExpression.ValueBufferExpression]
= concreteTypes[0].ShortName();
}
}

var entityMaterializationExpression = _parentVisitor.InjectEntityMaterializers(entityShaperExpression);
entityMaterializationExpression = Visit(entityMaterializationExpression);

return entityMaterializationExpression;
}

case CollectionResultExpression collectionResultExpression
when collectionResultExpression.Navigation is INavigation navigation
&& GetProjectionIndex(collectionResultExpression.ProjectionBindingExpression)
Expand Down
66 changes: 51 additions & 15 deletions src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -867,37 +867,43 @@ public Expression ApplyProjection(
if (shaperExpression is RelationalGroupByShaperExpression groupByShaper)
{
// We need to add key to projection and generate key selector in terms of projectionBindings
var projectionBindingMap = new Dictionary<SqlExpression, ProjectionBindingExpression>();
var projectionBindingMap = new Dictionary<SqlExpression, Expression>();
var keySelector = AddGroupByKeySelectorToProjection(
this, newClientProjections, projectionBindingMap, groupByShaper.KeySelector);
var (keyIdentifier, keyIdentifierValueComparers) = GetIdentifierAccessor(projectionBindingMap, _identifier);
var (keyIdentifier, keyIdentifierValueComparers) = GetIdentifierAccessor(
this, newClientProjections, projectionBindingMap, _identifier);
_identifier.Clear();
_identifier.AddRange(_preGroupByIdentifier!);
_preGroupByIdentifier!.Clear();

static Expression AddGroupByKeySelectorToProjection(
Expression AddGroupByKeySelectorToProjection(
SelectExpression selectExpression,
List<Expression> clientProjectionList,
Dictionary<SqlExpression, ProjectionBindingExpression> projectionBindingMap,
Dictionary<SqlExpression, Expression> projectionBindingMap,
Expression keySelector)
{
switch (keySelector)
{
case SqlExpression sqlExpression:
{
var index = selectExpression.AddToProjection(sqlExpression);
var clientProjectionToAdd = Constant(index);
var existingIndex = clientProjectionList.FindIndex(
e => ExpressionEqualityComparer.Instance.Equals(e, clientProjectionToAdd));
if (existingIndex == -1)
{
clientProjectionList.Add(Constant(index));
clientProjectionList.Add(clientProjectionToAdd);
existingIndex = clientProjectionList.Count - 1;
}

var projectionBindingExpression = new ProjectionBindingExpression(
selectExpression, existingIndex, sqlExpression.Type.MakeNullable());
var projectionBindingExpression = sqlExpression.Type.IsNullableType()
? (Expression)new ProjectionBindingExpression(selectExpression, existingIndex, sqlExpression.Type)
: Convert(new ProjectionBindingExpression(
selectExpression, existingIndex, sqlExpression.Type.MakeNullable()),
sqlExpression.Type);
projectionBindingMap[sqlExpression] = projectionBindingExpression;
return projectionBindingExpression;
}

case NewExpression newExpression:
var newArguments = new Expression[newExpression.Arguments.Count];
Expand Down Expand Up @@ -936,25 +942,57 @@ static Expression AddGroupByKeySelectorToProjection(
AddGroupByKeySelectorToProjection(
selectExpression, clientProjectionList, projectionBindingMap, unaryExpression.Operand));

case EntityShaperExpression entityShaperExpression
when entityShaperExpression.ValueBufferExpression is EntityProjectionExpression entityProjectionExpression:
{
var clientProjectionToAdd = AddEntityProjection(entityProjectionExpression);
var existingIndex = clientProjectionList.FindIndex(
e => ExpressionEqualityComparer.Instance.Equals(e, clientProjectionToAdd));
if (existingIndex == -1)
{
clientProjectionList.Add(clientProjectionToAdd);
existingIndex = clientProjectionList.Count - 1;
}

return entityShaperExpression.Update(
new ProjectionBindingExpression(selectExpression, existingIndex, typeof(ValueBuffer)));
}

default:
throw new InvalidOperationException(
RelationalStrings.InvalidKeySelectorForGroupBy(keySelector, keySelector.GetType()));
}
}

static (Expression, IReadOnlyList<ValueComparer>) GetIdentifierAccessor(
Dictionary<SqlExpression, ProjectionBindingExpression> projectionBindingMap,
SelectExpression selectExpression,
List<Expression> clientProjectionList,
Dictionary<SqlExpression, Expression> projectionBindingMap,
IEnumerable<(ColumnExpression Column, ValueComparer Comparer)> identifyingProjection)
{
var updatedExpressions = new List<Expression>();
var comparers = new List<ValueComparer>();
foreach (var (column, comparer) in identifyingProjection)
{
var projectionBindingExpression = projectionBindingMap[column];
if (!projectionBindingMap.TryGetValue(column, out var mappedExpresssion))
{
var index = selectExpression.AddToProjection(column);
var clientProjectionToAdd = Constant(index);
var existingIndex = clientProjectionList.FindIndex(
e => ExpressionEqualityComparer.Instance.Equals(e, clientProjectionToAdd));
if (existingIndex == -1)
{
clientProjectionList.Add(clientProjectionToAdd);
existingIndex = clientProjectionList.Count - 1;
}

mappedExpresssion = new ProjectionBindingExpression(selectExpression, existingIndex, column.Type.MakeNullable());
}

updatedExpressions.Add(
projectionBindingExpression.Type.IsValueType
? Convert(projectionBindingExpression, typeof(object))
: projectionBindingExpression);
mappedExpresssion.Type.IsValueType
? Convert(mappedExpresssion, typeof(object))
: mappedExpresssion);
comparers.Add(comparer);
}

Expand Down Expand Up @@ -2005,9 +2043,7 @@ private static void PopulateGroupByTerms(
break;

case EntityShaperExpression entityShaperExpression
when entityShaperExpression.ValueBufferExpression is ProjectionBindingExpression projectionBindingExpression:
var entityProjectionExpression = (EntityProjectionExpression)((SelectExpression)projectionBindingExpression.QueryExpression)
.GetProjection(projectionBindingExpression);
when entityShaperExpression.ValueBufferExpression is EntityProjectionExpression entityProjectionExpression:
foreach (var property in GetAllPropertiesInHierarchy(entityProjectionExpression.EntityType))
{
PopulateGroupByTerms(entityProjectionExpression.BindProperty(property), groupByTerms, groupByAliases, name: null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,13 @@ public override void Value_conversion_on_enum_collection_contains()
CoreStrings.TranslationFailed("")[47..],
Assert.Throws<InvalidOperationException>(() => base.Value_conversion_on_enum_collection_contains()).Message);

public override void GroupBy_converted_enum()
{
Assert.Contains(
CoreStrings.TranslationFailed("")[21..],
Assert.Throws<InvalidOperationException>(() => base.GroupBy_converted_enum()).Message);
}

private void AssertSql(params string[] expected)
=> Fixture.TestSqlLoggerFactory.AssertBaseline(expected);

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using Microsoft.EntityFrameworkCore.InMemory.Internal;
using static Microsoft.EntityFrameworkCore.DbLoggerCategory;

namespace Microsoft.EntityFrameworkCore;

public class CustomConvertersInMemoryTest : CustomConvertersTestBase<CustomConvertersInMemoryTest.CustomConvertersInMemoryFixture>
Expand Down Expand Up @@ -36,6 +39,13 @@ public override void Collection_property_as_scalar_Count_member()
public override void Collection_enum_as_string_Contains()
=> base.Collection_enum_as_string_Contains();

public override void GroupBy_converted_enum()
{
Assert.Contains(
CoreStrings.TranslationFailedWithDetails("", InMemoryStrings.NonComposedGroupByNotSupported)[21..],
Assert.Throws<InvalidOperationException>(() => base.GroupBy_converted_enum()).Message);
}

public class CustomConvertersInMemoryFixture : CustomConvertersFixtureBase
{
public override bool StrictEquality
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,16 @@ public override Task Final_GroupBy_property_entity(bool async)
() => base.Final_GroupBy_property_entity(async),
InMemoryStrings.NonComposedGroupByNotSupported);

public override Task Final_GroupBy_entity(bool async)
=> AssertTranslationFailedWithDetails(
() => base.Final_GroupBy_entity(async),
InMemoryStrings.NonComposedGroupByNotSupported);

public override Task Final_GroupBy_property_entity_non_nullable(bool async)
=> AssertTranslationFailedWithDetails(
() => base.Final_GroupBy_property_entity_non_nullable(async),
InMemoryStrings.NonComposedGroupByNotSupported);

public override Task Final_GroupBy_property_anonymous_type(bool async)
=> AssertTranslationFailedWithDetails(
() => base.Final_GroupBy_property_anonymous_type(async),
Expand Down
39 changes: 38 additions & 1 deletion test/EFCore.Specification.Tests/CustomConvertersTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -811,6 +811,36 @@ public enum HoldingEnum
Value2
}

[ConditionalFact]
public virtual void GroupBy_converted_enum()
{
using var context = CreateContext();
var result = context.Set<Entity>().GroupBy(e => e.SomeEnum).ToList();

Assert.Collection(result,
t =>
{
Assert.Equal(SomeEnum.No, t.Key);
Assert.Single(t);
},
t =>
{
Assert.Equal(SomeEnum.Yes, t.Key);
Assert.Equal(2, t.Count());
});
}

public class Entity
{
public int Id { get; set; }
public SomeEnum SomeEnum { get; set; }
}
public enum SomeEnum
{
Yes,
No
}

public abstract class CustomConvertersFixtureBase : BuiltInDataTypesFixtureBase
{
protected override string StoreName
Expand Down Expand Up @@ -1340,6 +1370,12 @@ protected override void OnModelCreating(ModelBuilder modelBuilder, DbContext con
v => new List<Layout>(v)));

modelBuilder.Entity<HolderClass>().HasData(new HolderClass { Id = 1, HoldingEnum = HoldingEnum.Value2 });

modelBuilder.Entity<Entity>().Property(e => e.SomeEnum).HasConversion(e => e.ToString(), e => Enum.Parse<SomeEnum>(e));
modelBuilder.Entity<Entity>().HasData(
new Entity { Id = 1, SomeEnum = SomeEnum.Yes },
new Entity { Id = 2, SomeEnum = SomeEnum.No },
new Entity { Id = 3, SomeEnum = SomeEnum.Yes });
}

private static class StringToDictionarySerializer
Expand Down Expand Up @@ -1376,7 +1412,8 @@ public static List<Layout> Deserialize(string s)
list.Add(
new Layout
{
Height = int.Parse(parts[0]), Width = int.Parse(parts[1]),
Height = int.Parse(parts[0]),
Width = int.Parse(parts[1]),
});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2596,6 +2596,26 @@ public virtual Task Final_GroupBy_property_entity(bool async)
elementAsserter: (e, a) => AssertGrouping(e, a),
entryCount: 91);

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Final_GroupBy_entity(bool async)
=> AssertQuery(
async,
ss => ss.Set<Order>().Where(e => e.OrderID < 10500).GroupBy(c => c.Customer),
elementSorter: e => e.Key.CustomerID,
elementAsserter: (e, a) => AssertGrouping(e, a),
entryCount: 328);

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Final_GroupBy_property_entity_non_nullable(bool async)
=> AssertQuery(
async,
ss => ss.Set<OrderDetail>().Where(e => e.OrderID < 10500).GroupBy(c => c.OrderID),
elementSorter: e => e.Key,
elementAsserter: (e, a) => AssertGrouping(e, a),
entryCount: 664);

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Final_GroupBy_property_anonymous_type(bool async)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,8 @@ public virtual void Columns_have_expected_data_types()
DateTimeEnclosure.Id ---> [int] [Precision = 10 Scale = 0]
EmailTemplate.Id ---> [uniqueidentifier]
EmailTemplate.TemplateType ---> [int] [Precision = 10 Scale = 0]
Entity.Id ---> [int] [Precision = 10 Scale = 0]
Entity.SomeEnum ---> [nvarchar] [MaxLength = -1]
EntityWithValueWrapper.Id ---> [int] [Precision = 10 Scale = 0]
EntityWithValueWrapper.Wrapper ---> [nullable nvarchar] [MaxLength = -1]
HolderClass.HoldingEnum ---> [int] [Precision = 10 Scale = 0]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3142,6 +3142,29 @@ FROM [Customers] AS [c]
ORDER BY [c].[City]");
}

public override async Task Final_GroupBy_entity(bool async)
{
await base.Final_GroupBy_entity(async);

AssertSql(
@"SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region], [o].[OrderID], [o].[CustomerID], [o].[EmployeeID], [o].[OrderDate]
FROM [Orders] AS [o]
LEFT JOIN [Customers] AS [c] ON [o].[CustomerID] = [c].[CustomerID]
WHERE [o].[OrderID] < 10500
ORDER BY [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region]");
}

public override async Task Final_GroupBy_property_entity_non_nullable(bool async)
{
await base.Final_GroupBy_property_entity_non_nullable(async);

AssertSql(
@"SELECT [o].[OrderID], [o].[ProductID], [o].[Discount], [o].[Quantity], [o].[UnitPrice]
FROM [Order Details] AS [o]
WHERE [o].[OrderID] < 10500
ORDER BY [o].[OrderID]");
}

public override async Task Final_GroupBy_property_anonymous_type(bool async)
{
await base.Final_GroupBy_property_anonymous_type(async);
Expand Down