Skip to content

Commit

Permalink
Query: Convert single result subquery comparison to null to Any opera…
Browse files Browse the repository at this point in the history
…tion (#27284)

Resolves #26744
A better fix for #18476

Initial fix for #18476 assumed that whenever we have single result operation compared to null, it will only be true if the result of single result is default when sequence is empty. This was correct for the query in the issue tracker which had anonymous type projection.
Anonymous type is never null as long as there is data, it can be only null value when default is invoked i.e. empty sequence.
Hence we added optimization for that but it didn't restrict to just anonymous type.
For entity type projection when entity is not nullable, the same logic holds true. This helped us translate queries which wouldn't work with entity equality due to composite key from a subquery.
But optimization was incorrect for the result which can be null (nullable scalar or nullable entity) as an non-empty sequence can have first result to be null which can match.

The improved fix avoids doing the unrestricted optimization during preprocessing phase. Instead we moved the logic to translation phase where we can evaluate the shape of the projection coming out subquery. Now we only apply optimization for non-nullable entity and anonymous type. Scalar comparison will work by comparing to null and nullable entity will work if entity equality covers it. It will start throwing error if composite key though earlier version possibly generated wrong results for it.
  • Loading branch information
smitpatel committed Feb 2, 2022
1 parent 06e2f6e commit b7e5c13
Show file tree
Hide file tree
Showing 14 changed files with 426 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,24 @@ public class InMemoryExpressionTranslatingExpressionVisitor : ExpressionVisitor
{
private const string RuntimeParameterPrefix = QueryCompilationContext.QueryParameterPrefix + "entity_equality_";

private static readonly List<MethodInfo> _singleResultMethodInfos = new()
{
QueryableMethods.FirstWithPredicate,
QueryableMethods.FirstWithoutPredicate,
QueryableMethods.FirstOrDefaultWithPredicate,
QueryableMethods.FirstOrDefaultWithoutPredicate,
QueryableMethods.SingleWithPredicate,
QueryableMethods.SingleWithoutPredicate,
QueryableMethods.SingleOrDefaultWithPredicate,
QueryableMethods.SingleOrDefaultWithoutPredicate,
QueryableMethods.LastWithPredicate,
QueryableMethods.LastWithoutPredicate,
QueryableMethods.LastOrDefaultWithPredicate,
QueryableMethods.LastOrDefaultWithoutPredicate
//QueryableMethodProvider.ElementAtMethodInfo,
//QueryableMethodProvider.ElementAtOrDefaultMethodInfo
};

private static readonly MemberInfo _valueBufferIsEmpty = typeof(ValueBuffer).GetMember(nameof(ValueBuffer.IsEmpty))[0];

private static readonly MethodInfo _parameterValueExtractor =
Expand Down Expand Up @@ -161,6 +179,51 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression)
return Visit(ConvertObjectArrayEqualityComparison(binaryExpression.Left, binaryExpression.Right));
}

if (!(AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue26744", out var enabled) && enabled))
{
if ((binaryExpression.NodeType == ExpressionType.Equal || binaryExpression.NodeType == ExpressionType.NotEqual)
&& (binaryExpression.Left.IsNullConstantExpression() || binaryExpression.Right.IsNullConstantExpression()))
{
var nonNullExpression = binaryExpression.Left.IsNullConstantExpression() ? binaryExpression.Right : binaryExpression.Left;
if (nonNullExpression is MethodCallExpression nonNullMethodCallExpression
&& nonNullMethodCallExpression.Method.DeclaringType == typeof(Queryable)
&& nonNullMethodCallExpression.Method.IsGenericMethod
&& _singleResultMethodInfos.Contains(nonNullMethodCallExpression.Method.GetGenericMethodDefinition()))
{
var source = nonNullMethodCallExpression.Arguments[0];
if (nonNullMethodCallExpression.Arguments.Count == 2)
{
source = Expression.Call(
QueryableMethods.Where.MakeGenericMethod(source.Type.GetSequenceType()),
source,
nonNullMethodCallExpression.Arguments[1]);
}

var translatedSubquery = _queryableMethodTranslatingExpressionVisitor.TranslateSubquery(source);
if (translatedSubquery != null)
{
var projection = translatedSubquery.ShaperExpression;
if (projection is NewExpression
|| RemoveConvert(projection) is EntityShaperExpression { IsNullable: false })
{
var anySubquery = Expression.Call(
QueryableMethods.AnyWithoutPredicate.MakeGenericMethod(translatedSubquery.Type.GetSequenceType()),
translatedSubquery);

return Visit(binaryExpression.NodeType == ExpressionType.Equal
? Expression.Not(anySubquery)
: anySubquery);
}

static Expression RemoveConvert(Expression e)
=> e is UnaryExpression { NodeType: ExpressionType.Convert or ExpressionType.ConvertChecked } unary
? RemoveConvert(unary.Operand)
: e;
}
}
}
}

var newLeft = Visit(binaryExpression.Left);
var newRight = Visit(binaryExpression.Right);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,24 @@ public class RelationalSqlTranslatingExpressionVisitor : ExpressionVisitor
{
private const string RuntimeParameterPrefix = QueryCompilationContext.QueryParameterPrefix + "entity_equality_";

private static readonly List<MethodInfo> _singleResultMethodInfos = new()
{
QueryableMethods.FirstWithPredicate,
QueryableMethods.FirstWithoutPredicate,
QueryableMethods.FirstOrDefaultWithPredicate,
QueryableMethods.FirstOrDefaultWithoutPredicate,
QueryableMethods.SingleWithPredicate,
QueryableMethods.SingleWithoutPredicate,
QueryableMethods.SingleOrDefaultWithPredicate,
QueryableMethods.SingleOrDefaultWithoutPredicate,
QueryableMethods.LastWithPredicate,
QueryableMethods.LastWithoutPredicate,
QueryableMethods.LastOrDefaultWithPredicate,
QueryableMethods.LastOrDefaultWithoutPredicate
//QueryableMethodProvider.ElementAtMethodInfo,
//QueryableMethodProvider.ElementAtOrDefaultMethodInfo
};

private static readonly MethodInfo _parameterValueExtractor =
typeof(RelationalSqlTranslatingExpressionVisitor).GetRequiredDeclaredMethod(nameof(ParameterValueExtractor));

Expand Down Expand Up @@ -324,6 +342,51 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression)
right = rightOperand!;
}

if (!(AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue26744", out var enabled) && enabled))
{
if ((binaryExpression.NodeType == ExpressionType.Equal || binaryExpression.NodeType == ExpressionType.NotEqual)
&& (left.IsNullConstantExpression() || right.IsNullConstantExpression()))
{
var nonNullExpression = left.IsNullConstantExpression() ? right : left;
if (nonNullExpression is MethodCallExpression nonNullMethodCallExpression
&& nonNullMethodCallExpression.Method.DeclaringType == typeof(Queryable)
&& nonNullMethodCallExpression.Method.IsGenericMethod
&& _singleResultMethodInfos.Contains(nonNullMethodCallExpression.Method.GetGenericMethodDefinition()))
{
var source = nonNullMethodCallExpression.Arguments[0];
if (nonNullMethodCallExpression.Arguments.Count == 2)
{
source = Expression.Call(
QueryableMethods.Where.MakeGenericMethod(source.Type.GetSequenceType()),
source,
nonNullMethodCallExpression.Arguments[1]);
}

var translatedSubquery = _queryableMethodTranslatingExpressionVisitor.TranslateSubquery(source);
if (translatedSubquery != null)
{
var projection = translatedSubquery.ShaperExpression;
if (projection is NewExpression
|| RemoveConvert(projection) is EntityShaperExpression { IsNullable: false })
{
var anySubquery = Expression.Call(
QueryableMethods.AnyWithoutPredicate.MakeGenericMethod(translatedSubquery.Type.GetSequenceType()),
translatedSubquery);

return Visit(binaryExpression.NodeType == ExpressionType.Equal
? Expression.Not(anySubquery)
: anySubquery);
}

static Expression RemoveConvert(Expression e)
=> e is UnaryExpression { NodeType: ExpressionType.Convert or ExpressionType.ConvertChecked } unary
? RemoveConvert(unary.Operand)
: e;
}
}
}
}

var visitedLeft = Visit(left);
var visitedRight = Visit(right);

Expand Down
45 changes: 24 additions & 21 deletions src/EFCore/Query/Internal/QueryOptimizingExpressionVisitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -77,30 +77,33 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression)
}
}

if (binaryExpression.NodeType == ExpressionType.Equal
|| binaryExpression.NodeType == ExpressionType.NotEqual)
if (AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue26744", out var enabled) && enabled)
{
var leftNullConstant = IsNullConstant(left);
var rightNullConstant = IsNullConstant(right);
if (leftNullConstant || rightNullConstant)
if (binaryExpression.NodeType == ExpressionType.Equal
|| binaryExpression.NodeType == ExpressionType.NotEqual)
{
var nonNullExpression = leftNullConstant ? right : left;
if (nonNullExpression is MethodCallExpression methodCallExpression
&& methodCallExpression.Method.DeclaringType == typeof(Queryable)
&& methodCallExpression.Method.IsGenericMethod
&& methodCallExpression.Method.GetGenericMethodDefinition() is MethodInfo genericMethod
&& _singleResultMethodInfos.Contains(genericMethod))
var leftNullConstant = IsNullConstant(left);
var rightNullConstant = IsNullConstant(right);
if (leftNullConstant || rightNullConstant)
{
var result = Expression.Call(
(methodCallExpression.Arguments.Count == 2
? QueryableMethods.AnyWithPredicate
: QueryableMethods.AnyWithoutPredicate)
.MakeGenericMethod(methodCallExpression.Type),
methodCallExpression.Arguments);

return binaryExpression.NodeType == ExpressionType.Equal
? Expression.Not(result)
: result;
var nonNullExpression = leftNullConstant ? right : left;
if (nonNullExpression is MethodCallExpression methodCallExpression
&& methodCallExpression.Method.DeclaringType == typeof(Queryable)
&& methodCallExpression.Method.IsGenericMethod
&& methodCallExpression.Method.GetGenericMethodDefinition() is MethodInfo genericMethod
&& _singleResultMethodInfos.Contains(genericMethod))
{
var result = Expression.Call(
(methodCallExpression.Arguments.Count == 2
? QueryableMethods.AnyWithPredicate
: QueryableMethods.AnyWithoutPredicate)
.MakeGenericMethod(methodCallExpression.Type),
methodCallExpression.Arguments);

return binaryExpression.NodeType == ExpressionType.Equal
? Expression.Not(result)
: result;
}
}
}
}
Expand Down
18 changes: 18 additions & 0 deletions test/EFCore.Specification.Tests/Query/GearsOfWarQueryTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9138,6 +9138,24 @@ public virtual Task Where_equals_method_on_nullable_with_object_overload(bool as
ss => ss.Set<Mission>().Where(m => m.Rating.Equals(null)));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Where_subquery_equality_to_null_with_composite_key(bool async)
{
return AssertQuery(
async,
ss => ss.Set<Squad>().Where(s => s.Members.OrderBy(e => e.Nickname).FirstOrDefault() == null));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Where_subquery_equality_to_null_without_composite_key(bool async)
{
return AssertQuery(
async,
ss => ss.Set<Gear>().Where(s => s.Weapons.OrderBy(e => e.Name).FirstOrDefault() == null));
}

protected GearsOfWarContext CreateContext()
=> Fixture.CreateContext();

Expand Down
97 changes: 97 additions & 0 deletions test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -889,5 +889,102 @@ public class ChildFilter2
public string Filter2 { get; set; }
public string Value2 { get; set; }
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task Subquery_first_member_compared_to_null(bool async)
{
var contextFactory = await InitializeAsync<Context26744>(seed: c => c.Seed());
using var context = contextFactory.CreateContext();

var query = context.Parents
.Where(p => p.Children.Any(c => c.SomeNullableDateTime == null)
&& p.Children.Where(c => c.SomeNullableDateTime == null)
.OrderBy(c => c.SomeInteger)
.First().SomeOtherNullableDateTime != null)
.Select(p => p.Children.Where(c => c.SomeNullableDateTime == null)
.OrderBy(c => c.SomeInteger)
.First().SomeOtherNullableDateTime);

var result = async
? await query.ToListAsync()
: query.ToList();

Assert.Single(result);
}

[ConditionalTheory(Skip = "Issue#26756")]
[MemberData(nameof(IsAsyncData))]
public virtual async Task SelectMany_where_Select(bool async)
{
var contextFactory = await InitializeAsync<Context26744>(seed: c => c.Seed());
using var context = contextFactory.CreateContext();

var query = context.Parents
.SelectMany(p => p.Children
.Where(c => c.SomeNullableDateTime == null)
.OrderBy(c => c.SomeInteger)
.Take(1))
.Where(c => c.SomeOtherNullableDateTime != null)
.Select(c => c.SomeNullableDateTime);

var result = async
? await query.ToListAsync()
: query.ToList();

Assert.Single(result);
}

protected class Context26744 : DbContext
{
public Context26744(DbContextOptions options)
: base(options)
{
}

public DbSet<Parent26744> Parents { get; set; }
public void Seed()
{
Add(new Parent26744
{
Children = new List<Child26744>
{
new Child26744
{
SomeInteger = 1,
SomeOtherNullableDateTime = new DateTime(2000, 11, 18)
}
}
});

Add(new Parent26744
{
Children = new List<Child26744>
{
new Child26744
{
SomeInteger = 1,
}
}
});

SaveChanges();
}
}

protected class Parent26744
{
public int Id { get; set; }
public List<Child26744> Children { get; set; }
}

protected class Child26744
{
public int Id { get; set; }
public int SomeInteger { get; set; }
public DateTime? SomeNullableDateTime { get; set; }
public DateTime? SomeOtherNullableDateTime { get; set; }
public Parent26744 Parent { get; set; }
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3213,10 +3213,11 @@ public override async Task Member_pushdown_with_multiple_collections(bool async)
@"SELECT (
SELECT TOP(1) [l0].[Name]
FROM [LevelThree] AS [l0]
WHERE EXISTS (
SELECT 1
WHERE ((
SELECT TOP(1) [l1].[Id]
FROM [LevelTwo] AS [l1]
WHERE [l].[Id] = [l1].[OneToMany_Optional_Inverse2Id]) AND (((
WHERE [l].[Id] = [l1].[OneToMany_Optional_Inverse2Id]
ORDER BY [l1].[Id]) IS NOT NULL) AND (((
SELECT TOP(1) [l2].[Id]
FROM [LevelTwo] AS [l2]
WHERE [l].[Id] = [l2].[OneToMany_Optional_Inverse2Id]
Expand Down Expand Up @@ -3699,10 +3700,10 @@ public override async Task Multiple_collection_FirstOrDefault_followed_by_member
@"SELECT [l].[Id], (
SELECT TOP(1) [l0].[Name]
FROM [LevelThree] AS [l0]
WHERE EXISTS (
SELECT 1
WHERE ((
SELECT TOP(1) [l1].[Id]
FROM [LevelTwo] AS [l1]
WHERE ([l].[Id] = [l1].[OneToMany_Optional_Inverse2Id]) AND ([l1].[Name] = N'L2 02')) AND (((
WHERE ([l].[Id] = [l1].[OneToMany_Optional_Inverse2Id]) AND ([l1].[Name] = N'L2 02')) IS NOT NULL) AND (((
SELECT TOP(1) [l2].[Id]
FROM [LevelTwo] AS [l2]
WHERE ([l].[Id] = [l2].[OneToMany_Optional_Inverse2Id]) AND ([l2].[Name] = N'L2 02')) = [l0].[OneToMany_Optional_Inverse3Id]) OR (((
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8223,6 +8223,32 @@ OFFSET @__p_0 ROWS FETCH NEXT @__p_1 ROWS ONLY
ORDER BY [t0].[Nickname], [t0].[SquadId], [t0].[HasSoulPatch0]");
}

public override async Task Where_subquery_equality_to_null_with_composite_key(bool async)
{
await base.Where_subquery_equality_to_null_with_composite_key(async);

AssertSql(
@"SELECT [s].[Id], [s].[Banner], [s].[Banner5], [s].[InternalNumber], [s].[Name]
FROM [Squads] AS [s]
WHERE NOT (EXISTS (
SELECT 1
FROM [Gears] AS [g]
WHERE [s].[Id] = [g].[SquadId]))");
}

public override async Task Where_subquery_equality_to_null_without_composite_key(bool async)
{
await base.Where_subquery_equality_to_null_without_composite_key(async);

AssertSql(
@"SELECT [g].[Nickname], [g].[SquadId], [g].[AssignedCityName], [g].[CityOfBirthName], [g].[Discriminator], [g].[FullName], [g].[HasSoulPatch], [g].[LeaderNickname], [g].[LeaderSquadId], [g].[Rank]
FROM [Gears] AS [g]
WHERE NOT (EXISTS (
SELECT 1
FROM [Weapons] AS [w]
WHERE [g].[FullName] = [w].[OwnerFullName]))");
}

private void AssertSql(params string[] expected)
=> Fixture.TestSqlLoggerFactory.AssertBaseline(expected);
}
Expand Down
Loading

0 comments on commit b7e5c13

Please sign in to comment.