Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #9956 - Add support for private and protected methods in DbFunctions #10040

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions src/EFCore.Relational/Infrastructure/RelationalModelCustomizer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,20 @@ protected virtual void FindDbFunctions([NotNull] ModelBuilder modelBuilder, [Not
Check.NotNull(modelBuilder, nameof(modelBuilder));
Check.NotNull(context, nameof(context));

var functions = context.GetType().GetMethods(BindingFlags.Public | BindingFlags.Instance | BindingFlags.Static | BindingFlags.FlattenHierarchy)
.Where(mi => mi.IsPublic
&& mi.GetCustomAttributes(typeof(DbFunctionAttribute)).Any());
var contextType = context.GetType();

foreach (var function in functions)
{
modelBuilder.HasDbFunction(function);
while(contextType != typeof(DbContext))
{
var functions = contextType.GetMethods(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance
| BindingFlags.Static | BindingFlags.DeclaredOnly)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BindingFlags.DeclaredOnly to search only the methods declared on the Type, not methods that were simply inherited.

I believe this should be different. When working with multi-context scenarios (using different contexts for read/write), user may want to just put method on common base instead of the exact context instance.

Also FlattenHierarchy should be there to find static members from base classes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The while loop will walk up the inheritance chain pulling methods until it gets to DbContext.

We need to do this because base class private methods are not returned when methods are pulled from derived types.

FlattenHierarchy is disabled to prevent duplicate static methods from being returned during the inheritance walk.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah. didn't realize the loop.

.Where(mi => mi.GetCustomAttributes(typeof(DbFunctionAttribute)).Any());

foreach (var function in functions)
{
modelBuilder.HasDbFunction(function);
}

contextType = contextType.BaseType;
}
}

Expand Down
146 changes: 106 additions & 40 deletions test/EFCore.Relational.Tests/Metadata/DbFunctionMetadataTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,20 @@ public static int DuplicateNameTest()

public class MyBaseContext : DbContext
{
public static readonly string[] FunctionNames =
{
nameof(MyBaseContext.StaticPublicBase),
nameof(MyBaseContext.StaticProtectedBase),
nameof(MyBaseContext.StaticPrivateBase),
nameof(MyBaseContext.StaticInteranlBase),
nameof(MyBaseContext.StaticProtectedInteralBase),
nameof(MyBaseContext.InstancePublicBase),
nameof(MyBaseContext.InstanceProtectedBase),
nameof(MyBaseContext.InstancePrivateBase),
nameof(MyBaseContext.InstanceInteranlBase),
nameof(MyBaseContext.InstanceProtectedInteralBase),
};

public static void Foo()
{
}
Expand All @@ -42,22 +56,57 @@ public static void Skip2()
private static void Skip()
{
}

[DbFunction]
public static int StaticBase()
{
throw new Exception();
}
public static int StaticPublicBase() => throw new Exception();

[DbFunction]
public int NonStaticBase()
{
throw new Exception();
}
protected static int StaticProtectedBase() => throw new Exception();

[DbFunction]
private static int StaticPrivateBase() => throw new Exception();

[DbFunction]
internal static int StaticInteranlBase() => throw new Exception();

[DbFunction]
protected internal static int StaticProtectedInteralBase() => throw new Exception();

[DbFunction]
public int InstancePublicBase() => throw new Exception();

[DbFunction]
protected int InstanceProtectedBase() => throw new Exception();

[DbFunction]
private int InstancePrivateBase() => throw new Exception();

[DbFunction]
internal int InstanceInteranlBase() => throw new Exception();

[DbFunction]
protected internal int InstanceProtectedInteralBase() => throw new Exception();

[DbFunction]
public virtual int VirtualBase() => throw new Exception();
}

public class MyDerivedContext : MyBaseContext
{
public new static readonly string[] FunctionNames =
{
nameof(MyDerivedContext.StaticPublicDerived),
nameof(MyDerivedContext.StaticProtectedDerived),
nameof(MyDerivedContext.StaticPrivateDerived),
nameof(MyDerivedContext.StaticInteranlDerived),
nameof(MyDerivedContext.StaticProtectedInteralDerived),
nameof(MyDerivedContext.InstancePublicDerived),
nameof(MyDerivedContext.InstanceProtectedDerived),
nameof(MyDerivedContext.InstancePrivateDerived),
nameof(MyDerivedContext.InstanceInteranlDerived),
nameof(MyDerivedContext.InstanceProtectedInteralDerived),
};

public static void Bar()
{
}
Expand All @@ -70,22 +119,43 @@ private static void Skip4()
{
}

[DbFunction]
public static int StaticDerived()
public static int DuplicateNameTest()
{
throw new Exception();
}

[DbFunction]
public int NonStaticDerived()
{
throw new Exception();
}
public static int StaticPublicDerived() => throw new Exception();

public static int DuplicateNameTest()
{
throw new Exception();
}
[DbFunction]
protected static int StaticProtectedDerived() => throw new Exception();

[DbFunction]
private static int StaticPrivateDerived() => throw new Exception();

[DbFunction]
internal static int StaticInteranlDerived() => throw new Exception();

[DbFunction]
protected internal static int StaticProtectedInteralDerived() => throw new Exception();

[DbFunction]
public int InstancePublicDerived() => throw new Exception();

[DbFunction]
protected int InstanceProtectedDerived() => throw new Exception();

[DbFunction]
private int InstancePrivateDerived() => throw new Exception();

[DbFunction]
internal int InstanceInteranlDerived() => throw new Exception();

[DbFunction]
protected internal int InstanceProtectedInteralDerived() => throw new Exception();

[DbFunction]
public override int VirtualBase() => throw new Exception();
}

public static MethodInfo MethodAmi = typeof(TestMethods).GetRuntimeMethod(nameof(TestMethods.MethodA), new[] { typeof(string), typeof(int) });
Expand Down Expand Up @@ -132,16 +202,16 @@ public virtual void DbFunctions_with_duplicate_names_and_parameters_on_different
{
var modelBuilder = GetModelBuilder();

var Dup1methodInfo
var dup1methodInfo
= typeof(MyDerivedContext)
.GetRuntimeMethod(nameof(MyDerivedContext.DuplicateNameTest), new Type[] { });

var Dup2methodInfo
var dup2methodInfo
= typeof(MyNonDbContext)
.GetRuntimeMethod(nameof(MyNonDbContext.DuplicateNameTest), new Type[] { });

var dbFunc1 = modelBuilder.HasDbFunction(Dup1methodInfo).HasName("Dup1").Metadata;
var dbFunc2 = modelBuilder.HasDbFunction(Dup2methodInfo).HasName("Dup2").Metadata;
var dbFunc1 = modelBuilder.HasDbFunction(dup1methodInfo).HasName("Dup1").Metadata;
var dbFunc2 = modelBuilder.HasDbFunction(dup2methodInfo).HasName("Dup2").Metadata;

Assert.Equal("Dup1", dbFunc1.FunctionName);
Assert.Equal("Dup2", dbFunc2.FunctionName);
Expand All @@ -156,35 +226,31 @@ public virtual void Finds_dbFunctions_on_dbContext()

customizer.Customize(modelBuilder, new MyDerivedContext());

Assert.NotNull(modelBuilder.Model.Relational().FindDbFunction(
typeof(MyDerivedContext)
.GetRuntimeMethod(nameof(MyBaseContext.NonStaticBase), new Type[] { })));

Assert.NotNull(modelBuilder.Model.Relational().FindDbFunction(
typeof(MyBaseContext)
.GetRuntimeMethod(nameof(MyBaseContext.StaticBase), new Type[] { })));

Assert.NotNull(modelBuilder.Model.Relational().FindDbFunction(
typeof(MyDerivedContext)
.GetRuntimeMethod(nameof(MyDerivedContext.NonStaticDerived), new Type[] { })));
foreach (var function in MyBaseContext.FunctionNames)
{
Assert.NotNull(modelBuilder.Model.Relational().FindDbFunction(
typeof(MyBaseContext).GetMethod(function, BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Instance)));
}

Assert.NotNull(modelBuilder.Model.Relational().FindDbFunction(
typeof(MyDerivedContext)
.GetRuntimeMethod(nameof(MyDerivedContext.NonStaticDerived), new Type[] { })));
foreach (var function in MyDerivedContext.FunctionNames)
{
Assert.NotNull(modelBuilder.Model.Relational().FindDbFunction(
typeof(MyDerivedContext).GetMethod(function, BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Instance)));
}
}

[Fact]
public virtual void Non_static_function_on_dbcontext_does_not_throw()
{
var modelBuilder = GetModelBuilder();

var methodInfo
var methodInfo
= typeof(MyDerivedContext)
.GetRuntimeMethod(nameof(MyDerivedContext.NonStaticBase), new Type[] { });
.GetRuntimeMethod(nameof(MyDerivedContext.InstancePublicBase), new Type[] { });

var dbFunc = modelBuilder.HasDbFunction(methodInfo).Metadata;

Assert.Equal("NonStaticBase", dbFunc.FunctionName);
Assert.Equal("InstancePublicBase", dbFunc.FunctionName);
Assert.Equal(typeof(int), dbFunc.MethodInfo.ReturnType);
}

Expand Down