Skip to content

Commit

Permalink
Query: Match correct predicate structure to convert apply to join
Browse files Browse the repository at this point in the history
Resolves #26756

For property access on optional dependents sharing column with principal, we generate CaseExpression. In order to match them during join key search, we updated our recursive function to match shape of the test in CaseExpression which incorrectly matched similar structure outside of case block causing somewhat wrong join key to be extracted. While join key in itself could work, when generating partitions out of it due to paging operation, it gives incorrect results.
The fix is to match the special structure of CaseExpression.Test separately. Also made the key comparison match more robust by only allowing column or case block to appear in condition. Other structures don't represent part of join key in a comparison.
Null checks for join keys are handled separately so that we only remove null checks which are indeed used in comparison with other columns in other operations.
  • Loading branch information
smitpatel committed Jan 26, 2022
1 parent d8ff3c4 commit 17b2756
Show file tree
Hide file tree
Showing 7 changed files with 66 additions and 24 deletions.
63 changes: 52 additions & 11 deletions src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2385,16 +2385,16 @@ static void GetPartitions(SelectExpression selectExpression, SqlExpression sqlEx
|| sqlBinaryExpression.OperatorType == ExpressionType.LessThan
|| sqlBinaryExpression.OperatorType == ExpressionType.LessThanOrEqual)))
{
if (IsContainedColumn(outer, sqlBinaryExpression.Left)
&& IsContainedColumn(inner, sqlBinaryExpression.Right))
if (IsContainedSql(outer, sqlBinaryExpression.Left)
&& IsContainedSql(inner, sqlBinaryExpression.Right))
{
outerColumnExpressions.Add(sqlBinaryExpression.Left);

return sqlBinaryExpression;
}

if (IsContainedColumn(outer, sqlBinaryExpression.Right)
&& IsContainedColumn(inner, sqlBinaryExpression.Left))
if (IsContainedSql(outer, sqlBinaryExpression.Right)
&& IsContainedSql(inner, sqlBinaryExpression.Left))
{
outerColumnExpressions.Add(sqlBinaryExpression.Right);

Expand All @@ -2410,14 +2410,14 @@ static void GetPartitions(SelectExpression selectExpression, SqlExpression sqlEx
// null checks are considered part of join key
if (sqlBinaryExpression.OperatorType == ExpressionType.NotEqual)
{
if (IsContainedColumn(outer, sqlBinaryExpression.Left)
if (IsContainedSql(outer, sqlBinaryExpression.Left)
&& sqlBinaryExpression.Right is SqlConstantExpression rightConstant
&& rightConstant.Value == null)
{
return sqlBinaryExpression;
}

if (IsContainedColumn(outer, sqlBinaryExpression.Right)
if (IsContainedSql(outer, sqlBinaryExpression.Right)
&& sqlBinaryExpression.Left is SqlConstantExpression leftConstant
&& leftConstant.Value == null)
{
Expand All @@ -2430,8 +2430,29 @@ static void GetPartitions(SelectExpression selectExpression, SqlExpression sqlEx
return null;
}

static bool IsContainedColumn(SelectExpression selectExpression, SqlExpression sqlExpression)
static bool IsContainedSql(SelectExpression selectExpression, SqlExpression sqlExpression)
{
if (!AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue26756", out var enabled) || !enabled)
{
switch (sqlExpression)
{
case ColumnExpression columnExpression:
return selectExpression.ContainsTableReference(columnExpression);

case CaseExpression caseExpression
when caseExpression.ElseResult == null
&& caseExpression.Operand == null
&& caseExpression.WhenClauses.Count == 1
&& caseExpression.WhenClauses[0].Result is ColumnExpression resultColumn:
// We check condition in a separate function to avoid matching structure of condition outside of case block
return IsContainedCondition(selectExpression, caseExpression.WhenClauses[0].Test)
&& selectExpression.ContainsTableReference(resultColumn);

default:
return false;
}
}

switch (sqlExpression)
{
case ColumnExpression columnExpression:
Expand All @@ -2445,21 +2466,41 @@ static bool IsContainedColumn(SelectExpression selectExpression, SqlExpression s
when sqlBinaryExpression.OperatorType == ExpressionType.AndAlso
|| sqlBinaryExpression.OperatorType == ExpressionType.OrElse
|| sqlBinaryExpression.OperatorType == ExpressionType.NotEqual:
return IsContainedColumn(selectExpression, sqlBinaryExpression.Left)
&& IsContainedColumn(selectExpression, sqlBinaryExpression.Right);
return IsContainedSql(selectExpression, sqlBinaryExpression.Left)
&& IsContainedSql(selectExpression, sqlBinaryExpression.Right);

case CaseExpression caseExpression
when caseExpression.ElseResult == null
&& caseExpression.Operand == null
&& caseExpression.WhenClauses.Count == 1:
return IsContainedColumn(selectExpression, caseExpression.WhenClauses[0].Test)
&& IsContainedColumn(selectExpression, caseExpression.WhenClauses[0].Result);
return IsContainedSql(selectExpression, caseExpression.WhenClauses[0].Test)
&& IsContainedSql(selectExpression, caseExpression.WhenClauses[0].Result);

default:
return false;
}
}

static bool IsContainedCondition(SelectExpression selectExpression, SqlExpression condition)
{
if (condition is not SqlBinaryExpression
{ OperatorType: ExpressionType.AndAlso or ExpressionType.OrElse or ExpressionType.NotEqual } sqlBinaryExpression)
{
return false;
}

if (sqlBinaryExpression.OperatorType == ExpressionType.NotEqual)
{
// We don't check left/right inverted because we generate this.
return sqlBinaryExpression.Right is SqlConstantExpression { Value: null }
&& sqlBinaryExpression.Left is ColumnExpression column
&& selectExpression.ContainsTableReference(column);
}

return IsContainedCondition(selectExpression, sqlBinaryExpression.Left)
&& IsContainedCondition(selectExpression, sqlBinaryExpression.Right);
}

static void PopulateInnerKeyColumns(
IEnumerable<TableExpressionBase> tables,
SqlExpression joinPredicate,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -913,7 +913,7 @@ public virtual async Task Subquery_first_member_compared_to_null(bool async)
Assert.Single(result);
}

[ConditionalTheory(Skip = "Issue#26756")]
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task SelectMany_where_Select(bool async)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6051,7 +6051,7 @@ public override async Task Null_checks_in_correlated_predicate_are_correctly_tra
AssertSql(
@"SELECT [t].[Id], [g].[Nickname], [g].[SquadId], [g].[AssignedCityName], [g].[CityOfBirthName], [g].[Discriminator], [g].[FullName], [g].[HasSoulPatch], [g].[LeaderNickname], [g].[LeaderSquadId], [g].[Rank]
FROM [Tags] AS [t]
LEFT JOIN [Gears] AS [g] ON ([t].[GearNickName] = [g].[Nickname]) AND ([t].[GearSquadId] = [g].[SquadId])
LEFT JOIN [Gears] AS [g] ON (([t].[GearNickName] = [g].[Nickname]) AND ([t].[GearSquadId] = [g].[SquadId])) AND ([t].[Note] IS NOT NULL)
ORDER BY [t].[Id], [g].[Nickname]");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1999,7 +1999,7 @@ LEFT JOIN (
SELECT [c0].[CustomerID], [c0].[Address], [c0].[City], [c0].[CompanyName], [c0].[ContactName], [c0].[ContactTitle], [c0].[Country], [c0].[Fax], [c0].[Phone], [c0].[PostalCode], [c0].[Region]
FROM [Customers] AS [c0]
WHERE [c0].[CustomerID] LIKE N'A%'
) AS [t0] ON ([t].[City] = [t0].[City]) OR (([t].[City] IS NULL) AND ([t0].[City] IS NULL))
) AS [t0] ON [t].[City] = [t0].[City]
ORDER BY [t].[City]");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,15 +122,15 @@ public override async Task Include_query(bool async)
SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region], [t0].[OrderID], [t0].[CustomerID], [t0].[EmployeeID], [t0].[OrderDate], [t0].[CustomerID0]
FROM [Customers] AS [c]
LEFT JOIN (
SELECT [o].[OrderID], [o].[CustomerID], [o].[EmployeeID], [o].[OrderDate], [t].[CustomerID] AS [CustomerID0], [t].[CompanyName]
SELECT [o].[OrderID], [o].[CustomerID], [o].[EmployeeID], [o].[OrderDate], [t].[CustomerID] AS [CustomerID0]
FROM [Orders] AS [o]
LEFT JOIN (
SELECT [c0].[CustomerID], [c0].[CompanyName]
FROM [Customers] AS [c0]
WHERE (@__ef_filter__TenantPrefix_0 = N'') OR (([c0].[CompanyName] IS NOT NULL) AND (LEFT([c0].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) = @__ef_filter__TenantPrefix_0))
) AS [t] ON [o].[CustomerID] = [t].[CustomerID]
WHERE [t].[CustomerID] IS NOT NULL
) AS [t0] ON ([t0].[CompanyName] IS NOT NULL) AND ([c].[CustomerID] = [t0].[CustomerID])
WHERE ([t].[CustomerID] IS NOT NULL) AND ([t].[CompanyName] IS NOT NULL)
) AS [t0] ON [c].[CustomerID] = [t0].[CustomerID]
WHERE (@__ef_filter__TenantPrefix_0 = N'') OR (([c].[CompanyName] IS NOT NULL) AND (LEFT([c].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) = @__ef_filter__TenantPrefix_0))
ORDER BY [c].[CustomerID], [t0].[OrderID]");
}
Expand Down Expand Up @@ -197,15 +197,15 @@ public override async Task Navs_query(bool async)
SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region]
FROM [Customers] AS [c]
INNER JOIN (
SELECT [o].[OrderID], [o].[CustomerID], [t].[CompanyName]
SELECT [o].[OrderID], [o].[CustomerID]
FROM [Orders] AS [o]
LEFT JOIN (
SELECT [c0].[CustomerID], [c0].[CompanyName]
FROM [Customers] AS [c0]
WHERE (@__ef_filter__TenantPrefix_0 = N'') OR (([c0].[CompanyName] IS NOT NULL) AND (LEFT([c0].[CompanyName], LEN(@__ef_filter__TenantPrefix_0)) = @__ef_filter__TenantPrefix_0))
) AS [t] ON [o].[CustomerID] = [t].[CustomerID]
WHERE [t].[CustomerID] IS NOT NULL
) AS [t0] ON ([t0].[CompanyName] IS NOT NULL) AND ([c].[CustomerID] = [t0].[CustomerID])
WHERE ([t].[CustomerID] IS NOT NULL) AND ([t].[CompanyName] IS NOT NULL)
) AS [t0] ON [c].[CustomerID] = [t0].[CustomerID]
INNER JOIN (
SELECT [o0].[OrderID], [o0].[Discount]
FROM [Order Details] AS [o0]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -289,11 +289,12 @@ FROM [Parents] AS [p]
INNER JOIN (
SELECT [t].[ParentId], [t].[SomeNullableDateTime], [t].[SomeOtherNullableDateTime]
FROM (
SELECT [c].[ParentId], [c].[SomeNullableDateTime], [c].[SomeOtherNullableDateTime], ROW_NUMBER() OVER(PARTITION BY [c].[ParentId], [c].[SomeNullableDateTime] ORDER BY [c].[SomeInteger]) AS [row]
FROM [Child] AS [c]
SELECT [c].[ParentId], [c].[SomeNullableDateTime], [c].[SomeOtherNullableDateTime], ROW_NUMBER() OVER(PARTITION BY [c].[ParentId] ORDER BY [c].[SomeInteger]) AS [row]
FROM [Child26744] AS [c]
WHERE [c].[SomeNullableDateTime] IS NULL
) AS [t]
WHERE [t].[row] <= 1
) AS [t0] ON ([p].[Id] = [t0].[ParentId]) AND [t0].[SomeNullableDateTime] IS NULL
) AS [t0] ON [p].[Id] = [t0].[ParentId]
WHERE [t0].[SomeOtherNullableDateTime] IS NOT NULL");
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7085,7 +7085,7 @@ WHEN [o].[Nickname] IS NOT NULL THEN N'Officer'
END AS [Discriminator]
FROM [Gears] AS [g]
LEFT JOIN [Officers] AS [o] ON ([g].[Nickname] = [o].[Nickname]) AND ([g].[SquadId] = [o].[SquadId])
) AS [t0] ON ([t].[GearNickName] = [t0].[Nickname]) AND ([t].[GearSquadId] = [t0].[SquadId])
) AS [t0] ON (([t].[GearNickName] = [t0].[Nickname]) AND ([t].[GearSquadId] = [t0].[SquadId])) AND ([t].[Note] IS NOT NULL)
ORDER BY [t].[Id], [t0].[Nickname]");
}

Expand Down

0 comments on commit 17b2756

Please sign in to comment.