Skip to content

Commit

Permalink
DBFunctions - Add support for instance methods.
Browse files Browse the repository at this point in the history
  • Loading branch information
pmiddleton authored and smitpatel committed Oct 3, 2017
1 parent 29c5538 commit 0edbe21
Show file tree
Hide file tree
Showing 7 changed files with 1,030 additions and 129 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,8 @@ protected virtual void FindDbFunctions([NotNull] ModelBuilder modelBuilder, [Not
Check.NotNull(modelBuilder, nameof(modelBuilder));
Check.NotNull(context, nameof(context));

var functions = context.GetType().GetMethods(BindingFlags.Public | BindingFlags.Static | BindingFlags.FlattenHierarchy)
.Where(
mi => mi.IsStatic
&& mi.IsPublic
var functions = context.GetType().GetMethods(BindingFlags.Public | BindingFlags.Instance | BindingFlags.Static | BindingFlags.FlattenHierarchy)
.Where(mi => mi.IsPublic
&& mi.GetCustomAttributes(typeof(DbFunctionAttribute)).Any());

foreach (var function in functions)
Expand Down
8 changes: 5 additions & 3 deletions src/EFCore.Relational/Metadata/Internal/DbFunction.cs
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,11 @@ private DbFunction(
throw new ArgumentException(RelationalStrings.DbFunctionGenericMethodNotSupported(methodInfo.DisplayName()));
}

if (!methodInfo.IsStatic)
if (!methodInfo.IsStatic
&& !typeof(DbContext).IsAssignableFrom(methodInfo.DeclaringType))
{
throw new ArgumentException(RelationalStrings.DbFunctionMethodMustBeStatic(methodInfo.DisplayName()));
throw new ArgumentException(
RelationalStrings.DbFunctionInvalidInstanceType(methodInfo.DisplayName(), methodInfo.DeclaringType.ShortDisplayName()));
}

if (methodInfo.ReturnType == null
Expand Down Expand Up @@ -102,7 +104,7 @@ public static IEnumerable<IDbFunction> GetDbFunctions([NotNull] IModel model, [N
}

private static string BuildAnnotationName(string annotationPrefix, MethodBase methodBase)
=> $@"{annotationPrefix}{methodBase.Name}({string.Join(",", methodBase.GetParameters().Select(p => p.ParameterType.Name))})";
=> $@"{annotationPrefix}{methodBase.DeclaringType.ShortDisplayName()}{methodBase.Name}({string.Join(",", methodBase.GetParameters().Select(p => p.ParameterType.Name))})";

/// <summary>
/// This API supports the Entity Framework Core infrastructure and is not intended to be used
Expand Down
16 changes: 8 additions & 8 deletions src/EFCore.Relational/Properties/RelationalStrings.Designer.cs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions src/EFCore.Relational/Properties/RelationalStrings.resx
Original file line number Diff line number Diff line change
Expand Up @@ -425,15 +425,15 @@
<data name="DbFunctionInvalidParameterType" xml:space="preserve">
<value>The parameter '{parameter}' for the DbFunction '{function}' has an invalid type '{type}'. Ensure the parameter type can be mapped by the current provider.</value>
</data>
<data name="DbFunctionMethodMustBeStatic" xml:space="preserve">
<value>The DbFunction '{function}' must be a static method. Non-static methods are not supported.</value>
</data>
<data name="DbFunctionGenericMethodNotSupported" xml:space="preserve">
<value>The DbFunction '{function}' is generic. Generic methods are not supported.</value>
</data>
<data name="DbFunctionExpressionIsNotMethodCall" xml:space="preserve">
<value>The provided DbFunction expression '{expression}' is invalid. The expression should be a lambda expression containing a single method call to the target static method. Default values can be provided as arguments if required. E.g. () =&gt; SomeClass.SomeMethod(null, 0)</value>
</data>
<data name="DbFunctionInvalidInstanceType" xml:space="preserve">
<value>The DbFunction '{function}' defined on type '{type}' must be either a static method or an instance method defined on a DbContext subclass. Instance methods on other types are not supported.</value>
</data>
<data name="ConflictingAmbientTransaction" xml:space="preserve">
<value>An ambient transaction has been detected. The ambient transaction needs to be completed before beginning a transaction on this connection.</value>
</data>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -613,7 +613,9 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
{
Check.NotNull(methodCallExpression, nameof(methodCallExpression));

var operand = Visit(methodCallExpression.Object);
var operand = _queryModelVisitor.QueryCompilationContext.Model.Relational().FindDbFunction(methodCallExpression.Method) != null
? methodCallExpression.Object
: Visit(methodCallExpression.Object);

if (operand != null
|| methodCallExpression.Object == null)
Expand Down
111 changes: 104 additions & 7 deletions test/EFCore.Relational.Tests/Metadata/DbFunctionMetadataTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System;
using System.Linq.Expressions;
using System.Reflection;
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Internal;
using Microsoft.EntityFrameworkCore.Metadata.Conventions;
using Microsoft.EntityFrameworkCore.Metadata.Conventions.Internal;
Expand All @@ -15,9 +16,21 @@ namespace Microsoft.EntityFrameworkCore.Metadata
{
public class DbFunctionMetadataTests
{
public class MyNonDbContext
{
public int NonStatic()
{
throw new Exception();
}

public static int DuplicateNameTest()
{
throw new Exception();
}
}

public class MyBaseContext : DbContext
{
[DbFunction]
public static void Foo()
{
}
Expand All @@ -29,11 +42,22 @@ public static void Skip2()
private static void Skip()
{
}

[DbFunction]
public static int StaticBase()
{
throw new Exception();
}

[DbFunction]
public int NonStaticBase()
{
throw new Exception();
}
}

public class MyDerivedContext : MyBaseContext
{
[DbFunction]
public static void Bar()
{
}
Expand All @@ -47,8 +71,20 @@ private static void Skip4()
}

[DbFunction]
public void NonStatic()
public static int StaticDerived()
{
throw new Exception();
}

[DbFunction]
public int NonStaticDerived()
{
throw new Exception();
}

public static int DuplicateNameTest()
{
throw new Exception();
}
}

Expand Down Expand Up @@ -92,16 +128,77 @@ public static int MethodH<T>(T a, string b)
}

[Fact]
public virtual void Detects_non_static_function_on_dbcontext()
public virtual void DbFunctions_with_duplicate_names_and_parameters_on_different_types_dont_collide()
{
var modelBuilder = GetModelBuilder();

var methodInfo
var Dup1methodInfo
= typeof(MyDerivedContext)
.GetRuntimeMethod(nameof(MyDerivedContext.NonStatic), new Type[] { });
.GetRuntimeMethod(nameof(MyDerivedContext.DuplicateNameTest), new Type[] { });

var Dup2methodInfo
= typeof(MyNonDbContext)
.GetRuntimeMethod(nameof(MyNonDbContext.DuplicateNameTest), new Type[] { });

var dbFunc1 = modelBuilder.HasDbFunction(Dup1methodInfo).HasName("Dup1").Metadata;
var dbFunc2 = modelBuilder.HasDbFunction(Dup2methodInfo).HasName("Dup2").Metadata;

Assert.Equal("Dup1", dbFunc1.FunctionName);
Assert.Equal("Dup2", dbFunc2.FunctionName);
}

[Fact]
public virtual void Finds_dbFunctions_on_dbContext()
{
var modelBuilder = GetModelBuilder();

var customizer = new RelationalModelCustomizer(new ModelCustomizerDependencies(new DbSetFinder()));

customizer.Customize(modelBuilder, new MyDerivedContext());

Assert.NotNull(modelBuilder.Model.Relational().FindDbFunction(
typeof(MyDerivedContext)
.GetRuntimeMethod(nameof(MyBaseContext.NonStaticBase), new Type[] { })));

Assert.NotNull(modelBuilder.Model.Relational().FindDbFunction(
typeof(MyBaseContext)
.GetRuntimeMethod(nameof(MyBaseContext.StaticBase), new Type[] { })));

Assert.NotNull(modelBuilder.Model.Relational().FindDbFunction(
typeof(MyDerivedContext)
.GetRuntimeMethod(nameof(MyDerivedContext.NonStaticDerived), new Type[] { })));

Assert.NotNull(modelBuilder.Model.Relational().FindDbFunction(
typeof(MyDerivedContext)
.GetRuntimeMethod(nameof(MyDerivedContext.NonStaticDerived), new Type[] { })));
}

[Fact]
public virtual void Non_static_function_on_dbcontext_does_not_throw()
{
var modelBuilder = GetModelBuilder();

var methodInfo
= typeof(MyDerivedContext)
.GetRuntimeMethod(nameof(MyDerivedContext.NonStaticBase), new Type[] { });

var dbFunc = modelBuilder.HasDbFunction(methodInfo).Metadata;

Assert.Equal("NonStaticBase", dbFunc.FunctionName);
Assert.Equal(typeof(int), dbFunc.MethodInfo.ReturnType);
}

[Fact]
public virtual void Non_static_function_on_non_dbcontext_throws()
{
var modelBuilder = GetModelBuilder();

var methodInfo
= typeof(MyNonDbContext)
.GetRuntimeMethod(nameof(MyNonDbContext.NonStatic), new Type[] { });

Assert.Equal(
RelationalStrings.DbFunctionMethodMustBeStatic("MyDerivedContext.NonStatic"),
RelationalStrings.DbFunctionInvalidInstanceType(methodInfo.DisplayName(), typeof(MyNonDbContext).ShortDisplayName()),
Assert.Throws<ArgumentException>(() => modelBuilder.HasDbFunction(methodInfo)).Message);
}

Expand Down
Loading

0 comments on commit 0edbe21

Please sign in to comment.