Skip to content

Commit

Permalink
Fir code generation for enums/ranges (#2038)
Browse files Browse the repository at this point in the history
Fixes #2027

Co-authored-by: Shay Rojansky <roji@roji.org>
  • Loading branch information
Kislov Sergey and roji committed Oct 14, 2021
1 parent eb32971 commit ce5604c
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 14 deletions.
33 changes: 19 additions & 14 deletions src/EFCore.PG/Design/Internal/NpgsqlAnnotationCodeGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,23 @@ private static readonly MethodInfo _modelHasPostgresExtensionMethodInfo1

private static readonly MethodInfo _modelHasPostgresExtensionMethodInfo2
= typeof(NpgsqlModelBuilderExtensions).GetRequiredRuntimeMethod(
nameof(NpgsqlModelBuilderExtensions.HasPostgresExtension), typeof(ModelBuilder), typeof(string), typeof(string),
typeof(string));
nameof(NpgsqlModelBuilderExtensions.HasPostgresExtension), typeof(ModelBuilder), typeof(string), typeof(string), typeof(string));

private static readonly MethodInfo _modelHasPostgresEnumMethodInfo
private static readonly MethodInfo _modelHasPostgresEnumMethodInfo1
= typeof(NpgsqlModelBuilderExtensions).GetRequiredRuntimeMethod(
nameof(NpgsqlModelBuilderExtensions.HasPostgresEnum), typeof(ModelBuilder), typeof(string), typeof(string[]));

private static readonly MethodInfo _modelHasPostgresEnumMethodInfo2
= typeof(NpgsqlModelBuilderExtensions).GetRequiredRuntimeMethod(
nameof(NpgsqlModelBuilderExtensions.HasPostgresEnum), typeof(ModelBuilder), typeof(string), typeof(string), typeof(string[]));

private static readonly MethodInfo _modelHasPostgresRangeMethodInfo
private static readonly MethodInfo _modelHasPostgresRangeMethodInfo1
= typeof(NpgsqlModelBuilderExtensions).GetRequiredRuntimeMethod(
nameof(NpgsqlModelBuilderExtensions.HasPostgresRange), typeof(ModelBuilder), typeof(string), typeof(string));

private static readonly MethodInfo _modelHasPostgresRangeMethodInfo2
= typeof(NpgsqlModelBuilderExtensions).GetRequiredRuntimeMethod(
nameof(NpgsqlModelBuilderExtensions.HasPostgresRange), typeof(ModelBuilder), typeof(string), typeof(string),typeof(string), typeof(string),typeof(string), typeof(string),typeof(string));

private static readonly MethodInfo _modelUseSerialColumnsMethodInfo
= typeof(NpgsqlModelBuilderExtensions).GetRequiredRuntimeMethod(
Expand Down Expand Up @@ -186,28 +193,26 @@ public override IReadOnlyList<MethodCallCodeFragment> GenerateFluentApiCalls(
{
var enumTypeDef = new PostgresEnum(model, annotation.Name);

return enumTypeDef.Schema == "public"
? new MethodCallCodeFragment(_modelHasPostgresEnumMethodInfo, enumTypeDef.Name, enumTypeDef.Labels)
: new MethodCallCodeFragment(_modelHasPostgresEnumMethodInfo, enumTypeDef.Schema, enumTypeDef.Name, enumTypeDef.Labels);
return enumTypeDef.Schema is null
? new MethodCallCodeFragment(_modelHasPostgresEnumMethodInfo1, enumTypeDef.Name, enumTypeDef.Labels)
: new MethodCallCodeFragment(_modelHasPostgresEnumMethodInfo2, enumTypeDef.Schema, enumTypeDef.Name, enumTypeDef.Labels);
}

if (annotation.Name.StartsWith(NpgsqlAnnotationNames.RangePrefix, StringComparison.Ordinal))
{
var rangeTypeDef = new PostgresRange(model, annotation.Name);

if (rangeTypeDef.CanonicalFunction is null &&
if (rangeTypeDef.Schema is null &&
rangeTypeDef.CanonicalFunction is null &&
rangeTypeDef.SubtypeOpClass is null &&
rangeTypeDef.Collation is null &&
rangeTypeDef.SubtypeDiff is null)
{
return new MethodCallCodeFragment(_modelHasPostgresRangeMethodInfo,
rangeTypeDef.Schema == "public" ? null : rangeTypeDef.Schema,
rangeTypeDef.Name,
rangeTypeDef.Subtype);
return new MethodCallCodeFragment(_modelHasPostgresRangeMethodInfo1, rangeTypeDef.Name, rangeTypeDef.Subtype);
}

return new MethodCallCodeFragment(_modelHasPostgresRangeMethodInfo,
rangeTypeDef.Schema == "public" ? null : rangeTypeDef.Schema,
return new MethodCallCodeFragment(_modelHasPostgresRangeMethodInfo2,
rangeTypeDef.Schema,
rangeTypeDef.Name,
rangeTypeDef.Subtype,
rangeTypeDef.CanonicalFunction,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,124 @@ public void Extension_with_null_schema()
Assert.Collection(result.Arguments, name => Assert.Equal("postgis", name));
}

[ConditionalFact]
public void Enum()
{
var generator = CreateGenerator();
var modelBuilder = new ModelBuilder(NpgsqlConventionSetBuilder.Build());

var enumLabels = new[] { "someValue1", "someValue2" };
modelBuilder.HasPostgresEnum("some_enum", enumLabels);

var model = (IModel)modelBuilder.Model;
var annotations = model.GetAnnotations().ToDictionary(a => a.Name, a => a);
var result = generator.GenerateFluentApiCalls(model, annotations)
.Single(c => c.Method == nameof(NpgsqlModelBuilderExtensions.HasPostgresEnum));

Assert.Collection(result.Arguments,
name => Assert.Equal("some_enum", name),
labels => Assert.Equal(enumLabels, labels));
}

[ConditionalFact]
public void Enum_with_schema()
{
var generator = CreateGenerator();
var modelBuilder = new ModelBuilder(NpgsqlConventionSetBuilder.Build());

var enumLabels = new[] { "someValue1", "someValue2" };
modelBuilder.HasPostgresEnum("some_schema", "some_enum", enumLabels);

var model = (IModel)modelBuilder.Model;
var annotations = model.GetAnnotations().ToDictionary(a => a.Name, a => a);
var result = generator.GenerateFluentApiCalls(model, annotations)
.Single(c => c.Method == nameof(NpgsqlModelBuilderExtensions.HasPostgresEnum));

Assert.Collection(result.Arguments,
schema => Assert.Equal("some_schema", schema),
name => Assert.Equal("some_enum", name),
labels => Assert.Equal(enumLabels, labels));
}

[ConditionalFact]
public void Enum_with_null_schema()
{
var generator = CreateGenerator();
var modelBuilder = new ModelBuilder(NpgsqlConventionSetBuilder.Build());

var enumLabels = new[] { "someValue1", "someValue2" };
modelBuilder.HasPostgresEnum(schema: null, "some_enum", enumLabels);

var model = (IModel)modelBuilder.Model;
var annotations = model.GetAnnotations().ToDictionary(a => a.Name, a => a);
var result = generator.GenerateFluentApiCalls(model, annotations)
.Single(c => c.Method == nameof(NpgsqlModelBuilderExtensions.HasPostgresEnum));

Assert.Collection(result.Arguments,
name => Assert.Equal("some_enum", name),
labels => Assert.Equal(enumLabels, labels));
}

[ConditionalFact]
public void Range()
{
var generator = CreateGenerator();
var modelBuilder = new ModelBuilder(NpgsqlConventionSetBuilder.Build());

modelBuilder.HasPostgresRange("some_range", "some_subtype");

var model = (IModel)modelBuilder.Model;
var annotations = model.GetAnnotations().ToDictionary(a => a.Name, a => a);
var result = generator.GenerateFluentApiCalls(model, annotations)
.Single(c => c.Method == nameof(NpgsqlModelBuilderExtensions.HasPostgresRange));

Assert.Collection(result.Arguments,
name => Assert.Equal("some_range", name),
subtype => Assert.Equal("some_subtype", subtype));
}

[ConditionalFact]
public void Range_with_schema()
{
var generator = CreateGenerator();
var modelBuilder = new ModelBuilder(NpgsqlConventionSetBuilder.Build());

modelBuilder.HasPostgresRange("some_schema", "some_range", "some_subtype");

var model = (IModel)modelBuilder.Model;
var annotations = model.GetAnnotations().ToDictionary(a => a.Name, a => a);
var result = generator.GenerateFluentApiCalls(model, annotations)
.Single(c => c.Method == nameof(NpgsqlModelBuilderExtensions.HasPostgresRange));

Assert.Collection(result.Arguments,
schema => Assert.Equal("some_schema", schema),
name => Assert.Equal("some_range", name),
subtype => Assert.Equal("some_subtype", subtype),
canonicalFunction => Assert.Null(canonicalFunction),
subtypeOpClass => Assert.Null(subtypeOpClass),
collation => Assert.Null(collation),
subtypeDiff => Assert.Null(subtypeDiff));
}


[ConditionalFact]
public void Range_with_null_schema()
{
var generator = CreateGenerator();
var modelBuilder = new ModelBuilder(NpgsqlConventionSetBuilder.Build());

modelBuilder.HasPostgresRange(schema: null, "some_range", "some_subtype");

var model = (IModel)modelBuilder.Model;
var annotations = model.GetAnnotations().ToDictionary(a => a.Name, a => a);
var result = generator.GenerateFluentApiCalls(model, annotations)
.Single(c => c.Method == nameof(NpgsqlModelBuilderExtensions.HasPostgresRange));

Assert.Collection(result.Arguments,
name => Assert.Equal("some_range", name),
subtype => Assert.Equal("some_subtype", subtype));
}

private NpgsqlAnnotationCodeGenerator CreateGenerator()
=> new(new AnnotationCodeGeneratorDependencies(
new NpgsqlTypeMappingSource(
Expand Down

0 comments on commit ce5604c

Please sign in to comment.