Skip to content

Commit

Permalink
Add PreserveSig support to ComInterfaceGenerator (#85941)
Browse files Browse the repository at this point in the history
  • Loading branch information
jkoritzinsky committed May 10, 2023
1 parent 32db631 commit 40a1cea
Show file tree
Hide file tree
Showing 3 changed files with 169 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Reflection;
using System.Threading;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
Expand Down Expand Up @@ -448,46 +449,49 @@ private static IncrementalMethodStubGenerationContext CalculateStubInformation(M
// Create the stub.
var signatureContext = SignatureContext.Create(symbol, DefaultMarshallingInfoParser.Create(environment, generatorDiagnostics, symbol, new InteropAttributeCompilationData(), generatedComAttribute), environment, typeof(VtableIndexStubGenerator).Assembly);

// Search for the element information for the managed return value.
// We need to transform it such that any return type is converted to an out parameter at the end of the parameter list.
ImmutableArray<TypePositionInfo> returnSwappedSignatureElements = signatureContext.ElementTypeInformation;
for (int i = 0; i < returnSwappedSignatureElements.Length; ++i)
if (!symbol.MethodImplementationFlags.HasFlag(MethodImplAttributes.PreserveSig))
{
if (returnSwappedSignatureElements[i].IsManagedReturnPosition)
// Search for the element information for the managed return value.
// We need to transform it such that any return type is converted to an out parameter at the end of the parameter list.
ImmutableArray<TypePositionInfo> returnSwappedSignatureElements = signatureContext.ElementTypeInformation;
for (int i = 0; i < returnSwappedSignatureElements.Length; ++i)
{
if (returnSwappedSignatureElements[i].ManagedType == SpecialTypeInfo.Void)
if (returnSwappedSignatureElements[i].IsManagedReturnPosition)
{
// Return type is void, just remove the element from the signature list.
// We don't introduce an out parameter.
returnSwappedSignatureElements = returnSwappedSignatureElements.RemoveAt(i);
}
else
{
// Convert the current element into an out parameter on the native signature
// while keeping it at the return position in the managed signature.
var managedSignatureAsNativeOut = returnSwappedSignatureElements[i] with
if (returnSwappedSignatureElements[i].ManagedType == SpecialTypeInfo.Void)
{
RefKind = RefKind.Out,
RefKindSyntax = SyntaxKind.OutKeyword,
ManagedIndex = TypePositionInfo.ReturnIndex,
NativeIndex = symbol.Parameters.Length
};
returnSwappedSignatureElements = returnSwappedSignatureElements.SetItem(i, managedSignatureAsNativeOut);
// Return type is void, just remove the element from the signature list.
// We don't introduce an out parameter.
returnSwappedSignatureElements = returnSwappedSignatureElements.RemoveAt(i);
}
else
{
// Convert the current element into an out parameter on the native signature
// while keeping it at the return position in the managed signature.
var managedSignatureAsNativeOut = returnSwappedSignatureElements[i] with
{
RefKind = RefKind.Out,
RefKindSyntax = SyntaxKind.OutKeyword,
ManagedIndex = TypePositionInfo.ReturnIndex,
NativeIndex = symbol.Parameters.Length
};
returnSwappedSignatureElements = returnSwappedSignatureElements.SetItem(i, managedSignatureAsNativeOut);
}
break;
}
break;
}
}

signatureContext = signatureContext with
{
// Add the HRESULT return value in the native signature.
// This element does not have any influence on the managed signature, so don't assign a managed index.
ElementTypeInformation = returnSwappedSignatureElements.Add(
new TypePositionInfo(SpecialTypeInfo.Int32, new ManagedHResultExceptionMarshallingInfo())
{
NativeIndex = TypePositionInfo.ReturnIndex
})
};
signatureContext = signatureContext with
{
// Add the HRESULT return value in the native signature.
// This element does not have any influence on the managed signature, so don't assign a managed index.
ElementTypeInformation = returnSwappedSignatureElements.Add(
new TypePositionInfo(SpecialTypeInfo.Int32, new ManagedHResultExceptionMarshallingInfo())
{
NativeIndex = TypePositionInfo.ReturnIndex
})
};
}

var containingSyntaxContext = new ContainingSyntaxContext(syntax);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,21 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection.Metadata;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.Operations;
using Microsoft.CodeAnalysis.Testing;
using Microsoft.Interop;
using Xunit;

using VerifyCS = Microsoft.Interop.UnitTests.Verifiers.CSharpSourceGeneratorVerifier<Microsoft.Interop.VtableIndexStubGenerator>;
using VerifyCS = Microsoft.Interop.UnitTests.Verifiers.CSharpSourceGeneratorVerifier<Microsoft.CodeAnalysis.Testing.EmptySourceGeneratorProvider>;

namespace ComInterfaceGenerator.Unit.Tests
{
public class CallingConventionForwarding
public class TargetSignatureTests
{
[Fact]
public async Task NoSpecifiedCallConvForwardsDefault()
Expand All @@ -32,7 +34,7 @@ partial interface INativeAPI : IUnmanagedInterfaceType
}
""";

await VerifySourceGeneratorAsync(source, "INativeAPI", "Method", (compilation, signature) =>
await VerifyVirtualMethodIndexGeneratorAsync(source, "INativeAPI", "Method", (compilation, signature) =>
{
Assert.Equal(SignatureCallingConvention.Unmanaged, signature.CallingConvention);
Assert.Empty(signature.UnmanagedCallingConventionTypes);
Expand All @@ -56,7 +58,7 @@ partial interface INativeAPI : IUnmanagedInterfaceType
}
""";

await VerifySourceGeneratorAsync(source, "INativeAPI", "Method", (newComp, signature) =>
await VerifyVirtualMethodIndexGeneratorAsync(source, "INativeAPI", "Method", (newComp, signature) =>
{
Assert.Equal(SignatureCallingConvention.Unmanaged, signature.CallingConvention);
Assert.Equal(newComp.GetTypeByMetadataName("System.Runtime.CompilerServices.CallConvSuppressGCTransition"), Assert.Single(signature.UnmanagedCallingConventionTypes), SymbolEqualityComparer.Default);
Expand All @@ -80,7 +82,7 @@ partial interface INativeAPI : IUnmanagedInterfaceType
}
""";

await VerifySourceGeneratorAsync(source, "INativeAPI", "Method", (_, signature) =>
await VerifyVirtualMethodIndexGeneratorAsync(source, "INativeAPI", "Method", (_, signature) =>
{
Assert.Equal(SignatureCallingConvention.Unmanaged, signature.CallingConvention);
Assert.Empty(signature.UnmanagedCallingConventionTypes);
Expand All @@ -105,7 +107,7 @@ partial interface INativeAPI : IUnmanagedInterfaceType
}
""";

await VerifySourceGeneratorAsync(source, "INativeAPI", "Method", (_, signature) =>
await VerifyVirtualMethodIndexGeneratorAsync(source, "INativeAPI", "Method", (_, signature) =>
{
Assert.Equal(SignatureCallingConvention.CDecl, signature.CallingConvention);
Assert.Empty(signature.UnmanagedCallingConventionTypes);
Expand All @@ -130,7 +132,7 @@ partial interface INativeAPI : IUnmanagedInterfaceType
}
""";

await VerifySourceGeneratorAsync(source, "INativeAPI", "Method", (newComp, signature) =>
await VerifyVirtualMethodIndexGeneratorAsync(source, "INativeAPI", "Method", (newComp, signature) =>
{
Assert.Equal(SignatureCallingConvention.Unmanaged, signature.CallingConvention);
Assert.Equal(new[]
Expand Down Expand Up @@ -162,7 +164,7 @@ partial interface INativeAPI : IUnmanagedInterfaceType
}
""";

await VerifySourceGeneratorAsync(source, "INativeAPI", "Method", (newComp, signature) =>
await VerifyVirtualMethodIndexGeneratorAsync(source, "INativeAPI", "Method", (newComp, signature) =>
{
Assert.Equal(SignatureCallingConvention.Unmanaged, signature.CallingConvention);
Assert.Equal(new[]
Expand All @@ -176,24 +178,105 @@ await VerifySourceGeneratorAsync(source, "INativeAPI", "Method", (newComp, signa
});
}

private static async Task VerifySourceGeneratorAsync(string source, string interfaceName, string methodName, Action<Compilation, IMethodSymbol> signatureValidator)
[Fact]
public async Task ComInterfaceMethodFunctionPointerReturnsInt()
{
CallingConventionForwardingTest test = new(interfaceName, methodName, signatureValidator)
string source = $$"""
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Runtime.InteropServices.Marshalling;

[GeneratedComInterface]
[Guid("0A617667-4961-4F90-B74F-6DC368E98179")]
partial interface IComInterface
{
void Method();
}
""";

await VerifyComInterfaceGeneratorAsync(source, "IComInterface", "Method", (newComp, signature) =>
{
Assert.Equal(SpecialType.System_Int32, signature.ReturnType.SpecialType);
});
}

[Fact]
public async Task ComInterfaceMethodFunctionPointerReturnTypeChangedToOutParameter()
{
string source = $$"""
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Runtime.InteropServices.Marshalling;

[GeneratedComInterface]
[Guid("0A617667-4961-4F90-B74F-6DC368E98179")]
partial interface IComInterface
{
long Method();
}
""";

await VerifyComInterfaceGeneratorAsync(source, "IComInterface", "Method", (newComp, signature) =>
{
Assert.Equal(SpecialType.System_Int32, signature.ReturnType.SpecialType);
Assert.Equal(2, signature.Parameters.Length);
Assert.Equal(newComp.CreatePointerTypeSymbol(newComp.GetSpecialType(SpecialType.System_Void)), signature.Parameters[0].Type, SymbolEqualityComparer.Default);
Assert.Equal(newComp.CreatePointerTypeSymbol(newComp.GetSpecialType(SpecialType.System_Int64)), signature.Parameters[^1].Type, SymbolEqualityComparer.Default);
});
}

[Fact]
public async Task ComInterfaceMethodPreserveSigFunctionPointerReturnTypePreserved()
{
string source = $$"""
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Runtime.InteropServices.Marshalling;

[GeneratedComInterface]
[Guid("0A617667-4961-4F90-B74F-6DC368E98179")]
partial interface IComInterface
{
[PreserveSig]
long Method();
}
""";

await VerifyComInterfaceGeneratorAsync(source, "IComInterface", "Method", (newComp, signature) =>
{
Assert.Equal(SpecialType.System_Int64, signature.ReturnType.SpecialType);
Assert.Equal(newComp.CreatePointerTypeSymbol(newComp.GetSpecialType(SpecialType.System_Void)), Assert.Single(signature.Parameters).Type, SymbolEqualityComparer.Default);
});
}

private static async Task VerifyVirtualMethodIndexGeneratorAsync(string source, string interfaceName, string methodName, Action<Compilation, IMethodSymbol> signatureValidator)
{
VirtualMethodIndexTargetSignatureTest test = new(interfaceName, methodName, signatureValidator)
{
TestCode = source,
TestBehaviors = TestBehaviors.SkipGeneratedSourcesCheck
};

await test.RunAsync();
}
private static async Task VerifyComInterfaceGeneratorAsync(string source, string interfaceName, string methodName, Action<Compilation, IMethodSymbol> signatureValidator)
{
ComInterfaceTargetSignatureTest test = new(interfaceName, methodName, signatureValidator)
{
TestCode = source,
TestBehaviors = TestBehaviors.SkipGeneratedSourcesCheck
};

class CallingConventionForwardingTest : VerifyCS.Test
await test.RunAsync();
}

private abstract class TargetSignatureTestBase : VerifyCS.Test
{
private readonly Action<Compilation, IMethodSymbol> _signatureValidator;
private readonly string _interfaceName;
private readonly string _methodName;

public CallingConventionForwardingTest(string interfaceName, string methodName, Action<Compilation, IMethodSymbol> signatureValidator)
protected TargetSignatureTestBase(string interfaceName, string methodName, Action<Compilation, IMethodSymbol> signatureValidator)
: base(referenceAncillaryInterop: true)
{
_signatureValidator = signatureValidator;
Expand All @@ -205,12 +288,14 @@ protected override void VerifyFinalCompilation(Compilation compilation)
{
_signatureValidator(compilation, FindFunctionPointerInvocationSignature(compilation));
}

protected abstract INamedTypeSymbol FindImplementationInterface(Compilation compilation, INamedTypeSymbol userDefinedInterface);
private IMethodSymbol FindFunctionPointerInvocationSignature(Compilation compilation)
{
INamedTypeSymbol? userDefinedInterface = compilation.Assembly.GetTypeByMetadataName(_interfaceName);
Assert.NotNull(userDefinedInterface);

INamedTypeSymbol generatedInterfaceImplementation = Assert.Single(userDefinedInterface.GetTypeMembers("Native"));
INamedTypeSymbol generatedInterfaceImplementation = FindImplementationInterface(compilation, userDefinedInterface);

IMethodSymbol methodImplementation = Assert.Single(generatedInterfaceImplementation.GetMembers($"global::{_interfaceName}.{_methodName}").OfType<IMethodSymbol>());

Expand All @@ -223,5 +308,38 @@ private IMethodSymbol FindFunctionPointerInvocationSignature(Compilation compila
return Assert.Single(body.Descendants().OfType<IFunctionPointerInvocationOperation>()).GetFunctionPointerSignature();
}
}

private sealed class VirtualMethodIndexTargetSignatureTest : TargetSignatureTestBase
{
public VirtualMethodIndexTargetSignatureTest(string interfaceName, string methodName, Action<Compilation, IMethodSymbol> signatureValidator)
: base(interfaceName, methodName, signatureValidator)
{
}

protected override IEnumerable<Type> GetSourceGenerators() => new[] { typeof(VtableIndexStubGenerator) };

protected override INamedTypeSymbol FindImplementationInterface(Compilation compilation, INamedTypeSymbol userDefinedInterface) => Assert.Single(userDefinedInterface.GetTypeMembers("Native"));
}

private sealed class ComInterfaceTargetSignatureTest : TargetSignatureTestBase
{
public ComInterfaceTargetSignatureTest(string interfaceName, string methodName, Action<Compilation, IMethodSymbol> signatureValidator) : base(interfaceName, methodName, signatureValidator)
{
}
protected override IEnumerable<Type> GetSourceGenerators() => new[] { typeof(Microsoft.Interop.ComInterfaceGenerator) };

protected override INamedTypeSymbol FindImplementationInterface(Compilation compilation, INamedTypeSymbol userDefinedInterface)
{
INamedTypeSymbol? iUnknownDerivedAttributeType = compilation.GetTypeByMetadataName("System.Runtime.InteropServices.Marshalling.IUnknownDerivedAttribute`2");

Assert.NotNull(iUnknownDerivedAttributeType);

AttributeData iUnknownDerivedAttribute = Assert.Single(
userDefinedInterface.GetAttributes(),
attr => SymbolEqualityComparer.Default.Equals(attr.AttributeClass?.OriginalDefinition, iUnknownDerivedAttributeType));

return (INamedTypeSymbol)iUnknownDerivedAttribute.AttributeClass!.TypeArguments[1];
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
namespace Microsoft.Interop.UnitTests.Verifiers
{
public static class CSharpSourceGeneratorVerifier<TSourceGenerator>
where TSourceGenerator : IIncrementalGenerator, new()
where TSourceGenerator : new()
{
public static DiagnosticResult Diagnostic(string diagnosticId)
=> new DiagnosticResult(diagnosticId, DiagnosticSeverity.Error);
Expand Down

0 comments on commit 40a1cea

Please sign in to comment.