Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DBFunctions - Add support for instance methods. #9755

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,8 @@ protected virtual void FindDbFunctions([NotNull] ModelBuilder modelBuilder, [Not
Check.NotNull(modelBuilder, nameof(modelBuilder));
Check.NotNull(context, nameof(context));

var functions = context.GetType().GetMethods(BindingFlags.Public | BindingFlags.Static | BindingFlags.FlattenHierarchy)
.Where(
mi => mi.IsStatic
&& mi.IsPublic
var functions = context.GetType().GetMethods(BindingFlags.Public | BindingFlags.Instance | BindingFlags.Static | BindingFlags.FlattenHierarchy)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have logic for selecting only this set of BindingFlags? It includes static protected methods from base types in search but not instance protected methods.
Also should we only support public methods or protected methods too. @anpete @divega

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The BindingFlags.Public will filter out the protected static methods as well as protected instance methods.

I think we should only bring in public methods since users shouldn't be putting their queries directly in their DbContext class - imho. Therefore there is no need for protected methods.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Filed #9956

.Where(mi => mi.IsPublic
&& mi.GetCustomAttributes(typeof(DbFunctionAttribute)).Any());

foreach (var function in functions)
Expand Down
8 changes: 5 additions & 3 deletions src/EFCore.Relational/Metadata/Internal/DbFunction.cs
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,11 @@ private DbFunction(
throw new ArgumentException(RelationalStrings.DbFunctionGenericMethodNotSupported(methodInfo.DisplayName()));
}

if (!methodInfo.IsStatic)
if (!methodInfo.IsStatic
&& !typeof(DbContext).IsAssignableFrom(methodInfo.DeclaringType))
{
throw new ArgumentException(RelationalStrings.DbFunctionMethodMustBeStatic(methodInfo.DisplayName()));
throw new ArgumentException(
RelationalStrings.DbFunctionInvalidInstanceType(methodInfo.DisplayName(), methodInfo.DeclaringType.ShortDisplayName()));
}

if (methodInfo.ReturnType == null
Expand Down Expand Up @@ -102,7 +104,7 @@ public static IEnumerable<IDbFunction> GetDbFunctions([NotNull] IModel model, [N
}

private static string BuildAnnotationName(string annotationPrefix, MethodBase methodBase)
=> $@"{annotationPrefix}{methodBase.Name}({string.Join(",", methodBase.GetParameters().Select(p => p.ParameterType.Name))})";
=> $@"{annotationPrefix}{methodBase.DeclaringType.ShortDisplayName()}{methodBase.Name}({string.Join(",", methodBase.GetParameters().Select(p => p.ParameterType.Name))})";
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is changing annotation name for DbFunction. Since everything is runtime, it would not break anything particular. But hand-crafted model can break may be.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is actually fixing a bug I ran into. If you declare two static methods on different types with the same name then the annotation name was colliding. We need the type name to differentiate the methods.

See the new unit test DbFunctions_with_duplicate_names_and_parameters_on_different_types_dont_collide


/// <summary>
/// This API supports the Entity Framework Core infrastructure and is not intended to be used
Expand Down
16 changes: 8 additions & 8 deletions src/EFCore.Relational/Properties/RelationalStrings.Designer.cs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions src/EFCore.Relational/Properties/RelationalStrings.resx
Original file line number Diff line number Diff line change
Expand Up @@ -425,15 +425,15 @@
<data name="DbFunctionInvalidParameterType" xml:space="preserve">
<value>The parameter '{parameter}' for the DbFunction '{function}' has an invalid type '{type}'. Ensure the parameter type can be mapped by the current provider.</value>
</data>
<data name="DbFunctionMethodMustBeStatic" xml:space="preserve">
<value>The DbFunction '{function}' must be a static method. Non-static methods are not supported.</value>
</data>
<data name="DbFunctionGenericMethodNotSupported" xml:space="preserve">
<value>The DbFunction '{function}' is generic. Generic methods are not supported.</value>
</data>
<data name="DbFunctionExpressionIsNotMethodCall" xml:space="preserve">
<value>The provided DbFunction expression '{expression}' is invalid. The expression should be a lambda expression containing a single method call to the target static method. Default values can be provided as arguments if required. E.g. () =&gt; SomeClass.SomeMethod(null, 0)</value>
</data>
<data name="DbFunctionInvalidInstanceType" xml:space="preserve">
<value>The DbFunction '{function}' defined on type '{type}' must be either a static method or an instance method defined on a DbContext subclass. Instance methods on other types are not supported.</value>
</data>
<data name="ConflictingAmbientTransaction" xml:space="preserve">
<value>An ambient transaction has been detected. The ambient transaction needs to be completed before beginning a transaction on this connection.</value>
</data>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -613,7 +613,9 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
{
Check.NotNull(methodCallExpression, nameof(methodCallExpression));

var operand = Visit(methodCallExpression.Object);
var operand = _queryModelVisitor.QueryCompilationContext.Model.Relational().FindDbFunction(methodCallExpression.Method) != null
? methodCallExpression.Object
: Visit(methodCallExpression.Object);

if (operand != null
|| methodCallExpression.Object == null)
Expand Down
111 changes: 104 additions & 7 deletions test/EFCore.Relational.Tests/Metadata/DbFunctionMetadataTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System;
using System.Linq.Expressions;
using System.Reflection;
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Internal;
using Microsoft.EntityFrameworkCore.Metadata.Conventions;
using Microsoft.EntityFrameworkCore.Metadata.Conventions.Internal;
Expand All @@ -15,9 +16,21 @@ namespace Microsoft.EntityFrameworkCore.Metadata
{
public class DbFunctionMetadataTests
{
public class MyNonDbContext
{
public int NonStatic()
{
throw new Exception();
}

public static int DuplicateNameTest()
{
throw new Exception();
}
}

public class MyBaseContext : DbContext
{
[DbFunction]
public static void Foo()
{
}
Expand All @@ -29,11 +42,22 @@ public static void Skip2()
private static void Skip()
{
}

[DbFunction]
public static int StaticBase()
{
throw new Exception();
}

[DbFunction]
public int NonStaticBase()
{
throw new Exception();
}
}

public class MyDerivedContext : MyBaseContext
{
[DbFunction]
public static void Bar()
{
}
Expand All @@ -47,8 +71,20 @@ private static void Skip4()
}

[DbFunction]
public void NonStatic()
public static int StaticDerived()
{
throw new Exception();
}

[DbFunction]
public int NonStaticDerived()
{
throw new Exception();
}

public static int DuplicateNameTest()
{
throw new Exception();
}
}

Expand Down Expand Up @@ -92,16 +128,77 @@ public static int MethodH<T>(T a, string b)
}

[Fact]
public virtual void Detects_non_static_function_on_dbcontext()
public virtual void DbFunctions_with_duplicate_names_and_parameters_on_different_types_dont_collide()
{
var modelBuilder = GetModelBuilder();

var methodInfo
var Dup1methodInfo
= typeof(MyDerivedContext)
.GetRuntimeMethod(nameof(MyDerivedContext.NonStatic), new Type[] { });
.GetRuntimeMethod(nameof(MyDerivedContext.DuplicateNameTest), new Type[] { });

var Dup2methodInfo
= typeof(MyNonDbContext)
.GetRuntimeMethod(nameof(MyNonDbContext.DuplicateNameTest), new Type[] { });

var dbFunc1 = modelBuilder.HasDbFunction(Dup1methodInfo).HasName("Dup1").Metadata;
var dbFunc2 = modelBuilder.HasDbFunction(Dup2methodInfo).HasName("Dup2").Metadata;

Assert.Equal("Dup1", dbFunc1.FunctionName);
Assert.Equal("Dup2", dbFunc2.FunctionName);
}

[Fact]
public virtual void Finds_dbFunctions_on_dbContext()
{
var modelBuilder = GetModelBuilder();

var customizer = new RelationalModelCustomizer(new ModelCustomizerDependencies(new DbSetFinder()));

customizer.Customize(modelBuilder, new MyDerivedContext());

Assert.NotNull(modelBuilder.Model.Relational().FindDbFunction(
typeof(MyDerivedContext)
.GetRuntimeMethod(nameof(MyBaseContext.NonStaticBase), new Type[] { })));

Assert.NotNull(modelBuilder.Model.Relational().FindDbFunction(
typeof(MyBaseContext)
.GetRuntimeMethod(nameof(MyBaseContext.StaticBase), new Type[] { })));

Assert.NotNull(modelBuilder.Model.Relational().FindDbFunction(
typeof(MyDerivedContext)
.GetRuntimeMethod(nameof(MyDerivedContext.NonStaticDerived), new Type[] { })));

Assert.NotNull(modelBuilder.Model.Relational().FindDbFunction(
typeof(MyDerivedContext)
.GetRuntimeMethod(nameof(MyDerivedContext.NonStaticDerived), new Type[] { })));
}

[Fact]
public virtual void Non_static_function_on_dbcontext_does_not_throw()
{
var modelBuilder = GetModelBuilder();

var methodInfo
= typeof(MyDerivedContext)
.GetRuntimeMethod(nameof(MyDerivedContext.NonStaticBase), new Type[] { });

var dbFunc = modelBuilder.HasDbFunction(methodInfo).Metadata;

Assert.Equal("NonStaticBase", dbFunc.FunctionName);
Assert.Equal(typeof(int), dbFunc.MethodInfo.ReturnType);
}

[Fact]
public virtual void Non_static_function_on_non_dbcontext_throws()
{
var modelBuilder = GetModelBuilder();

var methodInfo
= typeof(MyNonDbContext)
.GetRuntimeMethod(nameof(MyNonDbContext.NonStatic), new Type[] { });

Assert.Equal(
RelationalStrings.DbFunctionMethodMustBeStatic("MyDerivedContext.NonStatic"),
RelationalStrings.DbFunctionInvalidInstanceType(methodInfo.DisplayName(), typeof(MyNonDbContext).ShortDisplayName()),
Assert.Throws<ArgumentException>(() => modelBuilder.HasDbFunction(methodInfo)).Message);
}

Expand Down
Loading