diff --git a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/src/Providers/ClientOptionsProvider.cs b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/src/Providers/ClientOptionsProvider.cs index 6c2a811561..a936f139f3 100644 --- a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/src/Providers/ClientOptionsProvider.cs +++ b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/src/Providers/ClientOptionsProvider.cs @@ -145,7 +145,5 @@ protected override PropertyProvider[] BuildProperties() return [.. properties]; } - - protected override TypeSignatureModifiers GetDeclarationModifiers() => GetCustomCodeModifiers(); } } diff --git a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/src/Providers/ClientProvider.cs b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/src/Providers/ClientProvider.cs index 246aea4b10..cc3cfbb3e0 100644 --- a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/src/Providers/ClientProvider.cs +++ b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/src/Providers/ClientProvider.cs @@ -413,8 +413,6 @@ protected override MethodProvider[] BuildMethods() return [.. methods]; } - protected override TypeSignatureModifiers GetDeclarationModifiers() => GetCustomCodeModifiers(); - private ParameterProvider BuildClientEndpointParameter() { var endpointParam = _inputClient.Parameters.FirstOrDefault(p => p.IsEndpoint); diff --git a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/src/Providers/RestClientProvider.cs b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/src/Providers/RestClientProvider.cs index 9f0b5acd36..ca326c3f89 100644 --- a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/src/Providers/RestClientProvider.cs +++ b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp.ClientModel/src/Providers/RestClientProvider.cs @@ -121,8 +121,6 @@ protected override MethodProvider[] BuildMethods() return [.. methods]; } - protected override TypeSignatureModifiers GetDeclarationModifiers() => GetCustomCodeModifiers(); - private bool IsCreateRequest(MethodProvider method) { var span = method.Signature.Name.AsSpan(); diff --git a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/src/PostProcessing/PostProcessor.cs b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/src/PostProcessing/PostProcessor.cs index 96167f06c8..47044a2dcb 100644 --- a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/src/PostProcessing/PostProcessor.cs +++ b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/src/PostProcessing/PostProcessor.cs @@ -176,13 +176,16 @@ private async Task RemoveMethodsFromModelFactoryAsync(Project project, // find the GENERATED document of model factory (we may have the customized document of this for overloads) Document? modelFactoryGeneratedDocument = null; // the nodes corresponding to the model factory symbol has never been changed therefore the nodes inside the cache are still usable - foreach (var declarationNode in definitions.DeclaredNodesCache[modelFactorySymbol]) + if (definitions.DeclaredNodesCache.TryGetValue(modelFactorySymbol, out var nodes)) { - var document = project.GetDocument(declarationNode.SyntaxTree); - if (document != null && GeneratedCodeWorkspace.IsGeneratedDocument(document)) + foreach (var declarationNode in nodes) { - modelFactoryGeneratedDocument = document; - break; + var document = project.GetDocument(declarationNode.SyntaxTree); + if (document != null && GeneratedCodeWorkspace.IsGeneratedDocument(document)) + { + modelFactoryGeneratedDocument = document; + break; + } } } diff --git a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/src/Providers/ExtensibleEnumProvider.cs b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/src/Providers/ExtensibleEnumProvider.cs index 26d7a7ee38..28d6386a1c 100644 --- a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/src/Providers/ExtensibleEnumProvider.cs +++ b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/src/Providers/ExtensibleEnumProvider.cs @@ -26,13 +26,8 @@ internal ExtensibleEnumProvider(InputEnumType input, TypeProvider? declaringType _allowedValues = input.Values; // extensible enums are implemented as readonly structs _modifiers = TypeSignatureModifiers.Partial | TypeSignatureModifiers.ReadOnly | TypeSignatureModifiers.Struct; - var customCodeModifiers = GetCustomCodeModifiers(); - if (customCodeModifiers != TypeSignatureModifiers.None) - { - _modifiers |= customCodeModifiers; - } - else if (input.Accessibility == "internal") + if (input.Accessibility == "internal") { _modifiers |= TypeSignatureModifiers.Internal; } diff --git a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/src/Providers/FixedEnumProvider.cs b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/src/Providers/FixedEnumProvider.cs index 4a8e192839..ca8653c45f 100644 --- a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/src/Providers/FixedEnumProvider.cs +++ b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/src/Providers/FixedEnumProvider.cs @@ -27,12 +27,7 @@ internal FixedEnumProvider(InputEnumType input, TypeProvider? declaringType) : b // fixed enums are implemented by enum in C# _modifiers = TypeSignatureModifiers.Enum; - var customCodeModifiers = GetCustomCodeModifiers(); - if (customCodeModifiers != TypeSignatureModifiers.None) - { - _modifiers |= customCodeModifiers; - } - else if (input.Accessibility == "internal") + if (input.Accessibility == "internal") { _modifiers |= TypeSignatureModifiers.Internal; } diff --git a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/src/Providers/ModelFactoryProvider.cs b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/src/Providers/ModelFactoryProvider.cs index 14f83c1a7a..215f320dd7 100644 --- a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/src/Providers/ModelFactoryProvider.cs +++ b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/src/Providers/ModelFactoryProvider.cs @@ -52,7 +52,7 @@ protected override string BuildName() protected override string BuildRelativeFilePath() => Path.Combine("src", "Generated", $"{Name}.cs"); protected override TypeSignatureModifiers GetDeclarationModifiers() - => TypeSignatureModifiers.Static | TypeSignatureModifiers.Public | TypeSignatureModifiers.Class | TypeSignatureModifiers.Partial; + => TypeSignatureModifiers.Public | TypeSignatureModifiers.Static | TypeSignatureModifiers.Partial | TypeSignatureModifiers.Class; protected override string GetNamespace() => CodeModelPlugin.Instance.Configuration.ModelNamespace; diff --git a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/src/Providers/ModelProvider.cs b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/src/Providers/ModelProvider.cs index 7828063ac7..030367ebff 100644 --- a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/src/Providers/ModelProvider.cs +++ b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/src/Providers/ModelProvider.cs @@ -104,7 +104,7 @@ protected override TypeProvider[] BuildSerializationProviders() protected override TypeSignatureModifiers GetDeclarationModifiers() { - var customCodeModifiers = GetCustomCodeModifiers(); + var customCodeModifiers = CustomCodeView?.DeclarationModifiers ?? TypeSignatureModifiers.None; var isStruct = false; // the information of if this model should be a struct comes from two sources: // 1. the customied code diff --git a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/src/Providers/TypeProvider.cs b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/src/Providers/TypeProvider.cs index 6c302d9140..8ae0df855e 100644 --- a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/src/Providers/TypeProvider.cs +++ b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/src/Providers/TypeProvider.cs @@ -82,11 +82,14 @@ public TypeSignatureModifiers DeclarationModifiers protected virtual TypeSignatureModifiers GetDeclarationModifiers() => TypeSignatureModifiers.None; - protected TypeSignatureModifiers GetCustomCodeModifiers() => CustomCodeView?.DeclarationModifiers ?? TypeSignatureModifiers.None; - private TypeSignatureModifiers GetDeclarationModifiersInternal() { var modifiers = GetDeclarationModifiers(); + var customModifiers = CustomCodeView?.DeclarationModifiers ?? TypeSignatureModifiers.None; + if (customModifiers != TypeSignatureModifiers.None) + { + modifiers |= customModifiers; + } // we default to public when no accessibility modifier is provided if (!modifiers.HasFlag(TypeSignatureModifiers.Internal) && !modifiers.HasFlag(TypeSignatureModifiers.Public) && !modifiers.HasFlag(TypeSignatureModifiers.Private)) { diff --git a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/test/Providers/ModelFactories/ModelFactoriesCustomizationTests.cs b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/test/Providers/ModelFactories/ModelFactoriesCustomizationTests.cs index 2d7e173abb..da8a90375b 100644 --- a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/test/Providers/ModelFactories/ModelFactoriesCustomizationTests.cs +++ b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/test/Providers/ModelFactories/ModelFactoriesCustomizationTests.cs @@ -3,7 +3,9 @@ using System.Linq; using System.Threading.Tasks; +using Microsoft.CodeAnalysis; using Microsoft.Generator.CSharp.Input; +using Microsoft.Generator.CSharp.Primitives; using Microsoft.Generator.CSharp.Providers; using Microsoft.Generator.CSharp.Tests.Common; using NUnit.Framework; @@ -37,6 +39,10 @@ public async Task CanReplaceModelMethod() var modelFactory = plugin.Object.OutputLibrary.TypeProviders.SingleOrDefault(t => t is ModelFactoryProvider); Assert.IsNotNull(modelFactory); + // The model factory should be public + Assert.IsTrue(modelFactory!.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Public)); + ValidateModelFactoryCommon(modelFactory); + // The model factory method should be replaced var modelFactoryMethods = modelFactory!.Methods; Assert.AreEqual(1, modelFactoryMethods.Count); @@ -71,9 +77,50 @@ public async Task DoesNotReplaceMethodIfNotCustomized() var modelFactory = plugin.Object.OutputLibrary.TypeProviders.SingleOrDefault(t => t is ModelFactoryProvider); Assert.IsNotNull(modelFactory); + // The model factory should be public + Assert.IsTrue(modelFactory!.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Public)); + ValidateModelFactoryCommon(modelFactory); + // The model factory method should not be replaced var modelFactoryMethods = modelFactory!.Methods; Assert.AreEqual(1, modelFactoryMethods.Count); } + + private static void ValidateModelFactoryCommon(TypeProvider modelFactory) + { + Assert.IsTrue(modelFactory!.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Static)); + Assert.IsTrue(modelFactory!.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Partial)); + Assert.IsTrue(modelFactory!.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Class)); + } + + [Test] + public async Task CanChangeAccessibilityOfModelFactory() + { + var plugin = await MockHelpers.LoadMockPluginAsync( + inputModelTypes: [ + InputFactory.Model( + "mockInputModel", + properties: + [ + InputFactory.Property("Prop1", InputPrimitiveType.String), + InputFactory.Property("OptionalBool", InputPrimitiveType.Boolean, isRequired: false) + ]), + InputFactory.Model( + "otherModel", + properties: [InputFactory.Property("Prop2", InputPrimitiveType.String)]), + ], + compilation: async () => await Helpers.GetCompilationFromDirectoryAsync()); + var csharpGen = new CSharpGen(); + + await csharpGen.ExecuteAsync(); + + // Find the model factory provider + var modelFactory = plugin.Object.OutputLibrary.TypeProviders.SingleOrDefault(t => t is ModelFactoryProvider); + Assert.IsNotNull(modelFactory); + + // The model factory should be internal + Assert.IsTrue(modelFactory!.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Internal)); + ValidateModelFactoryCommon(modelFactory); + } } } diff --git a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/test/Providers/ModelFactories/TestData/ModelFactoriesCustomizationTests/CanChangeAccessibilityOfModelFactory/SampleNamespaceModelFactory.cs b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/test/Providers/ModelFactories/TestData/ModelFactoriesCustomizationTests/CanChangeAccessibilityOfModelFactory/SampleNamespaceModelFactory.cs new file mode 100644 index 0000000000..30350730fe --- /dev/null +++ b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/test/Providers/ModelFactories/TestData/ModelFactoriesCustomizationTests/CanChangeAccessibilityOfModelFactory/SampleNamespaceModelFactory.cs @@ -0,0 +1,9 @@ +using Microsoft.Generator.CSharp.Customization; +using System; + +namespace Sample.Models +{ + internal static partial class SampleNamespaceModelFactory + { + } +} diff --git a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/test/Providers/ModelProviders/ClientCustomizationTests.cs b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/test/Providers/ModelProviders/ClientCustomizationTests.cs index 078591e3cd..2b4c8bca50 100644 --- a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/test/Providers/ModelProviders/ClientCustomizationTests.cs +++ b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/test/Providers/ModelProviders/ClientCustomizationTests.cs @@ -17,6 +17,11 @@ public class ClientCustomizationTests public async Task CanRemoveMethods() { var client = new ClientTypeProvider(); + var outputLibrary = new ClientOutputLibrary(client); + var plugin = await MockHelpers.LoadMockPluginAsync( + createOutputLibrary: () => outputLibrary, + compilation: async () => await Helpers.GetCompilationFromDirectoryAsync()); + var methods = new[] { new MethodProvider(new MethodSignature( @@ -60,11 +65,7 @@ public async Task CanRemoveMethods() }; client.MethodProviders = methods; - var outputLibrary = new ClientOutputLibrary(client); - var plugin = await MockHelpers.LoadMockPluginAsync( - createOutputLibrary: () => outputLibrary, - compilation: async () => await Helpers.GetCompilationFromDirectoryAsync()); var csharpGen = new CSharpGen(); await csharpGen.ExecuteAsync(); @@ -75,6 +76,11 @@ public async Task CanRemoveMethods() public async Task DoesNotRemoveMethodsThatDoNotMatch() { var client = new ClientTypeProvider(); + var outputLibrary = new ClientOutputLibrary(client); + var plugin = await MockHelpers.LoadMockPluginAsync( + createOutputLibrary: () => outputLibrary, + compilation: async () => await Helpers.GetCompilationFromDirectoryAsync()); + var methods = new[] { // Method name doesn't match @@ -122,11 +128,7 @@ public async Task DoesNotRemoveMethodsThatDoNotMatch() }; client.MethodProviders = methods; - var outputLibrary = new ClientOutputLibrary(client); - var plugin = await MockHelpers.LoadMockPluginAsync( - createOutputLibrary: () => outputLibrary, - compilation: async () => await Helpers.GetCompilationFromDirectoryAsync()); var csharpGen = new CSharpGen(); await csharpGen.ExecuteAsync();