Skip to content

Commit

Permalink
Query: Add TableReferenceExpression as a bridge to ColumnExpression f…
Browse files Browse the repository at this point in the history
…or referential integrity

Part of #17337
  • Loading branch information
smitpatel committed Mar 19, 2021
1 parent 92a1b91 commit 48b9bc3
Show file tree
Hide file tree
Showing 18 changed files with 1,005 additions and 778 deletions.
18 changes: 6 additions & 12 deletions src/EFCore.InMemory/Query/Internal/InMemoryQueryExpression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,7 @@ private static readonly PropertyInfo _valueBufferCountMemberInfo
private readonly List<Expression> _clientProjectionExpressions = new();
private readonly List<MethodCallExpression> _projectionMappingExpressions = new();

private readonly IDictionary<EntityProjectionExpression, IDictionary<IProperty, int>> _entityProjectionCache
= new Dictionary<EntityProjectionExpression, IDictionary<IProperty, int>>();
private readonly Dictionary<EntityProjectionExpression, IReadOnlyDictionary<IProperty, int>> _entityProjectionCache = new();

private readonly ParameterExpression _valueBufferParameter;

Expand Down Expand Up @@ -319,17 +318,12 @@ EntityProjectionExpression UpdateEntityProjection(EntityProjectionExpression ent
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
public virtual IDictionary<IProperty, int> AddToProjection(EntityProjectionExpression entityProjectionExpression)
public virtual IReadOnlyDictionary<IProperty, int> AddToProjection(EntityProjectionExpression entityProjectionExpression)
{
if (!_entityProjectionCache.TryGetValue(entityProjectionExpression, out var indexMap))
var indexMap = new Dictionary<IProperty, int>();
foreach (var property in GetAllPropertiesInHierarchy(entityProjectionExpression.EntityType))
{
indexMap = new Dictionary<IProperty, int>();
foreach (var property in GetAllPropertiesInHierarchy(entityProjectionExpression.EntityType))
{
indexMap[property] = AddToProjection(entityProjectionExpression.BindProperty(property));
}

_entityProjectionCache[entityProjectionExpression] = indexMap;
indexMap[property] = AddToProjection(entityProjectionExpression.BindProperty(property));
}

return indexMap;
Expand Down Expand Up @@ -1032,7 +1026,7 @@ public ShaperRemappingExpressionVisitor(IDictionary<ProjectionMember, Expression
&& projectionBindingExpression.ProjectionMember != null)
{
var mappingValue = ((ConstantExpression)_projectionMapping[projectionBindingExpression.ProjectionMember]).Value;
return mappingValue is IDictionary<IProperty, int> indexMap
return mappingValue is IReadOnlyDictionary<IProperty, int> indexMap
? new ProjectionBindingExpression(projectionBindingExpression.QueryExpression, indexMap)
: mappingValue is int index
? new ProjectionBindingExpression(
Expand Down
8 changes: 3 additions & 5 deletions src/EFCore.Relational/Query/EntityProjectionExpression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,8 @@ namespace Microsoft.EntityFrameworkCore.Query
/// </summary>
public class EntityProjectionExpression : Expression
{
private readonly IDictionary<IProperty, ColumnExpression> _propertyExpressionMap = new Dictionary<IProperty, ColumnExpression>();

private readonly IDictionary<INavigation, EntityShaperExpression> _ownedNavigationMap
= new Dictionary<INavigation, EntityShaperExpression>();
private readonly IReadOnlyDictionary<IProperty, ColumnExpression> _propertyExpressionMap = new Dictionary<IProperty, ColumnExpression>();
private readonly Dictionary<INavigation, EntityShaperExpression> _ownedNavigationMap = new();

/// <summary>
/// Creates a new instance of the <see cref="EntityProjectionExpression" /> class.
Expand All @@ -49,7 +47,7 @@ public EntityProjectionExpression(IEntityType entityType, TableExpressionBase in
/// <param name="discriminatorExpression"> A <see cref="SqlExpression" /> to generate discriminator for each concrete entity type in hierarchy. </param>
public EntityProjectionExpression(
IEntityType entityType,
IDictionary<IProperty, ColumnExpression> propertyExpressionMap,
IReadOnlyDictionary<IProperty, ColumnExpression> propertyExpressionMap,
SqlExpression? discriminatorExpression = null)
{
Check.NotNull(entityType, nameof(entityType));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ private static readonly MethodInfo _getParameterValueMethodInfo
private SelectExpression _selectExpression;
private SqlExpression[] _existingProjections;
private bool _clientEval;
private Dictionary<EntityProjectionExpression, ProjectionBindingExpression>? _entityProjectionCache;

private readonly IDictionary<ProjectionMember, Expression> _projectionMapping
= new Dictionary<ProjectionMember, Expression>();
private readonly Dictionary<ProjectionMember, Expression> _projectionMapping = new();

private readonly Stack<ProjectionMember> _projectionMembers = new();

Expand Down Expand Up @@ -77,6 +77,7 @@ public virtual Expression Translate(SelectExpression selectExpression, Expressio
if (result == QueryCompilationContext.NotTranslatedExpression)
{
_clientEval = true;
_entityProjectionCache = new();

expandedExpression = _queryableMethodTranslatingExpressionVisitor.ExpandWeakEntities(_selectExpression, expression);
_existingProjections = _selectExpression.Projection.Select(e => e.Expression).ToArray();
Expand Down Expand Up @@ -334,9 +335,14 @@ protected override Expression VisitExtension(Expression extensionExpression)

if (_clientEval)
{
return entityShaperExpression.Update(
new ProjectionBindingExpression(
_selectExpression, _selectExpression.AddToProjection(entityProjectionExpression)));
if (!_entityProjectionCache!.TryGetValue(entityProjectionExpression, out var entityProjectionBinding))
{
entityProjectionBinding = new ProjectionBindingExpression(
_selectExpression, _selectExpression.AddToProjection(entityProjectionExpression));
_entityProjectionCache[entityProjectionExpression] = entityProjectionBinding;
}

return entityShaperExpression.Update(entityProjectionBinding);
}

_projectionMapping[_projectionMembers.Peek()] = entityProjectionExpression;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ protected override Expression VisitExtension(Expression extensionExpression)
{
return VisitExtension(
_sqlExpressionFactory.Case(
caseExpression.WhenClauses.Union(nestedCaseExpression.WhenClauses).ToList(),
caseExpression.WhenClauses.Union<CaseWhenClause>(
nestedCaseExpression.WhenClauses, ReferenceEqualityComparer.Instance).ToList(),
nestedCaseExpression.ElseResult));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Linq.Expressions;
using Microsoft.EntityFrameworkCore.Internal;
using Microsoft.EntityFrameworkCore.Query.SqlExpressions;
Expand Down Expand Up @@ -56,6 +57,27 @@ private sealed class ScopedVisitor : ExpressionVisitor
private readonly ISet<TableExpressionBase> _visitedTableExpressionBases
= new HashSet<TableExpressionBase>(LegacyReferenceEqualityComparer.Instance);

public Expression EntryPoint(Expression expression)
{
var result = Visit(expression);

foreach (var group in _usedAliases.GroupBy(e => e[0..1]))
{
if (group.Count() == 1)
{
continue;
}

var numbers = group.OrderBy(e => e).Skip(1).Select(e => int.Parse(e)).OrderBy(e => e).ToList();
if (numbers.Count - 1 != numbers[^1])
{
throw new InvalidTimeZoneException();
}
}

return result;
}

[return: NotNullIfNotNull("expression")]
public override Expression? Visit(Expression? expression)
{
Expand All @@ -64,33 +86,17 @@ private readonly ISet<TableExpressionBase> _visitedTableExpressionBases
&& !_visitedTableExpressionBases.Contains(tableExpressionBase)
&& tableExpressionBase.Alias != null)
{
tableExpressionBase.Alias = GenerateUniqueAlias(tableExpressionBase.Alias);
if (_usedAliases.Contains(tableExpressionBase.Alias))
{
throw new InvalidOperationException("Duplicate alias");
}
_usedAliases.Add(tableExpressionBase.Alias);

_visitedTableExpressionBases.Add(tableExpressionBase);
}

return visitedExpression;
}

private string GenerateUniqueAlias(string currentAlias)
{
if (!_usedAliases.Contains(currentAlias))
{
_usedAliases.Add(currentAlias);
return currentAlias;
}

var counter = 0;
var uniqueAlias = currentAlias;

while (_usedAliases.Contains(uniqueAlias))
{
uniqueAlias = currentAlias + counter++;
}

_usedAliases.Add(uniqueAlias);

return uniqueAlias;
}
}
}
}
4 changes: 2 additions & 2 deletions src/EFCore.Relational/Query/QuerySqlGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ private bool IsNonComposedSetOperation(SelectExpression selectExpression)
&& selectExpression.Projection.Count == setOperation.Source1.Projection.Count
&& selectExpression.Projection.Select(
(pe, index) => pe.Expression is ColumnExpression column
&& string.Equals(column.Table.Alias, setOperation.Alias, StringComparison.OrdinalIgnoreCase)
&& string.Equals(column.TableAlias, setOperation.Alias, StringComparison.OrdinalIgnoreCase)
&& string.Equals(
column.Name, setOperation.Source1.Projection[index].Alias, StringComparison.OrdinalIgnoreCase))
.All(e => e);
Expand Down Expand Up @@ -332,7 +332,7 @@ protected override Expression VisitColumn(ColumnExpression columnExpression)
Check.NotNull(columnExpression, nameof(columnExpression));

_relationalCommandBuilder
.Append(_sqlGenerationHelper.DelimitIdentifier(columnExpression.Table.Alias!))
.Append(_sqlGenerationHelper.DelimitIdentifier(columnExpression.TableAlias))
.Append(".")
.Append(_sqlGenerationHelper.DelimitIdentifier(columnExpression.Name));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1441,13 +1441,8 @@ outerKey is NewArrayExpression newArrayExpression
|| (entityType.FindDiscriminatorProperty() == null
&& navigation.DeclaringEntityType.IsStrictlyDerivedFrom(entityShaperExpression.EntityType));

var propertyExpressions = GetPropertyExpressionFromSameTable(
targetEntityType, table, _selectExpression, identifyingColumn, principalNullable);
if (propertyExpressions != null)
{
innerShaper = new RelationalEntityShaperExpression(
targetEntityType, new EntityProjectionExpression(targetEntityType, propertyExpressions), true);
}
innerShaper = _selectExpression.GenerateWeakEntityShaper(
targetEntityType, table, identifyingColumn.Name, identifyingColumn.Table, principalNullable);
}

if (innerShaper == null)
Expand Down Expand Up @@ -1479,10 +1474,9 @@ outerKey is NewArrayExpression newArrayExpression
var joinPredicate = _sqlTranslator.Translate(Expression.Equal(outerKey, innerKey))!;
_selectExpression.AddLeftJoin(innerSelectExpression, joinPredicate);
var leftJoinTable = ((LeftJoinExpression)_selectExpression.Tables.Last()).Table;
var propertyExpressions = GetPropertyExpressionsFromJoinedTable(targetEntityType, table, leftJoinTable);

innerShaper = new RelationalEntityShaperExpression(
targetEntityType, new EntityProjectionExpression(targetEntityType, propertyExpressions), true);
innerShaper = _selectExpression.GenerateWeakEntityShaper(
targetEntityType, table, null, leftJoinTable, makeNullable: true)!;
}

entityProjectionExpression.AddNavigationBinding(navigation, innerShaper);
Expand All @@ -1495,80 +1489,6 @@ private static Expression AddConvertToObject(Expression expression)
=> expression.Type.IsValueType
? Expression.Convert(expression, typeof(object))
: expression;

private static IDictionary<IProperty, ColumnExpression>? GetPropertyExpressionFromSameTable(
IEntityType entityType,
ITableBase table,
SelectExpression selectExpression,
ColumnExpression identifyingColumn,
bool nullable)
{
if (identifyingColumn.Table is TableExpression tableExpression)
{
if (!string.Equals(tableExpression.Name, table.Name, StringComparison.OrdinalIgnoreCase))
{
// Fetch the table for the type which is defining the navigation since dependent would be in that table
tableExpression = selectExpression.Tables
.Select(t => (t as InnerJoinExpression)?.Table ?? (t as LeftJoinExpression)?.Table ?? t)
.Cast<TableExpression>()
.First(t => t.Name == table.Name && t.Schema == table.Schema);
}

var propertyExpressions = new Dictionary<IProperty, ColumnExpression>();
foreach (var property in entityType
.GetAllBaseTypes().Concat(entityType.GetDerivedTypesInclusive())
.SelectMany(t => t.GetDeclaredProperties()))
{
propertyExpressions[property] = new ColumnExpression(
property, table.FindColumn(property)!, tableExpression, nullable || !property.IsPrimaryKey());
}

return propertyExpressions;
}

if (identifyingColumn.Table is SelectExpression subquery)
{
var subqueryIdentifyingColumn = (ColumnExpression)subquery.Projection
.Single(e => string.Equals(e.Alias, identifyingColumn.Name, StringComparison.OrdinalIgnoreCase))
.Expression;

var subqueryPropertyExpressions = GetPropertyExpressionFromSameTable(
entityType, table, subquery, subqueryIdentifyingColumn, nullable);

if (subqueryPropertyExpressions == null)
{
return null;
}

var newPropertyExpressions = new Dictionary<IProperty, ColumnExpression>();
foreach (var item in subqueryPropertyExpressions)
{
newPropertyExpressions[item.Key] = new ColumnExpression(
subquery.Projection[subquery.AddToProjection(item.Value)], subquery);
}

return newPropertyExpressions;
}

return null;
}

private static IDictionary<IProperty, ColumnExpression> GetPropertyExpressionsFromJoinedTable(
IEntityType entityType,
ITableBase table,
TableExpressionBase tableExpression)
{
var propertyExpressions = new Dictionary<IProperty, ColumnExpression>();
foreach (var property in entityType
.GetAllBaseTypes().Concat(entityType.GetDerivedTypesInclusive())
.SelectMany(t => t.GetDeclaredProperties()))
{
propertyExpressions[property] = new ColumnExpression(
property, table.FindColumn(property)!, tableExpression, nullable: true);
}

return propertyExpressions;
}
}

private ShapedQueryExpression TranslateTwoParameterSelector(ShapedQueryExpression source, LambdaExpression resultSelector)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ private static readonly MethodInfo _collectionAccessorAddMethodInfo
private readonly ReaderColumn[]? _readerColumns;

// States to materialize only once
private readonly IDictionary<Expression, Expression> _variableShaperMapping = new Dictionary<Expression, Expression>();
private readonly Dictionary<Expression, Expression> _variableShaperMapping = new(ReferenceEqualityComparer.Instance);

// There are always entity variables to avoid materializing same entity twice
private readonly List<ParameterExpression> _variables = new();
Expand Down
Loading

0 comments on commit 48b9bc3

Please sign in to comment.