Skip to content

Commit

Permalink
ExecuteUpdate: Correctly identify when we need to cause a subquery join
Browse files Browse the repository at this point in the history
Resolves #28823
  • Loading branch information
smitpatel committed Aug 23, 2022
1 parent 59a3605 commit b0a0cd2
Show file tree
Hide file tree
Showing 5 changed files with 155 additions and 11 deletions.
5 changes: 4 additions & 1 deletion src/EFCore.Relational/Query/QuerySqlGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1266,7 +1266,10 @@ protected override Expression VisitUpdate(UpdateExpression updateExpression)
&& selectExpression.Orderings.Count == 0
&& selectExpression.GroupBy.Count == 0
&& selectExpression.Projection.Count == 0
&& selectExpression.Tables.All(e => !(e is LeftJoinExpression || e is OuterApplyExpression)))
&& (selectExpression.Tables.Count == 1
|| !ReferenceEquals(selectExpression.Tables[0], updateExpression.Table)
|| selectExpression.Tables[1] is InnerJoinExpression
|| selectExpression.Tables[1] is CrossJoinExpression))
{
_relationalCommandBuilder.Append("UPDATE ");
Visit(updateExpression.Table);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1424,14 +1424,15 @@ protected virtual bool IsValidSelectExpressionForExecuteUpdate(
EntityShaperExpression entityShaperExpression,
[NotNullWhen(true)] out TableExpression? tableExpression)
{
tableExpression = null;
if (selectExpression.Offset == null
&& selectExpression.Limit == null
// If entity type has primary key then Distinct is no-op
&& (!selectExpression.IsDistinct || entityShaperExpression.EntityType.FindPrimaryKey() != null)
&& selectExpression.GroupBy.Count == 0
&& selectExpression.Having == null
&& selectExpression.Orderings.Count == 0
&& selectExpression.Tables.All(e => !(e is LeftJoinExpression || e is OuterApplyExpression)))
&& selectExpression.Tables.Count > 0)
{
TableExpressionBase table;
if (selectExpression.Tables.Count == 1)
Expand All @@ -1444,6 +1445,15 @@ protected virtual bool IsValidSelectExpressionForExecuteUpdate(
var entityProjectionExpression = (EntityProjectionExpression)selectExpression.GetProjection(projectionBindingExpression);
var column = entityProjectionExpression.BindProperty(entityShaperExpression.EntityType.GetProperties().First());
table = column.Table;
if (ReferenceEquals(selectExpression.Tables[0], table))
{
// If the table we are looking for it first table, then we need to verify if we can lift the next table in FROM clause
var secondTable = selectExpression.Tables[1];
if (secondTable is not InnerJoinExpression and not CrossJoinExpression)
{
return false;
}
}
if (table is JoinExpressionBase joinExpressionBase)
{
table = joinExpressionBase.Table;
Expand All @@ -1457,7 +1467,6 @@ protected virtual bool IsValidSelectExpressionForExecuteUpdate(
}
}

tableExpression = null;
return false;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -866,6 +866,50 @@ from o in ss.Set<Order>().Where(o => o.OrderID < 10300 && o.OrderDate.Value.Year
rowsAffectedCount: 8,
(b, a) => Assert.All(a, c => Assert.Equal("Updated", c.ContactName)));
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Update_with_cross_join_left_join_set_constant(bool async)
=> AssertUpdate(
async,
ss => from c in ss.Set<Customer>().Where(c => c.CustomerID.StartsWith("F"))
from c2 in ss.Set<Customer>().Where(c => c.City.StartsWith("S"))
join o in ss.Set<Order>().Where(o => o.OrderID < 10300)
on c.CustomerID equals o.CustomerID into grouping
from o in grouping.DefaultIfEmpty()
select new { c, c2, o },
e => e.c,
s => s.SetProperty(c => c.c.ContactName, c => "Updated"),
rowsAffectedCount: 8,
(b, a) => Assert.All(a, c => Assert.Equal("Updated", c.ContactName)));
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Update_with_cross_join_cross_apply_set_constant(bool async)
=> AssertUpdate(
async,
ss => from c in ss.Set<Customer>().Where(c => c.CustomerID.StartsWith("F"))
from c2 in ss.Set<Customer>().Where(c => c.City.StartsWith("S"))
from o in ss.Set<Order>().Where(o => o.OrderID < 10300 && o.OrderDate.Value.Year < c.ContactName.Length)
select new { c, o },
e => e.c,
s => s.SetProperty(c => c.c.ContactName, c => "Updated"),
rowsAffectedCount: 0,
(b, a) => Assert.All(a, c => Assert.Equal("Updated", c.ContactName)));

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Update_with_cross_join_outer_apply_set_constant(bool async)
=> AssertUpdate(
async,
ss => from c in ss.Set<Customer>().Where(c => c.CustomerID.StartsWith("F"))
from c2 in ss.Set<Customer>().Where(c => c.City.StartsWith("S"))
from o in ss.Set<Order>().Where(o => o.OrderID < 10300 && o.OrderDate.Value.Year < c.ContactName.Length).DefaultIfEmpty()
select new { c, c2, o },
e => e.c,
s => s.SetProperty(c => c.c.ContactName, c => "Updated"),
rowsAffectedCount: 8,
(b, a) => Assert.All(a, c => Assert.Equal("Updated", c.ContactName)));

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task Update_FromSql_set_constant(bool async)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1138,6 +1138,69 @@ FROM [Orders] AS [o]
WHERE [c].[CustomerID] LIKE N'F%'");
}

public override async Task Update_with_cross_join_left_join_set_constant(bool async)
{
await base.Update_with_cross_join_left_join_set_constant(async);

AssertExecuteUpdateSql(
@"UPDATE [c]
SET [c].[ContactName] = N'Updated'
FROM [Customers] AS [c]
CROSS JOIN (
SELECT [c0].[CustomerID], [c0].[Address], [c0].[City], [c0].[CompanyName], [c0].[ContactName], [c0].[ContactTitle], [c0].[Country], [c0].[Fax], [c0].[Phone], [c0].[PostalCode], [c0].[Region]
FROM [Customers] AS [c0]
WHERE [c0].[City] IS NOT NULL AND ([c0].[City] LIKE N'S%')
) AS [t]
LEFT JOIN (
SELECT [o].[OrderID], [o].[CustomerID], [o].[EmployeeID], [o].[OrderDate]
FROM [Orders] AS [o]
WHERE [o].[OrderID] < 10300
) AS [t0] ON [c].[CustomerID] = [t0].[CustomerID]
WHERE [c].[CustomerID] LIKE N'F%'");
}

public override async Task Update_with_cross_join_cross_apply_set_constant(bool async)
{
await base.Update_with_cross_join_cross_apply_set_constant(async);

AssertExecuteUpdateSql(
@"UPDATE [c]
SET [c].[ContactName] = N'Updated'
FROM [Customers] AS [c]
CROSS JOIN (
SELECT [c0].[CustomerID], [c0].[Address], [c0].[City], [c0].[CompanyName], [c0].[ContactName], [c0].[ContactTitle], [c0].[Country], [c0].[Fax], [c0].[Phone], [c0].[PostalCode], [c0].[Region]
FROM [Customers] AS [c0]
WHERE [c0].[City] IS NOT NULL AND ([c0].[City] LIKE N'S%')
) AS [t]
CROSS APPLY (
SELECT [o].[OrderID], [o].[CustomerID], [o].[EmployeeID], [o].[OrderDate]
FROM [Orders] AS [o]
WHERE [o].[OrderID] < 10300 AND DATEPART(year, [o].[OrderDate]) < CAST(LEN([c].[ContactName]) AS int)
) AS [t0]
WHERE [c].[CustomerID] LIKE N'F%'");
}

public override async Task Update_with_cross_join_outer_apply_set_constant(bool async)
{
await base.Update_with_cross_join_outer_apply_set_constant(async);

AssertExecuteUpdateSql(
@"UPDATE [c]
SET [c].[ContactName] = N'Updated'
FROM [Customers] AS [c]
CROSS JOIN (
SELECT [c0].[CustomerID], [c0].[Address], [c0].[City], [c0].[CompanyName], [c0].[ContactName], [c0].[ContactTitle], [c0].[Country], [c0].[Fax], [c0].[Phone], [c0].[PostalCode], [c0].[Region]
FROM [Customers] AS [c0]
WHERE [c0].[City] IS NOT NULL AND ([c0].[City] LIKE N'S%')
) AS [t]
OUTER APPLY (
SELECT [o].[OrderID], [o].[CustomerID], [o].[EmployeeID], [o].[OrderDate]
FROM [Orders] AS [o]
WHERE [o].[OrderID] < 10300 AND DATEPART(year, [o].[OrderDate]) < CAST(LEN([c].[ContactName]) AS int)
) AS [t0]
WHERE [c].[CustomerID] LIKE N'F%'");
}

public override async Task Update_FromSql_set_constant(bool async)
{
await base.Update_FromSql_set_constant(async);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -840,14 +840,9 @@ public override async Task Update_Where_using_navigation_2_set_constant(bool asy
AssertExecuteUpdateSql(
@"UPDATE ""Order Details"" AS ""o""
SET ""Quantity"" = CAST(1 AS INTEGER)
FROM (
SELECT ""o0"".""OrderID"", ""o0"".""ProductID"", ""o0"".""Discount"", ""o0"".""Quantity"", ""o0"".""UnitPrice"", ""o1"".""OrderID"" AS ""OrderID0"", ""c"".""CustomerID""
FROM ""Order Details"" AS ""o0""
INNER JOIN ""Orders"" AS ""o1"" ON ""o0"".""OrderID"" = ""o1"".""OrderID""
LEFT JOIN ""Customers"" AS ""c"" ON ""o1"".""CustomerID"" = ""c"".""CustomerID""
WHERE ""c"".""City"" = 'Seattle'
) AS ""t""
WHERE ""o"".""OrderID"" = ""t"".""OrderID"" AND ""o"".""ProductID"" = ""t"".""ProductID""");
FROM ""Orders"" AS ""o0""
LEFT JOIN ""Customers"" AS ""c"" ON ""o0"".""CustomerID"" = ""c"".""CustomerID""
WHERE ""o"".""OrderID"" = ""o0"".""OrderID"" AND ""c"".""City"" = 'Seattle'");
}

public override async Task Update_Where_SelectMany_set_null(bool async)
Expand Down Expand Up @@ -1097,6 +1092,36 @@ public override async Task Update_with_outer_apply_set_constant(bool async)
SqliteStrings.ApplyNotSupported,
(await Assert.ThrowsAsync<InvalidOperationException>(() => base.Update_with_outer_apply_set_constant(async))).Message);

public override async Task Update_with_cross_join_left_join_set_constant(bool async)
{
await base.Update_with_cross_join_left_join_set_constant(async);

AssertExecuteUpdateSql(
@"UPDATE ""Customers"" AS ""c""
SET ""ContactName"" = 'Updated'
FROM (
SELECT ""c0"".""CustomerID"", ""c0"".""Address"", ""c0"".""City"", ""c0"".""CompanyName"", ""c0"".""ContactName"", ""c0"".""ContactTitle"", ""c0"".""Country"", ""c0"".""Fax"", ""c0"".""Phone"", ""c0"".""PostalCode"", ""c0"".""Region""
FROM ""Customers"" AS ""c0""
WHERE ""c0"".""City"" IS NOT NULL AND (""c0"".""City"" LIKE 'S%')
) AS ""t""
LEFT JOIN (
SELECT ""o"".""OrderID"", ""o"".""CustomerID"", ""o"".""EmployeeID"", ""o"".""OrderDate""
FROM ""Orders"" AS ""o""
WHERE ""o"".""OrderID"" < 10300
) AS ""t0"" ON ""c"".""CustomerID"" = ""t0"".""CustomerID""
WHERE ""c"".""CustomerID"" LIKE 'F%'");
}

public override async Task Update_with_cross_join_cross_apply_set_constant(bool async)
=> Assert.Equal(
SqliteStrings.ApplyNotSupported,
(await Assert.ThrowsAsync<InvalidOperationException>(() => base.Update_with_cross_join_cross_apply_set_constant(async))).Message);

public override async Task Update_with_cross_join_outer_apply_set_constant(bool async)
=> Assert.Equal(
SqliteStrings.ApplyNotSupported,
(await Assert.ThrowsAsync<InvalidOperationException>(() => base.Update_with_cross_join_outer_apply_set_constant(async))).Message);

public override async Task Update_FromSql_set_constant(bool async)
{
await base.Update_FromSql_set_constant(async);
Expand Down

0 comments on commit b0a0cd2

Please sign in to comment.