From 9ab01ec9699ea0bb49ac4353743b42b62bd0bab1 Mon Sep 17 00:00:00 2001 From: JoshLove-msft <54595583+JoshLove-msft@users.noreply.github.com> Date: Wed, 9 Oct 2024 10:46:11 -0700 Subject: [PATCH] Handle nullable parameter types for method suppression (#4642) Fixes https://github.com/microsoft/typespec/issues/4641 --- .../src/Providers/NamedTypeSymbolProvider.cs | 198 +--------------- .../src/Providers/TypeProvider.cs | 5 +- .../Utilities/NamedTypeSymbolExtensions.cs | 36 --- .../src/Utilities/TypeSymbolExtensions.cs | 224 ++++++++++++++++++ .../ClientCustomizationTests.cs | 25 +- .../CanRemoveMethods/MockInputClient.cs | 1 + .../MockInputClient.cs | 1 + 7 files changed, 258 insertions(+), 232 deletions(-) delete mode 100644 packages/http-client-csharp/generator/Microsoft.Generator.CSharp/src/Utilities/NamedTypeSymbolExtensions.cs create mode 100644 packages/http-client-csharp/generator/Microsoft.Generator.CSharp/src/Utilities/TypeSymbolExtensions.cs diff --git a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/src/Providers/NamedTypeSymbolProvider.cs b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/src/Providers/NamedTypeSymbolProvider.cs index 92ba1d107f..dc5a15b153 100644 --- a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/src/Providers/NamedTypeSymbolProvider.cs +++ b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/src/Providers/NamedTypeSymbolProvider.cs @@ -14,7 +14,6 @@ namespace Microsoft.Generator.CSharp.Providers { public sealed class NamedTypeSymbolProvider : TypeProvider { - private const string GlobalPrefix = "global::"; private INamedTypeSymbol _namedTypeSymbol; public NamedTypeSymbolProvider(INamedTypeSymbol namedTypeSymbol) @@ -28,7 +27,7 @@ public NamedTypeSymbolProvider(INamedTypeSymbol namedTypeSymbol) protected override string BuildName() => _namedTypeSymbol.Name; - protected override string GetNamespace() => GetFullyQualifiedNameFromDisplayString(_namedTypeSymbol.ContainingNamespace); + protected override string GetNamespace() => _namedTypeSymbol.ContainingNamespace.GetFullyQualifiedNameFromDisplayString(); public IEnumerable GetAttributes() => _namedTypeSymbol.GetAttributes(); @@ -90,7 +89,7 @@ protected override FieldProvider[] BuildFields() var fieldProvider = new FieldProvider( modifiers, - GetCSharpType(fieldSymbol.Type), + fieldSymbol.Type.GetCSharpType(), fieldSymbol.Name, this, GetSymbolXmlDoc(fieldSymbol, "summary")) @@ -112,7 +111,7 @@ protected override PropertyProvider[] BuildProperties() var propertyProvider = new PropertyProvider( GetSymbolXmlDoc(propertySymbol, "summary"), GetAccessModifier(propertySymbol.DeclaredAccessibility), - GetCSharpType(propertySymbol.Type), + propertySymbol.Type.GetCSharpType(), propertySymbol.Name, new AutoPropertyBody(propertySymbol.SetMethod is not null), this) @@ -179,7 +178,7 @@ private ParameterProvider ConvertToParameterProvider(IMethodSymbol methodSymbol, return new ParameterProvider( parameterSymbol.Name, FormattableStringHelpers.FromString(GetParameterXmlDocumentation(methodSymbol, parameterSymbol)) ?? FormattableStringHelpers.Empty, - GetCSharpType(parameterSymbol.Type)); + parameterSymbol.Type.GetCSharpType()); } private void AddAdditionalModifiers(IMethodSymbol methodSymbol, ref MethodSignatureModifiers modifiers) @@ -264,197 +263,12 @@ private static XDocument ParseXml(ISymbol docsSymbol, string xmlDocumentation) private CSharpType? GetNullableCSharpType(ITypeSymbol typeSymbol) { - var fullyQualifiedName = GetFullyQualifiedName(typeSymbol); + var fullyQualifiedName = typeSymbol.GetFullyQualifiedName(); if (fullyQualifiedName == "System.Void") { return null; } - return GetCSharpType(typeSymbol); - } - - private CSharpType GetCSharpType(ITypeSymbol typeSymbol) - { - var fullyQualifiedName = GetFullyQualifiedName(typeSymbol); - var namedTypeSymbol = typeSymbol as INamedTypeSymbol; - - Type? type = LoadFrameworkType(fullyQualifiedName); - - if (type is null) - { - return ConstructCSharpTypeFromSymbol(typeSymbol, fullyQualifiedName, namedTypeSymbol); - } - - CSharpType result = new CSharpType(type); - if (namedTypeSymbol is not null && namedTypeSymbol.IsGenericType && !result.IsNullable) - { - return result.MakeGenericType([.. namedTypeSymbol.TypeArguments.Select(GetCSharpType)]); - } - - return result; - } - - private static Type? LoadFrameworkType(string fullyQualifiedName) - { - return fullyQualifiedName switch - { - // Special case for types that would not be defined in corlib, but should still be considered framework types. - "System.BinaryData" => typeof(BinaryData), - _ => System.Type.GetType(fullyQualifiedName) - }; - } - - private CSharpType ConstructCSharpTypeFromSymbol( - ITypeSymbol typeSymbol, - string fullyQualifiedName, - INamedTypeSymbol? namedTypeSymbol) - { - var typeArg = namedTypeSymbol?.TypeArguments.FirstOrDefault(); - bool isValueType = typeSymbol.IsValueType; - bool isEnum = typeSymbol.TypeKind == TypeKind.Enum; - bool isNullable = typeSymbol.NullableAnnotation == NullableAnnotation.Annotated; - bool isNullableUnknownType = isNullable && typeArg?.TypeKind == TypeKind.Error; - string name = isNullableUnknownType ? fullyQualifiedName : typeSymbol.Name; - string[] pieces = fullyQualifiedName.Split('.'); - - // handle nullables - if (isNullable) - { - // System.Nullable`1[T] -> T - name = typeArg != null ? GetFullyQualifiedName(typeArg) : fullyQualifiedName; - pieces = name.Split('.'); - } - - return new CSharpType( - name, - string.Join('.', pieces.Take(pieces.Length - 1)), - isValueType, - isNullable, - typeSymbol.ContainingType is not null ? GetCSharpType(typeSymbol.ContainingType) : null, - namedTypeSymbol is not null && !isNullableUnknownType ? [.. namedTypeSymbol.TypeArguments.Select(GetCSharpType)] : [], - typeSymbol.DeclaredAccessibility == Accessibility.Public, - isValueType && !isEnum, - baseType: typeSymbol.BaseType is not null && typeSymbol.BaseType.TypeKind != TypeKind.Error && !isNullableUnknownType - ? GetCSharpType(typeSymbol.BaseType) - : null, - underlyingEnumType: namedTypeSymbol is not null && namedTypeSymbol.EnumUnderlyingType is not null - ? GetCSharpType(namedTypeSymbol.EnumUnderlyingType).FrameworkType - : null); - } - - private static string GetFullyQualifiedName(ITypeSymbol typeSymbol) - { - // Handle special cases for built-in types - switch (typeSymbol.SpecialType) - { - case SpecialType.System_Object: - return "System.Object"; - case SpecialType.System_Void: - return "System.Void"; - case SpecialType.System_Boolean: - return "System.Boolean"; - case SpecialType.System_Char: - return "System.Char"; - case SpecialType.System_SByte: - return "System.SByte"; - case SpecialType.System_Byte: - return "System.Byte"; - case SpecialType.System_Int16: - return "System.Int16"; - case SpecialType.System_UInt16: - return "System.UInt16"; - case SpecialType.System_Int32: - return "System.Int32"; - case SpecialType.System_UInt32: - return "System.UInt32"; - case SpecialType.System_Int64: - return "System.Int64"; - case SpecialType.System_UInt64: - return "System.UInt64"; - case SpecialType.System_Decimal: - return "System.Decimal"; - case SpecialType.System_Single: - return "System.Single"; - case SpecialType.System_Double: - return "System.Double"; - case SpecialType.System_String: - return "System.String"; - case SpecialType.System_DateTime: - return "System.DateTime"; - } - - // Handle array types - if (typeSymbol is IArrayTypeSymbol arrayTypeSymbol) - { - return GetFullyQualifiedName(arrayTypeSymbol.ElementType) + "[]"; - } - - // Handle generic types - if (typeSymbol is INamedTypeSymbol namedTypeSymbol && namedTypeSymbol.IsGenericType) - { - // Handle nullable types - if (typeSymbol.NullableAnnotation == NullableAnnotation.Annotated && !IsCollectionType(namedTypeSymbol)) - { - const string nullableTypeName = "System.Nullable"; - var argTypeSymbol = namedTypeSymbol.TypeArguments.FirstOrDefault(); - - if (argTypeSymbol != null) - { - if (argTypeSymbol.TypeKind == TypeKind.Error) - { - return GetFullyQualifiedName(argTypeSymbol); - } - - string[] typeArguments = [.. namedTypeSymbol.TypeArguments.Select(arg => "[" + GetFullyQualifiedName(arg) + "]")]; - return $"{nullableTypeName}`{namedTypeSymbol.TypeArguments.Length}[{string.Join(", ", typeArguments)}]"; - } - } - else if (namedTypeSymbol.TypeArguments.Length > 0 && !IsCollectionType(namedTypeSymbol)) - { - return GetNonNullableGenericTypeName(namedTypeSymbol); - } - - var typeNameSpan = namedTypeSymbol.ConstructedFrom.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat).AsSpan(); - var start = typeNameSpan.IndexOf(':') + 2; - var end = typeNameSpan.IndexOf('<'); - typeNameSpan = typeNameSpan.Slice(start, end - start); - return $"{typeNameSpan}`{namedTypeSymbol.TypeArguments.Length}"; - } - - // Default to fully qualified name - return GetFullyQualifiedNameFromDisplayString(typeSymbol); - } - - private static string GetNonNullableGenericTypeName(INamedTypeSymbol namedTypeSymbol) - { - string[] typeArguments = [.. namedTypeSymbol.TypeArguments.Select(GetFullyQualifiedName)]; - var fullName = namedTypeSymbol.ConstructedFrom.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); - - // Remove the type arguments from the fully qualified name - var typeArgumentStartIndex = fullName.IndexOf('<'); - var genericTypeName = typeArgumentStartIndex >= 0 ? fullName.Substring(0, typeArgumentStartIndex) : fullName; - - // Remove global:: prefix - if (genericTypeName.StartsWith(GlobalPrefix, StringComparison.Ordinal)) - { - genericTypeName = genericTypeName.Substring(GlobalPrefix.Length); - } - - return $"{genericTypeName}`{namedTypeSymbol.TypeArguments.Length}[{string.Join(", ", typeArguments)}]"; - } - - private static bool IsCollectionType(INamedTypeSymbol typeSymbol) - { - // Check if the type implements IEnumerable, ICollection, or IEnumerable - return typeSymbol.AllInterfaces.Any(i => - i.OriginalDefinition.SpecialType == SpecialType.System_Collections_Generic_IEnumerable_T || - i.OriginalDefinition.SpecialType == SpecialType.System_Collections_Generic_ICollection_T || - i.OriginalDefinition.SpecialType == SpecialType.System_Collections_IEnumerable); - } - - private static string GetFullyQualifiedNameFromDisplayString(ISymbol typeSymbol) - { - var fullyQualifiedName = typeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); - return fullyQualifiedName.StartsWith(GlobalPrefix, StringComparison.Ordinal) ? fullyQualifiedName.Substring(GlobalPrefix.Length) : fullyQualifiedName; + return typeSymbol.GetCSharpType(); } } } diff --git a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/src/Providers/TypeProvider.cs b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/src/Providers/TypeProvider.cs index 8cb696aec6..6c302d9140 100644 --- a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/src/Providers/TypeProvider.cs +++ b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/src/Providers/TypeProvider.cs @@ -424,7 +424,7 @@ private static bool IsMatch(TypeProvider enclosingType, MethodSignatureBase sign } else if (attribute.ConstructorArguments[1].Kind != TypedConstantKind.Array) { - parameterTypes = [(ISymbol?) attribute.ConstructorArguments[1].Value]; + parameterTypes = attribute.ConstructorArguments[1..].Select(a => (ISymbol?) a.Value).ToArray(); } else { @@ -437,7 +437,8 @@ private static bool IsMatch(TypeProvider enclosingType, MethodSignatureBase sign for (int i = 0; i < parameterTypes.Length; i++) { - if (parameterTypes[i]?.Name != signature.Parameters[i].Type.Name) + var parameterType = ((ITypeSymbol)parameterTypes[i]!).GetCSharpType(); + if (parameterType.Name != signature.Parameters[i].Type.Name || parameterType.IsNullable != signature.Parameters[i].Type.IsNullable) { return false; } diff --git a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/src/Utilities/NamedTypeSymbolExtensions.cs b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/src/Utilities/NamedTypeSymbolExtensions.cs deleted file mode 100644 index 8b6c5be3e4..0000000000 --- a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/src/Utilities/NamedTypeSymbolExtensions.cs +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -using System.Linq; -using Microsoft.CodeAnalysis; -using Microsoft.Generator.CSharp.Primitives; - -namespace Microsoft.Generator.CSharp -{ - internal static class NamedTypeSymbolExtensions - { - public static bool IsSameType(this INamedTypeSymbol symbol, CSharpType type) - { - if (type.IsValueType && type.IsNullable) - { - if (symbol.ConstructedFrom.SpecialType != SpecialType.System_Nullable_T) - return false; - return IsSameType((INamedTypeSymbol)symbol.TypeArguments.Single(), type.WithNullable(false)); - } - - if (symbol.ContainingNamespace.ToString() != type.Namespace || symbol.Name != type.Name || symbol.TypeArguments.Length != type.Arguments.Count) - { - return false; - } - - for (int i = 0; i < type.Arguments.Count; ++i) - { - if (!IsSameType((INamedTypeSymbol)symbol.TypeArguments[i], type.Arguments[i])) - { - return false; - } - } - return true; - } - } -} diff --git a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/src/Utilities/TypeSymbolExtensions.cs b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/src/Utilities/TypeSymbolExtensions.cs new file mode 100644 index 0000000000..1efa2c8704 --- /dev/null +++ b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/src/Utilities/TypeSymbolExtensions.cs @@ -0,0 +1,224 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Linq; +using Microsoft.CodeAnalysis; +using Microsoft.Generator.CSharp.Primitives; + +namespace Microsoft.Generator.CSharp +{ + internal static class TypeSymbolExtensions + { + private const string GlobalPrefix = "global::"; + + public static bool IsSameType(this INamedTypeSymbol symbol, CSharpType type) + { + if (type.IsValueType && type.IsNullable) + { + if (symbol.ConstructedFrom.SpecialType != SpecialType.System_Nullable_T) + return false; + return IsSameType((INamedTypeSymbol)symbol.TypeArguments.Single(), type.WithNullable(false)); + } + + if (symbol.ContainingNamespace.ToString() != type.Namespace || symbol.Name != type.Name || symbol.TypeArguments.Length != type.Arguments.Count) + { + return false; + } + + for (int i = 0; i < type.Arguments.Count; ++i) + { + if (!IsSameType((INamedTypeSymbol)symbol.TypeArguments[i], type.Arguments[i])) + { + return false; + } + } + return true; + } + + public static CSharpType GetCSharpType(this ITypeSymbol typeSymbol) + { + var fullyQualifiedName = GetFullyQualifiedName(typeSymbol); + var namedTypeSymbol = typeSymbol as INamedTypeSymbol; + + Type? type = LoadFrameworkType(fullyQualifiedName); + + if (type is null) + { + return ConstructCSharpTypeFromSymbol(typeSymbol, fullyQualifiedName, namedTypeSymbol); + } + + CSharpType result = new CSharpType(type); + if (namedTypeSymbol is not null && namedTypeSymbol.IsGenericType && !result.IsNullable) + { + return result.MakeGenericType([.. namedTypeSymbol.TypeArguments.Select(GetCSharpType)]); + } + + return result; + } + + public static string GetFullyQualifiedName(this ITypeSymbol typeSymbol) + { + // Handle special cases for built-in types + switch (typeSymbol.SpecialType) + { + case SpecialType.System_Object: + return "System.Object"; + case SpecialType.System_Void: + return "System.Void"; + case SpecialType.System_Boolean: + return "System.Boolean"; + case SpecialType.System_Char: + return "System.Char"; + case SpecialType.System_SByte: + return "System.SByte"; + case SpecialType.System_Byte: + return "System.Byte"; + case SpecialType.System_Int16: + return "System.Int16"; + case SpecialType.System_UInt16: + return "System.UInt16"; + case SpecialType.System_Int32: + return "System.Int32"; + case SpecialType.System_UInt32: + return "System.UInt32"; + case SpecialType.System_Int64: + return "System.Int64"; + case SpecialType.System_UInt64: + return "System.UInt64"; + case SpecialType.System_Decimal: + return "System.Decimal"; + case SpecialType.System_Single: + return "System.Single"; + case SpecialType.System_Double: + return "System.Double"; + case SpecialType.System_String: + return "System.String"; + case SpecialType.System_DateTime: + return "System.DateTime"; + } + + // Handle array types + if (typeSymbol is IArrayTypeSymbol arrayTypeSymbol) + { + return GetFullyQualifiedName(arrayTypeSymbol.ElementType) + "[]"; + } + + // Handle generic types + if (typeSymbol is INamedTypeSymbol namedTypeSymbol && namedTypeSymbol.IsGenericType) + { + // Handle nullable types + if (typeSymbol.NullableAnnotation == NullableAnnotation.Annotated && !IsCollectionType(namedTypeSymbol)) + { + const string nullableTypeName = "System.Nullable"; + var argTypeSymbol = namedTypeSymbol.TypeArguments.FirstOrDefault(); + + if (argTypeSymbol != null) + { + if (argTypeSymbol.TypeKind == TypeKind.Error) + { + return GetFullyQualifiedName(argTypeSymbol); + } + + string[] typeArguments = [.. namedTypeSymbol.TypeArguments.Select(arg => "[" + GetFullyQualifiedName(arg) + "]")]; + return $"{nullableTypeName}`{namedTypeSymbol.TypeArguments.Length}[{string.Join(", ", typeArguments)}]"; + } + } + else if (namedTypeSymbol.TypeArguments.Length > 0 && !IsCollectionType(namedTypeSymbol)) + { + return GetNonNullableGenericTypeName(namedTypeSymbol); + } + + var typeNameSpan = namedTypeSymbol.ConstructedFrom.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat).AsSpan(); + var start = typeNameSpan.IndexOf(':') + 2; + var end = typeNameSpan.IndexOf('<'); + typeNameSpan = typeNameSpan.Slice(start, end - start); + return $"{typeNameSpan}`{namedTypeSymbol.TypeArguments.Length}"; + } + + // Default to fully qualified name + return GetFullyQualifiedNameFromDisplayString(typeSymbol); + } + + public static string GetFullyQualifiedNameFromDisplayString(this ISymbol typeSymbol) + { + var fullyQualifiedName = typeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + return fullyQualifiedName.StartsWith(GlobalPrefix, StringComparison.Ordinal) ? fullyQualifiedName.Substring(GlobalPrefix.Length) : fullyQualifiedName; + } + + private static Type? LoadFrameworkType(string fullyQualifiedName) + { + return fullyQualifiedName switch + { + // Special case for types that would not be defined in corlib, but should still be considered framework types. + "System.BinaryData" => typeof(BinaryData), + _ => System.Type.GetType(fullyQualifiedName) + }; + } + + private static CSharpType ConstructCSharpTypeFromSymbol( + ITypeSymbol typeSymbol, + string fullyQualifiedName, + INamedTypeSymbol? namedTypeSymbol) + { + var typeArg = namedTypeSymbol?.TypeArguments.FirstOrDefault(); + bool isValueType = typeSymbol.IsValueType; + bool isEnum = typeSymbol.TypeKind == TypeKind.Enum; + bool isNullable = typeSymbol.NullableAnnotation == NullableAnnotation.Annotated; + bool isNullableUnknownType = isNullable && typeArg?.TypeKind == TypeKind.Error; + string name = isNullableUnknownType ? fullyQualifiedName : typeSymbol.Name; + string[] pieces = fullyQualifiedName.Split('.'); + + // handle nullables + if (isNullable) + { + // System.Nullable`1[T] -> T + name = typeArg != null ? GetFullyQualifiedName(typeArg) : fullyQualifiedName; + pieces = name.Split('.'); + } + + return new CSharpType( + name, + string.Join('.', pieces.Take(pieces.Length - 1)), + isValueType, + isNullable, + typeSymbol.ContainingType is not null ? GetCSharpType(typeSymbol.ContainingType) : null, + namedTypeSymbol is not null && !isNullableUnknownType ? [.. namedTypeSymbol.TypeArguments.Select(GetCSharpType)] : [], + typeSymbol.DeclaredAccessibility == Accessibility.Public, + isValueType && !isEnum, + baseType: typeSymbol.BaseType is not null && typeSymbol.BaseType.TypeKind != TypeKind.Error && !isNullableUnknownType + ? GetCSharpType(typeSymbol.BaseType) + : null, + underlyingEnumType: namedTypeSymbol is not null && namedTypeSymbol.EnumUnderlyingType is not null + ? GetCSharpType(namedTypeSymbol.EnumUnderlyingType).FrameworkType + : null); + } + + private static bool IsCollectionType(INamedTypeSymbol typeSymbol) + { + // Check if the type implements IEnumerable, ICollection, or IEnumerable + return typeSymbol.AllInterfaces.Any(i => + i.OriginalDefinition.SpecialType == SpecialType.System_Collections_Generic_IEnumerable_T || + i.OriginalDefinition.SpecialType == SpecialType.System_Collections_Generic_ICollection_T || + i.OriginalDefinition.SpecialType == SpecialType.System_Collections_IEnumerable); + } + + private static string GetNonNullableGenericTypeName(INamedTypeSymbol namedTypeSymbol) + { + string[] typeArguments = [.. namedTypeSymbol.TypeArguments.Select(GetFullyQualifiedName)]; + var fullName = namedTypeSymbol.ConstructedFrom.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + + // Remove the type arguments from the fully qualified name + var typeArgumentStartIndex = fullName.IndexOf('<'); + var genericTypeName = typeArgumentStartIndex >= 0 ? fullName.Substring(0, typeArgumentStartIndex) : fullName; + + // Remove global:: prefix + if (genericTypeName.StartsWith(GlobalPrefix, StringComparison.Ordinal)) + { + genericTypeName = genericTypeName.Substring(GlobalPrefix.Length); + } + + return $"{genericTypeName}`{namedTypeSymbol.TypeArguments.Length}[{string.Join(", ", typeArguments)}]"; + } + } +} diff --git a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/test/Providers/ModelProviders/ClientCustomizationTests.cs b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/test/Providers/ModelProviders/ClientCustomizationTests.cs index a5b9143661..078591e3cd 100644 --- a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/test/Providers/ModelProviders/ClientCustomizationTests.cs +++ b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/test/Providers/ModelProviders/ClientCustomizationTests.cs @@ -46,7 +46,17 @@ public async Task CanRemoveMethods() [ new ParameterProvider("param1", $"", typeof(string)), new ParameterProvider("param2", $"", typeof(int))]), - Snippet.ThrowExpression(Snippet.Null), client) + Snippet.ThrowExpression(Snippet.Null), client), + new MethodProvider(new MethodSignature( + "Method4", + $"", + MethodSignatureModifiers.Public, + null, + $"", + [ + new ParameterProvider("param1", $"", typeof(string)), + new ParameterProvider("param2", $"", typeof(int?))]), + Snippet.ThrowExpression(Snippet.Null), client), }; client.MethodProviders = methods; @@ -97,6 +107,17 @@ public async Task DoesNotRemoveMethodsThatDoNotMatch() [ new ParameterProvider("param1", $"", typeof(string)), new ParameterProvider("param2", $"", typeof(int))]), + Snippet.ThrowExpression(Snippet.Null), client), + // Nullability of one of the parameters doesn't match + new MethodProvider(new MethodSignature( + "Method4", + $"", + MethodSignatureModifiers.Public, + null, + $"", + [ + new ParameterProvider("param1", $"", typeof(string)), + new ParameterProvider("param2", $"", typeof(int?))]), Snippet.ThrowExpression(Snippet.Null), client) }; @@ -109,7 +130,7 @@ public async Task DoesNotRemoveMethodsThatDoNotMatch() var csharpGen = new CSharpGen(); await csharpGen.ExecuteAsync(); - Assert.AreEqual(3, plugin.Object.OutputLibrary.TypeProviders.Single(t => t.Name == "MockInputClient").Methods.Count); + Assert.AreEqual(4, plugin.Object.OutputLibrary.TypeProviders.Single(t => t.Name == "MockInputClient").Methods.Count); } [Test] diff --git a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/test/Providers/ModelProviders/TestData/ClientCustomizationTests/CanRemoveMethods/MockInputClient.cs b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/test/Providers/ModelProviders/TestData/ClientCustomizationTests/CanRemoveMethods/MockInputClient.cs index 65e9152519..6f648fee0a 100644 --- a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/test/Providers/ModelProviders/TestData/ClientCustomizationTests/CanRemoveMethods/MockInputClient.cs +++ b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/test/Providers/ModelProviders/TestData/ClientCustomizationTests/CanRemoveMethods/MockInputClient.cs @@ -6,6 +6,7 @@ namespace Sample; [CodeGenSuppress("Method1")] [CodeGenSuppress("Method2", typeof(bool)] [CodeGenSuppress("Method3", typeof(string), typeof(int))] +[CodeGenSuppress("Method4", typeof(string), typeof(int?))] public partial class MockInputClient { } diff --git a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/test/Providers/ModelProviders/TestData/ClientCustomizationTests/DoesNotRemoveMethodsThatDoNotMatch/MockInputClient.cs b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/test/Providers/ModelProviders/TestData/ClientCustomizationTests/DoesNotRemoveMethodsThatDoNotMatch/MockInputClient.cs index 550401ac55..a33808a7ce 100644 --- a/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/test/Providers/ModelProviders/TestData/ClientCustomizationTests/DoesNotRemoveMethodsThatDoNotMatch/MockInputClient.cs +++ b/packages/http-client-csharp/generator/Microsoft.Generator.CSharp/test/Providers/ModelProviders/TestData/ClientCustomizationTests/DoesNotRemoveMethodsThatDoNotMatch/MockInputClient.cs @@ -5,6 +5,7 @@ namespace Sample; [CodeGenSuppress("Method1")] [CodeGenSuppress("Method2", typeof(bool)] [CodeGenSuppress("Method3", typeof(string), typeof(int), typeof(bool))] +[CodeGenSuppress("Method4", typeof(string), typeof(int)] public partial class MockInputClient { }