Skip to content

Commit

Permalink
Fix bugs with rows affected and value propagation for sprocs
Browse files Browse the repository at this point in the history
Fixes #28997
  • Loading branch information
roji committed Sep 7, 2022
1 parent 706a3e8 commit 5217cfe
Show file tree
Hide file tree
Showing 19 changed files with 236 additions and 267 deletions.
35 changes: 22 additions & 13 deletions src/EFCore.Relational/Infrastructure/RelationalModelValidator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -359,9 +359,9 @@ private static void ValidateSproc(IStoredProcedure sproc, string mappingStrategy
var storeGeneratedProperties = storeObjectIdentifier.StoreObjectType switch
{
StoreObjectType.InsertStoredProcedure
=> properties.Where(p => (p.Value.ValueGenerated & ValueGenerated.OnAdd) != 0).ToDictionary(p => p.Key, p => p.Value),
=> properties.Where(p => p.Value.ValueGenerated.HasFlag(ValueGenerated.OnAdd)).ToDictionary(p => p.Key, p => p.Value),
StoreObjectType.UpdateStoredProcedure
=> properties.Where(p => (p.Value.ValueGenerated & ValueGenerated.OnUpdate) != 0).ToDictionary(p => p.Key, p => p.Value),
=> properties.Where(p => p.Value.ValueGenerated.HasFlag(ValueGenerated.OnUpdate)).ToDictionary(p => p.Key, p => p.Value),
_ => new Dictionary<string, IProperty>()
};

Expand Down Expand Up @@ -574,18 +574,27 @@ private static void ValidateSproc(IStoredProcedure sproc, string mappingStrategy
}
}

var missedConcurrencyToken = originalValueProperties.Values.FirstOrDefault(p => p.IsConcurrencyToken);
if (missedConcurrencyToken != null
&& storeObjectIdentifier.StoreObjectType != StoreObjectType.InsertStoredProcedure
&& (sproc.IsRowsAffectedReturned
|| sproc.FindRowsAffectedParameter() != null
|| sproc.FindRowsAffectedResultColumn() != null))
if (sproc.IsRowsAffectedReturned
|| sproc.FindRowsAffectedParameter() != null
|| sproc.FindRowsAffectedResultColumn() != null)
{
throw new InvalidOperationException(
RelationalStrings.StoredProcedureConcurrencyTokenNotMapped(
entityType.DisplayName(),
storeObjectIdentifier.DisplayName(),
missedConcurrencyToken.Name));
if (originalValueProperties.Values.FirstOrDefault(p => p.IsConcurrencyToken) is { } missedConcurrencyToken
&& storeObjectIdentifier.StoreObjectType != StoreObjectType.InsertStoredProcedure)
{
throw new InvalidOperationException(
RelationalStrings.StoredProcedureConcurrencyTokenNotMapped(
entityType.DisplayName(),
storeObjectIdentifier.DisplayName(),
missedConcurrencyToken.Name));
}

if (sproc.ResultColumns.Any(c => c != sproc.FindRowsAffectedResultColumn()))
{
throw new InvalidOperationException(
RelationalStrings.StoredProcedureRowsAffectedWithResultColumns(
entityType.DisplayName(),
storeObjectIdentifier.DisplayName()));
}
}
}

Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions src/EFCore.Relational/Properties/RelationalStrings.resx
Original file line number Diff line number Diff line change
Expand Up @@ -1051,6 +1051,9 @@
<data name="StoredProcedureRowsAffectedReturnConflictingParameter" xml:space="preserve">
<value>The stored procedure '{sproc}' cannot be configured to return the rows affected because a rows affected parameter or a rows affected result column for this stored procedure already exists.</value>
</data>
<data name="StoredProcedureRowsAffectedWithResultColumns" xml:space="preserve">
<value>The entity type '{entityType}' is mapped to the stored procedure '{sproc}' which returns both result columns and a rows affected value. If the stored procedure returns result columns, a rows affected value isn't needed and can be safely removed.</value>
</data>
<data name="StoredProcedureTableSharing" xml:space="preserve">
<value>Both entity type '{entityType1}' and '{entityType2}' were configured to use '{sproc}', stored procedure sharing is not supported. Specify different names for the corresponding stored procedures.</value>
</data>
Expand Down
100 changes: 50 additions & 50 deletions src/EFCore.Relational/Update/AffectedCountModificationCommandBatch.cs
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ protected AffectedCountModificationCommandBatch(ModificationCommandBatchFactoryD
protected override void Consume(RelationalDataReader reader)
{
Check.DebugAssert(
CommandResultSet.Count == ModificationCommands.Count,
$"CommandResultSet.Count of {CommandResultSet.Count} != ModificationCommands.Count of {ModificationCommands.Count}");
ResultSetMappings.Count == ModificationCommands.Count,
$"CommandResultSet.Count of {ResultSetMappings.Count} != ModificationCommands.Count of {ModificationCommands.Count}");

var commandIndex = 0;

Expand All @@ -44,10 +44,9 @@ protected override void Consume(RelationalDataReader reader)
bool? onResultSet = null;
var hasOutputParameters = false;

for (; commandIndex < CommandResultSet.Count; commandIndex++)
for (; commandIndex < ResultSetMappings.Count; commandIndex++)
{
var resultSetMapping = CommandResultSet[commandIndex];
var command = ModificationCommands[commandIndex];
var resultSetMapping = ResultSetMappings[commandIndex];

if (resultSetMapping.HasFlag(ResultSetMapping.HasResultRow))
{
Expand All @@ -56,9 +55,9 @@ protected override void Consume(RelationalDataReader reader)
throw new InvalidOperationException(RelationalStrings.MissingResultSetWhenSaving);
}

var lastHandledCommandIndex = command.RequiresResultPropagation
? ConsumeResultSetWithPropagation(commandIndex, reader)
: ConsumeResultSetWithoutPropagation(commandIndex, reader);
var lastHandledCommandIndex = resultSetMapping.HasFlag(ResultSetMapping.ResultSetWithRowsAffectedOnly)
? ConsumeResultSetWithRowsAffectedOnly(commandIndex, reader)
: ConsumeResultSet(commandIndex, reader);

Check.DebugAssert(resultSetMapping.HasFlag(ResultSetMapping.LastInResultSet)
? lastHandledCommandIndex == commandIndex
Expand Down Expand Up @@ -88,12 +87,12 @@ protected override void Consume(RelationalDataReader reader)
IReadOnlyModificationCommand command;

for (commandIndex = 0;
commandIndex < CommandResultSet.Count;
commandIndex < ResultSetMappings.Count;
commandIndex++, parameterCounter += command.StoreStoredProcedure!.Parameters.Count)
{
command = ModificationCommands[commandIndex];

if (!CommandResultSet[commandIndex].HasFlag(ResultSetMapping.HasOutputParameters))
if (!ResultSetMappings[commandIndex].HasFlag(ResultSetMapping.HasOutputParameters))
{
continue;
}
Expand Down Expand Up @@ -124,10 +123,7 @@ protected override void Consume(RelationalDataReader reader)
}
}

if (command.RequiresResultPropagation)
{
command.PropagateOutputParameters(reader.DbCommand.Parameters, parameterCounter);
}
command.PropagateOutputParameters(reader.DbCommand.Parameters, parameterCounter);
}
}
}
Expand All @@ -152,8 +148,8 @@ protected override async Task ConsumeAsync(
CancellationToken cancellationToken = default)
{
Check.DebugAssert(
CommandResultSet.Count == ModificationCommands.Count,
$"CommandResultSet.Count of {CommandResultSet.Count} != ModificationCommands.Count of {ModificationCommands.Count}");
ResultSetMappings.Count == ModificationCommands.Count,
$"CommandResultSet.Count of {ResultSetMappings.Count} != ModificationCommands.Count of {ModificationCommands.Count}");

var commandIndex = 0;

Expand All @@ -162,10 +158,9 @@ protected override async Task ConsumeAsync(
bool? onResultSet = null;
var hasOutputParameters = false;

for (; commandIndex < CommandResultSet.Count; commandIndex++)
for (; commandIndex < ResultSetMappings.Count; commandIndex++)
{
var resultSetMapping = CommandResultSet[commandIndex];
var command = ModificationCommands[commandIndex];
var resultSetMapping = ResultSetMappings[commandIndex];

if (resultSetMapping.HasFlag(ResultSetMapping.HasResultRow))
{
Expand All @@ -174,9 +169,9 @@ protected override async Task ConsumeAsync(
throw new InvalidOperationException(RelationalStrings.MissingResultSetWhenSaving);
}

var lastHandledCommandIndex = command.RequiresResultPropagation
? await ConsumeResultSetWithPropagationAsync(commandIndex, reader, cancellationToken).ConfigureAwait(false)
: await ConsumeResultSetWithoutPropagationAsync(commandIndex, reader, cancellationToken).ConfigureAwait(false);
var lastHandledCommandIndex = resultSetMapping.HasFlag(ResultSetMapping.ResultSetWithRowsAffectedOnly)
? await ConsumeResultSetWithRowsAffectedOnlyAsync(commandIndex, reader, cancellationToken).ConfigureAwait(false)
: await ConsumeResultSetAsync(commandIndex, reader, cancellationToken).ConfigureAwait(false);

Check.DebugAssert(resultSetMapping.HasFlag(ResultSetMapping.LastInResultSet)
? lastHandledCommandIndex == commandIndex
Expand Down Expand Up @@ -206,12 +201,12 @@ protected override async Task ConsumeAsync(
IReadOnlyModificationCommand command;

for (commandIndex = 0;
commandIndex < CommandResultSet.Count;
commandIndex < ResultSetMappings.Count;
commandIndex++, parameterCounter += command.StoreStoredProcedure!.Parameters.Count)
{
command = ModificationCommands[commandIndex];

if (!CommandResultSet[commandIndex].HasFlag(ResultSetMapping.HasOutputParameters))
if (!ResultSetMappings[commandIndex].HasFlag(ResultSetMapping.HasOutputParameters))
{
continue;
}
Expand Down Expand Up @@ -243,10 +238,7 @@ await ThrowAggregateUpdateConcurrencyExceptionAsync(
}
}

if (command.RequiresResultPropagation)
{
command.PropagateOutputParameters(reader.DbCommand.Parameters, parameterCounter);
}
command.PropagateOutputParameters(reader.DbCommand.Parameters, parameterCounter);
}
}
}
Expand All @@ -266,7 +258,7 @@ await ThrowAggregateUpdateConcurrencyExceptionAsync(
/// <param name="startCommandIndex">The ordinal of the first command being consumed.</param>
/// <param name="reader">The data reader.</param>
/// <returns>The ordinal of the next result set that must be consumed.</returns>
protected virtual int ConsumeResultSetWithPropagation(int startCommandIndex, RelationalDataReader reader)
protected virtual int ConsumeResultSet(int startCommandIndex, RelationalDataReader reader)
{
var commandIndex = startCommandIndex;
var rowsAffected = 0;
Expand All @@ -275,8 +267,8 @@ protected virtual int ConsumeResultSetWithPropagation(int startCommandIndex, Rel
if (!reader.Read())
{
var expectedRowsAffected = rowsAffected + 1;
while (++commandIndex < CommandResultSet.Count
&& CommandResultSet[commandIndex - 1].HasFlag(ResultSetMapping.NotLastInResultSet))
while (++commandIndex < ResultSetMappings.Count
&& ResultSetMappings[commandIndex - 1].HasFlag(ResultSetMapping.NotLastInResultSet))
{
expectedRowsAffected++;
}
Expand All @@ -285,22 +277,24 @@ protected virtual int ConsumeResultSetWithPropagation(int startCommandIndex, Rel
}
else
{
var resultSetMapping = CommandResultSet[commandIndex];
var resultSetMapping = ResultSetMappings[commandIndex];

var command = ModificationCommands[
resultSetMapping.HasFlag(ResultSetMapping.IsPositionalResultMappingEnabled)
? startCommandIndex + reader.DbDataReader.GetInt32(reader.DbDataReader.FieldCount - 1)
: commandIndex];

Check.DebugAssert(command.RequiresResultPropagation, "RequiresResultPropagation is false");
Check.DebugAssert(
!resultSetMapping.HasFlag(ResultSetMapping.ResultSetWithRowsAffectedOnly),
"!resultSetMapping.HasFlag(ResultSetMapping.ResultSetWithRowsAffectedOnly)");

command.PropagateResults(reader);
}

rowsAffected++;
}
while (++commandIndex < CommandResultSet.Count
&& CommandResultSet[commandIndex - 1].HasFlag(ResultSetMapping.NotLastInResultSet));
while (++commandIndex < ResultSetMappings.Count
&& ResultSetMappings[commandIndex - 1].HasFlag(ResultSetMapping.NotLastInResultSet));

return commandIndex - 1;
}
Expand All @@ -317,7 +311,7 @@ protected virtual int ConsumeResultSetWithPropagation(int startCommandIndex, Rel
/// The task contains the ordinal of the next command that must be consumed.
/// </returns>
/// <exception cref="OperationCanceledException">If the <see cref="CancellationToken" /> is canceled.</exception>
protected virtual async Task<int> ConsumeResultSetWithPropagationAsync(
protected virtual async Task<int> ConsumeResultSetAsync(
int startCommandIndex,
RelationalDataReader reader,
CancellationToken cancellationToken)
Expand All @@ -329,8 +323,8 @@ protected virtual async Task<int> ConsumeResultSetWithPropagationAsync(
if (!await reader.ReadAsync(cancellationToken).ConfigureAwait(false))
{
var expectedRowsAffected = rowsAffected + 1;
while (++commandIndex < CommandResultSet.Count
&& CommandResultSet[commandIndex - 1].HasFlag(ResultSetMapping.NotLastInResultSet))
while (++commandIndex < ResultSetMappings.Count
&& ResultSetMappings[commandIndex - 1].HasFlag(ResultSetMapping.NotLastInResultSet))
{
expectedRowsAffected++;
}
Expand All @@ -340,22 +334,24 @@ await ThrowAggregateUpdateConcurrencyExceptionAsync(
}
else
{
var resultSetMapping = CommandResultSet[commandIndex];
var resultSetMapping = ResultSetMappings[commandIndex];

var command = ModificationCommands[
resultSetMapping.HasFlag(ResultSetMapping.IsPositionalResultMappingEnabled)
? startCommandIndex + reader.DbDataReader.GetInt32(reader.DbDataReader.FieldCount - 1)
: commandIndex];

Check.DebugAssert(command.RequiresResultPropagation, "RequiresResultPropagation is false");
Check.DebugAssert(
!resultSetMapping.HasFlag(ResultSetMapping.ResultSetWithRowsAffectedOnly),
"!resultSetMapping.HasFlag(ResultSetMapping.ResultSetWithRowsAffectedOnly)");

command.PropagateResults(reader);
}

rowsAffected++;
}
while (++commandIndex < CommandResultSet.Count
&& CommandResultSet[commandIndex - 1].HasFlag(ResultSetMapping.NotLastInResultSet));
while (++commandIndex < ResultSetMappings.Count
&& ResultSetMappings[commandIndex - 1].HasFlag(ResultSetMapping.NotLastInResultSet));

return commandIndex - 1;
}
Expand All @@ -367,13 +363,15 @@ await ThrowAggregateUpdateConcurrencyExceptionAsync(
/// <param name="commandIndex">The ordinal of the command being consumed.</param>
/// <param name="reader">The data reader.</param>
/// <returns>The ordinal of the next command that must be consumed.</returns>
protected virtual int ConsumeResultSetWithoutPropagation(int commandIndex, RelationalDataReader reader)
protected virtual int ConsumeResultSetWithRowsAffectedOnly(int commandIndex, RelationalDataReader reader)
{
var expectedRowsAffected = 1;
while (++commandIndex < CommandResultSet.Count
&& CommandResultSet[commandIndex - 1].HasFlag(ResultSetMapping.NotLastInResultSet))
while (++commandIndex < ResultSetMappings.Count
&& ResultSetMappings[commandIndex - 1].HasFlag(ResultSetMapping.NotLastInResultSet))
{
Check.DebugAssert(!ModificationCommands[commandIndex].RequiresResultPropagation, "RequiresResultPropagation is true");
Check.DebugAssert(
ResultSetMappings[commandIndex].HasFlag(ResultSetMapping.ResultSetWithRowsAffectedOnly),
"ResultSetMappings[commandIndex].HasFlag(ResultSetMapping.ResultSetWithRowsAffectedOnly)");

expectedRowsAffected++;
}
Expand Down Expand Up @@ -406,16 +404,18 @@ protected virtual int ConsumeResultSetWithoutPropagation(int commandIndex, Relat
/// The task contains the ordinal of the next command that must be consumed.
/// </returns>
/// <exception cref="OperationCanceledException">If the <see cref="CancellationToken" /> is canceled.</exception>
protected virtual async Task<int> ConsumeResultSetWithoutPropagationAsync(
protected virtual async Task<int> ConsumeResultSetWithRowsAffectedOnlyAsync(
int commandIndex,
RelationalDataReader reader,
CancellationToken cancellationToken)
{
var expectedRowsAffected = 1;
while (++commandIndex < CommandResultSet.Count
&& CommandResultSet[commandIndex - 1].HasFlag(ResultSetMapping.NotLastInResultSet))
while (++commandIndex < ResultSetMappings.Count
&& ResultSetMappings[commandIndex - 1].HasFlag(ResultSetMapping.NotLastInResultSet))
{
Check.DebugAssert(!ModificationCommands[commandIndex].RequiresResultPropagation, "RequiresResultPropagation is true");
Check.DebugAssert(
ResultSetMappings[commandIndex].HasFlag(ResultSetMapping.ResultSetWithRowsAffectedOnly),
"ResultSetMappings[commandIndex].HasFlag(ResultSetMapping.ResultSetWithRowsAffectedOnly)");

expectedRowsAffected++;
}
Expand Down
Loading

0 comments on commit 5217cfe

Please sign in to comment.