Skip to content

Commit

Permalink
Allow cancellation token to be passed as part of params array in Find…
Browse files Browse the repository at this point in the history
…Async

Fixes #22667
  • Loading branch information
ajcvickers committed Jul 7, 2022
1 parent 40cf1d0 commit e4b38db
Show file tree
Hide file tree
Showing 6 changed files with 392 additions and 221 deletions.
56 changes: 43 additions & 13 deletions src/EFCore/Internal/EntityFinder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,18 @@ public EntityFinder(
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
public virtual TEntity? Find(object?[]? keyValues)
=> keyValues == null || keyValues.Any(v => v == null)
? null
: (FindTracked(keyValues!, out var keyProperties)
?? _queryRoot.FirstOrDefault(BuildLambda(keyProperties, new ValueBuffer(keyValues))));
{
if (keyValues == null
|| keyValues.Any(v => v == null))
{
return default;
}

var (key, processedKeyValues, _) = ValidateKeyPropertiesAndExtractCancellationToken(keyValues!, async: false, default);

return FindTracked(key, processedKeyValues)
?? _queryRoot.FirstOrDefault(BuildLambda(key.Properties, new ValueBuffer(processedKeyValues)));
}

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
Expand All @@ -74,11 +82,13 @@ public EntityFinder(
return default;
}

var tracked = FindTracked(keyValues!, out var keyProperties);
var (key, processedKeyValues, ct) = ValidateKeyPropertiesAndExtractCancellationToken(keyValues!, async: true, cancellationToken);

var tracked = FindTracked(key, processedKeyValues);
return tracked != null
? new ValueTask<TEntity?>(tracked)
: new ValueTask<TEntity?>(
_queryRoot.FirstOrDefaultAsync(BuildLambda(keyProperties, new ValueBuffer(keyValues)), cancellationToken));
_queryRoot.FirstOrDefaultAsync(BuildLambda(key.Properties, new ValueBuffer(processedKeyValues)), ct));
}

/// <summary>
Expand All @@ -95,12 +105,14 @@ public EntityFinder(
return default;
}

var tracked = FindTracked(keyValues!, out var keyProperties);
var (key, processedKeyValues, ct) = ValidateKeyPropertiesAndExtractCancellationToken(keyValues!, async: true, cancellationToken);

var tracked = FindTracked(key, processedKeyValues);
return tracked != null
? new ValueTask<object?>(tracked)
: new ValueTask<object?>(
_queryRoot.FirstOrDefaultAsync(
BuildObjectLambda(keyProperties, new ValueBuffer(keyValues)), cancellationToken));
BuildObjectLambda(key.Properties, new ValueBuffer(processedKeyValues)), ct));
}

/// <summary>
Expand Down Expand Up @@ -259,23 +271,41 @@ private static IReadOnlyList<IProperty> GetLoadProperties(INavigation navigation
? navigation.ForeignKey.PrincipalKey.Properties
: navigation.ForeignKey.Properties;

private TEntity? FindTracked(object[] keyValues, out IReadOnlyList<IProperty> keyProperties)
private (IKey Key, object[] KeyValues,CancellationToken CancellationToken) ValidateKeyPropertiesAndExtractCancellationToken(
object[] keyValues,
bool async,
CancellationToken cancellationToken)
{
var key = _entityType.FindPrimaryKey()!;
keyProperties = key.Properties;
var keyPropertiesCount = key.Properties.Count;

if (keyProperties.Count != keyValues.Length)
if (keyPropertiesCount != keyValues.Length)
{
if (keyProperties.Count == 1)
if (async
&& keyPropertiesCount == keyValues.Length - 1
&& keyValues[keyPropertiesCount] is CancellationToken ct)
{
var newValues = new object[keyPropertiesCount];
Array.Copy(keyValues, newValues, keyPropertiesCount);
return (key, newValues, ct);
}

if (keyPropertiesCount == 1)
{
throw new ArgumentException(
CoreStrings.FindNotCompositeKey(typeof(TEntity).ShortDisplayName(), keyValues.Length));
}

throw new ArgumentException(
CoreStrings.FindValueCountMismatch(typeof(TEntity).ShortDisplayName(), keyProperties.Count, keyValues.Length));
CoreStrings.FindValueCountMismatch(typeof(TEntity).ShortDisplayName(), keyPropertiesCount, keyValues.Length));
}

return (key, keyValues, cancellationToken);
}

private TEntity? FindTracked(IKey key, object[] keyValues)
{
var keyProperties = key.Properties;
for (var i = 0; i < keyValues.Length; i++)
{
var valueType = keyValues[i].GetType();
Expand Down
26 changes: 7 additions & 19 deletions test/EFCore.Cosmos.FunctionalTests/FindCosmosTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@ protected FindCosmosTest(FindCosmosFixture fixture)
[ConditionalFact(Skip = "#25886")]
public override void Find_base_type_using_derived_set_tracked() { }

[ConditionalFact(Skip = "#25886")]
public override Task Find_base_type_using_derived_set_tracked_async()
[ConditionalTheory(Skip = "#25886")]
public override Task Find_base_type_using_derived_set_tracked_async(CancellationType cancellationType)
=> Task.CompletedTask;

[ConditionalFact(Skip = "#25886")]
public override void Find_derived_using_base_set_type_from_store() { }

[ConditionalFact(Skip = "#25886")]
public override Task Find_derived_using_base_set_type_from_store_async()
[ConditionalTheory(Skip = "#25886")]
public override Task Find_derived_using_base_set_type_from_store_async(CancellationType cancellationType)
=> Task.CompletedTask;

public class FindCosmosTestSet : FindCosmosTest
Expand All @@ -32,11 +32,7 @@ public FindCosmosTestSet(FindCosmosFixture fixture)
{
}

protected override TEntity Find<TEntity>(DbContext context, params object[] keyValues)
=> context.Set<TEntity>().Find(keyValues);

protected override ValueTask<TEntity> FindAsync<TEntity>(DbContext context, params object[] keyValues)
=> context.Set<TEntity>().FindAsync(keyValues);
protected override TestFinder Finder { get; } = new FindViaSetFinder();
}

public class FindCosmosTestContext : FindCosmosTest
Expand All @@ -46,11 +42,7 @@ public FindCosmosTestContext(FindCosmosFixture fixture)
{
}

protected override TEntity Find<TEntity>(DbContext context, params object[] keyValues)
=> context.Find<TEntity>(keyValues);

protected override ValueTask<TEntity> FindAsync<TEntity>(DbContext context, params object[] keyValues)
=> context.FindAsync<TEntity>(keyValues);
protected override TestFinder Finder { get; } = new FindViaContextFinder();
}

public class FindCosmosTestNonGeneric : FindCosmosTest
Expand All @@ -60,11 +52,7 @@ public FindCosmosTestNonGeneric(FindCosmosFixture fixture)
{
}

protected override TEntity Find<TEntity>(DbContext context, params object[] keyValues)
=> (TEntity)context.Find(typeof(TEntity), keyValues);

protected override async ValueTask<TEntity> FindAsync<TEntity>(DbContext context, params object[] keyValues)
=> (TEntity)await context.FindAsync(typeof(TEntity), keyValues);
protected override TestFinder Finder { get; } = new FindViaNonGenericContextFinder();
}

public class FindCosmosFixture : FindFixtureBase
Expand Down
18 changes: 3 additions & 15 deletions test/EFCore.InMemory.FunctionalTests/FindInMemoryTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,7 @@ public FindInMemoryTestSet(FindInMemoryFixture fixture)
{
}

protected override TEntity Find<TEntity>(DbContext context, params object[] keyValues)
=> context.Set<TEntity>().Find(keyValues);

protected override ValueTask<TEntity> FindAsync<TEntity>(DbContext context, params object[] keyValues)
=> context.Set<TEntity>().FindAsync(keyValues);
protected override TestFinder Finder { get; } = new FindViaSetFinder();
}

public class FindInMemoryTestContext : FindInMemoryTest
Expand All @@ -31,11 +27,7 @@ public FindInMemoryTestContext(FindInMemoryFixture fixture)
{
}

protected override TEntity Find<TEntity>(DbContext context, params object[] keyValues)
=> context.Find<TEntity>(keyValues);

protected override ValueTask<TEntity> FindAsync<TEntity>(DbContext context, params object[] keyValues)
=> context.FindAsync<TEntity>(keyValues);
protected override TestFinder Finder { get; } = new FindViaContextFinder();
}

public class FindInMemoryTestNonGeneric : FindInMemoryTest
Expand All @@ -45,11 +37,7 @@ public FindInMemoryTestNonGeneric(FindInMemoryFixture fixture)
{
}

protected override TEntity Find<TEntity>(DbContext context, params object[] keyValues)
=> (TEntity)context.Find(typeof(TEntity), keyValues);

protected override async ValueTask<TEntity> FindAsync<TEntity>(DbContext context, params object[] keyValues)
=> (TEntity)await context.FindAsync(typeof(TEntity), keyValues);
protected override TestFinder Finder { get; } = new FindViaNonGenericContextFinder();
}

public class FindInMemoryFixture : FindFixtureBase
Expand Down
Loading

0 comments on commit e4b38db

Please sign in to comment.