Skip to content

Commit

Permalink
Support customizing model factory accessibility (#4643)
Browse files Browse the repository at this point in the history
Fixes #4634
  • Loading branch information
JoshLove-msft authored Oct 9, 2024
1 parent 9ab01ec commit 8f50ec6
Show file tree
Hide file tree
Showing 12 changed files with 83 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,5 @@ protected override PropertyProvider[] BuildProperties()

return [.. properties];
}

protected override TypeSignatureModifiers GetDeclarationModifiers() => GetCustomCodeModifiers();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -413,8 +413,6 @@ protected override MethodProvider[] BuildMethods()
return [.. methods];
}

protected override TypeSignatureModifiers GetDeclarationModifiers() => GetCustomCodeModifiers();

private ParameterProvider BuildClientEndpointParameter()
{
var endpointParam = _inputClient.Parameters.FirstOrDefault(p => p.IsEndpoint);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,6 @@ protected override MethodProvider[] BuildMethods()
return [.. methods];
}

protected override TypeSignatureModifiers GetDeclarationModifiers() => GetCustomCodeModifiers();

private bool IsCreateRequest(MethodProvider method)
{
var span = method.Signature.Name.AsSpan();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,13 +176,16 @@ private async Task<Project> RemoveMethodsFromModelFactoryAsync(Project project,
// find the GENERATED document of model factory (we may have the customized document of this for overloads)
Document? modelFactoryGeneratedDocument = null;
// the nodes corresponding to the model factory symbol has never been changed therefore the nodes inside the cache are still usable
foreach (var declarationNode in definitions.DeclaredNodesCache[modelFactorySymbol])
if (definitions.DeclaredNodesCache.TryGetValue(modelFactorySymbol, out var nodes))
{
var document = project.GetDocument(declarationNode.SyntaxTree);
if (document != null && GeneratedCodeWorkspace.IsGeneratedDocument(document))
foreach (var declarationNode in nodes)
{
modelFactoryGeneratedDocument = document;
break;
var document = project.GetDocument(declarationNode.SyntaxTree);
if (document != null && GeneratedCodeWorkspace.IsGeneratedDocument(document))
{
modelFactoryGeneratedDocument = document;
break;
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,8 @@ internal ExtensibleEnumProvider(InputEnumType input, TypeProvider? declaringType
_allowedValues = input.Values;
// extensible enums are implemented as readonly structs
_modifiers = TypeSignatureModifiers.Partial | TypeSignatureModifiers.ReadOnly | TypeSignatureModifiers.Struct;
var customCodeModifiers = GetCustomCodeModifiers();

if (customCodeModifiers != TypeSignatureModifiers.None)
{
_modifiers |= customCodeModifiers;
}
else if (input.Accessibility == "internal")
if (input.Accessibility == "internal")
{
_modifiers |= TypeSignatureModifiers.Internal;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,7 @@ internal FixedEnumProvider(InputEnumType input, TypeProvider? declaringType) : b
// fixed enums are implemented by enum in C#
_modifiers = TypeSignatureModifiers.Enum;

var customCodeModifiers = GetCustomCodeModifiers();
if (customCodeModifiers != TypeSignatureModifiers.None)
{
_modifiers |= customCodeModifiers;
}
else if (input.Accessibility == "internal")
if (input.Accessibility == "internal")
{
_modifiers |= TypeSignatureModifiers.Internal;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ protected override string BuildName()
protected override string BuildRelativeFilePath() => Path.Combine("src", "Generated", $"{Name}.cs");

protected override TypeSignatureModifiers GetDeclarationModifiers()
=> TypeSignatureModifiers.Static | TypeSignatureModifiers.Public | TypeSignatureModifiers.Class | TypeSignatureModifiers.Partial;
=> TypeSignatureModifiers.Public | TypeSignatureModifiers.Static | TypeSignatureModifiers.Partial | TypeSignatureModifiers.Class;

protected override string GetNamespace() => CodeModelPlugin.Instance.Configuration.ModelNamespace;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ protected override TypeProvider[] BuildSerializationProviders()

protected override TypeSignatureModifiers GetDeclarationModifiers()
{
var customCodeModifiers = GetCustomCodeModifiers();
var customCodeModifiers = CustomCodeView?.DeclarationModifiers ?? TypeSignatureModifiers.None;
var isStruct = false;
// the information of if this model should be a struct comes from two sources:
// 1. the customied code
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,14 @@ public TypeSignatureModifiers DeclarationModifiers

protected virtual TypeSignatureModifiers GetDeclarationModifiers() => TypeSignatureModifiers.None;

protected TypeSignatureModifiers GetCustomCodeModifiers() => CustomCodeView?.DeclarationModifiers ?? TypeSignatureModifiers.None;

private TypeSignatureModifiers GetDeclarationModifiersInternal()
{
var modifiers = GetDeclarationModifiers();
var customModifiers = CustomCodeView?.DeclarationModifiers ?? TypeSignatureModifiers.None;
if (customModifiers != TypeSignatureModifiers.None)
{
modifiers |= customModifiers;
}
// we default to public when no accessibility modifier is provided
if (!modifiers.HasFlag(TypeSignatureModifiers.Internal) && !modifiers.HasFlag(TypeSignatureModifiers.Public) && !modifiers.HasFlag(TypeSignatureModifiers.Private))
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

using System.Linq;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis;
using Microsoft.Generator.CSharp.Input;
using Microsoft.Generator.CSharp.Primitives;
using Microsoft.Generator.CSharp.Providers;
using Microsoft.Generator.CSharp.Tests.Common;
using NUnit.Framework;
Expand Down Expand Up @@ -37,6 +39,10 @@ public async Task CanReplaceModelMethod()
var modelFactory = plugin.Object.OutputLibrary.TypeProviders.SingleOrDefault(t => t is ModelFactoryProvider);
Assert.IsNotNull(modelFactory);

// The model factory should be public
Assert.IsTrue(modelFactory!.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Public));
ValidateModelFactoryCommon(modelFactory);

// The model factory method should be replaced
var modelFactoryMethods = modelFactory!.Methods;
Assert.AreEqual(1, modelFactoryMethods.Count);
Expand Down Expand Up @@ -71,9 +77,50 @@ public async Task DoesNotReplaceMethodIfNotCustomized()
var modelFactory = plugin.Object.OutputLibrary.TypeProviders.SingleOrDefault(t => t is ModelFactoryProvider);
Assert.IsNotNull(modelFactory);

// The model factory should be public
Assert.IsTrue(modelFactory!.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Public));
ValidateModelFactoryCommon(modelFactory);

// The model factory method should not be replaced
var modelFactoryMethods = modelFactory!.Methods;
Assert.AreEqual(1, modelFactoryMethods.Count);
}

private static void ValidateModelFactoryCommon(TypeProvider modelFactory)
{
Assert.IsTrue(modelFactory!.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Static));
Assert.IsTrue(modelFactory!.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Partial));
Assert.IsTrue(modelFactory!.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Class));
}

[Test]
public async Task CanChangeAccessibilityOfModelFactory()
{
var plugin = await MockHelpers.LoadMockPluginAsync(
inputModelTypes: [
InputFactory.Model(
"mockInputModel",
properties:
[
InputFactory.Property("Prop1", InputPrimitiveType.String),
InputFactory.Property("OptionalBool", InputPrimitiveType.Boolean, isRequired: false)
]),
InputFactory.Model(
"otherModel",
properties: [InputFactory.Property("Prop2", InputPrimitiveType.String)]),
],
compilation: async () => await Helpers.GetCompilationFromDirectoryAsync());
var csharpGen = new CSharpGen();

await csharpGen.ExecuteAsync();

// Find the model factory provider
var modelFactory = plugin.Object.OutputLibrary.TypeProviders.SingleOrDefault(t => t is ModelFactoryProvider);
Assert.IsNotNull(modelFactory);

// The model factory should be internal
Assert.IsTrue(modelFactory!.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Internal));
ValidateModelFactoryCommon(modelFactory);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
using Microsoft.Generator.CSharp.Customization;
using System;

namespace Sample.Models
{
internal static partial class SampleNamespaceModelFactory
{
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ public class ClientCustomizationTests
public async Task CanRemoveMethods()
{
var client = new ClientTypeProvider();
var outputLibrary = new ClientOutputLibrary(client);
var plugin = await MockHelpers.LoadMockPluginAsync(
createOutputLibrary: () => outputLibrary,
compilation: async () => await Helpers.GetCompilationFromDirectoryAsync());

var methods = new[]
{
new MethodProvider(new MethodSignature(
Expand Down Expand Up @@ -60,11 +65,7 @@ public async Task CanRemoveMethods()

};
client.MethodProviders = methods;
var outputLibrary = new ClientOutputLibrary(client);

var plugin = await MockHelpers.LoadMockPluginAsync(
createOutputLibrary: () => outputLibrary,
compilation: async () => await Helpers.GetCompilationFromDirectoryAsync());
var csharpGen = new CSharpGen();
await csharpGen.ExecuteAsync();

Expand All @@ -75,6 +76,11 @@ public async Task CanRemoveMethods()
public async Task DoesNotRemoveMethodsThatDoNotMatch()
{
var client = new ClientTypeProvider();
var outputLibrary = new ClientOutputLibrary(client);
var plugin = await MockHelpers.LoadMockPluginAsync(
createOutputLibrary: () => outputLibrary,
compilation: async () => await Helpers.GetCompilationFromDirectoryAsync());

var methods = new[]
{
// Method name doesn't match
Expand Down Expand Up @@ -122,11 +128,7 @@ public async Task DoesNotRemoveMethodsThatDoNotMatch()

};
client.MethodProviders = methods;
var outputLibrary = new ClientOutputLibrary(client);

var plugin = await MockHelpers.LoadMockPluginAsync(
createOutputLibrary: () => outputLibrary,
compilation: async () => await Helpers.GetCompilationFromDirectoryAsync());
var csharpGen = new CSharpGen();
await csharpGen.ExecuteAsync();

Expand Down

0 comments on commit 8f50ec6

Please sign in to comment.