From 36a85b00695007652d8a2d5e1c5e9d99b7247505 Mon Sep 17 00:00:00 2001 From: Jackson Schuster <36744439+jtschuster@users.noreply.github.com> Date: Wed, 29 May 2024 20:35:38 -0700 Subject: [PATCH] Propagate the 'this' keyword in LibraryImports (#102793) Since Roslyn just lowers the extension methods to a call to the static method there's no reason we shouldn't support creating extension methods that are also LibraryImport methods. This adds the IsExplicitThis property to TypePositionInfo and copies this to the generated signature if it is true. --- .../Marshalling/MarshallerHelpers.cs | 5 + .../TypePositionInfo.cs | 7 +- .../ArrayTests.cs | 283 ++++++++++++++++++ .../BlittableStructTests.cs | 21 ++ .../CodeSnippets.cs | 9 + .../Compiles.cs | 3 +- 6 files changed, 325 insertions(+), 3 deletions(-) diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/MarshallerHelpers.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/MarshallerHelpers.cs index 91028ff5a6ebd..25e09c52a6008 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/MarshallerHelpers.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/MarshallerHelpers.cs @@ -426,6 +426,11 @@ public static SyntaxTokenList GetManagedParameterModifiers(TypePositionInfo type } } + if (typeInfo.IsExplicitThis) + { + tokens = tokens.Add(Token(SyntaxKind.ThisKeyword)); + } + return tokens; } diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/TypePositionInfo.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/TypePositionInfo.cs index 49ebf92f600ec..69654c34c4f75 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/TypePositionInfo.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/TypePositionInfo.cs @@ -3,9 +3,10 @@ using System; using System.Collections.Generic; - +using System.Linq; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; namespace Microsoft.Interop @@ -77,6 +78,7 @@ public static int IncrementIndex(int index) public int ManagedIndex { get; init; } = UnsetIndex; public int NativeIndex { get; init; } = UnsetIndex; + public bool IsExplicitThis { get; init; } public static TypePositionInfo CreateForParameter(IParameterSymbol paramSymbol, MarshallingInfo marshallingInfo, Compilation compilation) { @@ -88,7 +90,8 @@ public static TypePositionInfo CreateForParameter(IParameterSymbol paramSymbol, RefKind = paramSymbol.RefKind, ByValueContentsMarshalKind = byValueContentsMarshalKind, ByValueMarshalAttributeLocations = (inLocation, outLocation), - ScopedKind = paramSymbol.ScopedKind + ScopedKind = paramSymbol.ScopedKind, + IsExplicitThis = ((ParameterSyntax)paramSymbol.DeclaringSyntaxReferences[0].GetSyntax()).Modifiers.Any(SyntaxKind.ThisKeyword) }; return typeInfo; diff --git a/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.Tests/ArrayTests.cs b/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.Tests/ArrayTests.cs index e082fc4b25417..c0d7398b039c1 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.Tests/ArrayTests.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.Tests/ArrayTests.cs @@ -115,6 +115,73 @@ public static partial BoolStruct[] NegateBools( } } + public static partial class ArrayNativeExtensions + { + // The first parameter of a 'ref' extension method must be a value type or a generic type constrained to struct. + // The first 'in' or 'ref readonly' parameter of the extension method must be a concrete (non-generic) value type. + + [LibraryImport(NativeExportsNE.NativeExportsNE_Binary, EntryPoint = "sum_int_array")] + public static partial int Sum(this int[] values, int numValues); + + [LibraryImport(NativeExportsNE.NativeExportsNE_Binary, EntryPoint = "sum_int_array")] + public static partial int Sum(this ref int values, int numValues); + + [LibraryImport(NativeExportsNE.NativeExportsNE_Binary, EntryPoint = "sum_char_array", StringMarshalling = StringMarshalling.Utf16)] + public static partial int SumChars(this char[] chars, int numElements); + + [LibraryImport(NativeExportsNE.NativeExportsNE_Binary, EntryPoint = "fill_char_array", StringMarshalling = StringMarshalling.Utf16)] + public static partial void FillChars([Out] this char[] chars, int length, ushort start); + + [LibraryImport(NativeExportsNE.NativeExportsNE_Binary, EntryPoint = "sum_string_lengths")] + public static partial int SumStringLengths([MarshalAs(UnmanagedType.LPArray, ArraySubType = UnmanagedType.LPWStr)] this string[] strArray); + + [LibraryImport(NativeExportsNE.NativeExportsNE_Binary, EntryPoint = "reverse_strings_return")] + [return: MarshalAs(UnmanagedType.LPArray, ArraySubType = UnmanagedType.LPWStr, SizeParamIndex = 1)] + public static partial string[] ReverseStrings_Return([MarshalAs(UnmanagedType.LPArray, ArraySubType = UnmanagedType.LPWStr)] this string[] strArray, out int numElements); + + [LibraryImport(NativeExportsNE.NativeExportsNE_Binary, EntryPoint = "reverse_strings_out")] + public static partial void ReverseStrings_Out([MarshalAs(UnmanagedType.LPArray, ArraySubType = UnmanagedType.LPWStr)] this string[] strArray, out int numElements, [MarshalAs(UnmanagedType.LPArray, ArraySubType = UnmanagedType.LPWStr, SizeParamIndex = 1)] out string[] res); + + [LibraryImport(NativeExportsNE.NativeExportsNE_Binary, EntryPoint = "get_long_bytes")] + [return: MarshalAs(UnmanagedType.LPArray, SizeConst = sizeof(long))] + public static partial byte[] GetLongBytes(this long l); + + [LibraryImport(NativeExportsNE.NativeExportsNE_Binary, EntryPoint = "fill_range_array")] + [return: MarshalAs(UnmanagedType.U1)] + public static partial bool FillRangeArray([Out] this IntStructWrapper[] array, int length, int start); + + [LibraryImport(NativeExportsNE.NativeExportsNE_Binary, EntryPoint = "double_values")] + public static partial void DoubleValues([In, Out] this IntStructWrapper[] array, int length); + + [LibraryImport(NativeExportsNE.NativeExportsNE_Binary, EntryPoint = "and_bool_struct_array")] + [return: MarshalAs(UnmanagedType.U1)] + public static partial bool AndAllMembers(this BoolStruct[] pArray, int length); + + [LibraryImport(NativeExportsNE.NativeExportsNE_Binary, EntryPoint = "negate_bool_struct_array_out")] + public static partial void NegateBools( + this BoolStruct[] boolStruct, + int numValues, + [MarshalUsing(CountElementName = "numValues")] out BoolStruct[] pBoolStructOut); + + [LibraryImport(NativeExportsNE.NativeExportsNE_Binary, EntryPoint = "negate_bool_struct_array_return")] + [return: MarshalUsing(CountElementName = "numValues")] + public static partial BoolStruct[] NegateBools( + this BoolStruct[] boolStruct, + int numValues); + + [LibraryImport(NativeExportsNE.NativeExportsNE_Binary, EntryPoint = "transpose_matrix")] + [return: MarshalUsing(CountElementName = "numColumns")] + [return: MarshalUsing(CountElementName = "numRows", ElementIndirectionDepth = 1)] + public static partial int[][] TransposeMatrix(this int[][] matrix, int[] numRows, int numColumns); + + [LibraryImport(NativeExportsNE.NativeExportsNE_Binary, EntryPoint = "sum_int_ptr_array")] + public static unsafe partial int Sum(this int*[] values, int numValues); + + [LibraryImport(NativeExportsNE.NativeExportsNE_Binary, EntryPoint = "return_duplicate_int_ptr_array")] + [return: MarshalAs(UnmanagedType.LPArray, SizeParamIndex = 1)] + public static unsafe partial int*[] ReturnDuplicate(this int*[] values, int numValues); + } + public class ArrayTests { private int[] GetIntArray() => new[] { 1, 5, 79, 165, 32, 3 }; @@ -126,6 +193,13 @@ public void IntArray_ByValue() Assert.Equal(array.Sum(), NativeExportsNE.Arrays.Sum(array, array.Length)); } + [Fact] + public void IntArray_ByValue_This() + { + int[] array = GetIntArray(); + Assert.Equal(array.Sum(), array.Sum(array.Length)); + } + [Fact] public void IntArray_RefToFirstElement() { @@ -133,6 +207,13 @@ public void IntArray_RefToFirstElement() Assert.Equal(array.Sum(), NativeExportsNE.Arrays.Sum(ref array[0], array.Length)); } + [Fact] + public void IntArray_RefToFirstElement_This() + { + int[] array = GetIntArray(); + Assert.Equal(array.Sum(), array[0].Sum(array.Length)); + } + [Fact] public void NullIntArray_ByValue() { @@ -147,6 +228,13 @@ public void ZeroLengthArray_MarshalledAsNonNull() Assert.Equal(0, NativeExportsNE.Arrays.Sum(array, array.Length)); } + [Fact] + public void ZeroLengthArray_MarshalledAsNonNull_This() + { + var array = new int[0]; + Assert.Equal(0, array.Sum(array.Length)); + } + [Fact] public void IntArray_In() { @@ -170,6 +258,13 @@ public void CharArray_ByValue() Assert.Equal(array.Sum(c => c), NativeExportsNE.Arrays.SumChars(array, array.Length)); } + [Fact] + public void CharArray_ByValue_This() + { + char[] array = CharacterTests.CharacterMappings().Select(o => (char)o[0]).ToArray(); + Assert.Equal(array.Sum(c => c), array.SumChars(array.Length)); + } + [Fact] public void CharArray_Ref() { @@ -219,6 +314,22 @@ public unsafe void PointerArray_ByValue() } } + [Fact] + public unsafe void PointerArray_ByValue_This() + { + int[] array = GetIntArray(); + fixed (int* arrayPointer = array) + { + int*[] pointerArray = new int*[array.Length]; + for (int i = 0; i < array.Length; i++) + { + pointerArray[i] = &arrayPointer[i]; + } + + Assert.Equal(array.Sum(), pointerArray.Sum(pointerArray.Length)); + } + } + [Fact] public unsafe void PointerArray_In() { @@ -281,6 +392,28 @@ public unsafe void PointerArray_Return() } } + [Fact] + public unsafe void PointerArray_Return_This() + { + int[] array = GetIntArray(); + fixed (int* arrayPointer = array) + { + int*[] pointerArray = new int*[array.Length]; + for (int i = 0; i < array.Length; i++) + { + pointerArray[i] = &arrayPointer[i]; + } + + int*[] res = pointerArray.ReturnDuplicate(pointerArray.Length); + Assert.Equal(pointerArray.Length, res.Length); + for (int i = 0; i < pointerArray.Length; i++) + { + Assert.Equal((IntPtr)pointerArray[i], (IntPtr)res[i]); + Assert.Equal(*pointerArray[i], *res[i]); + } + } + } + private static string[] GetStringArray() { return new[] @@ -301,12 +434,26 @@ public void ArrayWithElementMarshalling_ByValue() Assert.Equal(strings.Sum(str => str?.Length ?? 0), NativeExportsNE.Arrays.SumStringLengths(strings)); } + [Fact] + public void ArrayWithElementMarshalling_ByValue_This() + { + var strings = GetStringArray(); + Assert.Equal(strings.Sum(str => str?.Length ?? 0), strings.SumStringLengths()); + } + [Fact] public void NullArrayWithElementMarshalling_ByValue() { Assert.Equal(0, NativeExportsNE.Arrays.SumStringLengths(null)); } + [Fact] + public void NullArrayWithElementMarshalling_ByValue_This() + { + string[] strings = null; + Assert.Equal(0, strings.SumStringLengths()); + } + [Fact] public void ArrayWithElementMarshalling_Ref() { @@ -329,6 +476,18 @@ public void ArrayWithElementMarshalling_Return() Assert.Equal(expectedStrings, res); } + [Fact] + public void ArrayWithElementMarshalling_Return_This() + { + var strings = GetStringArray(); + var expectedStrings = strings.Select(s => ReverseChars(s)).ToArray(); + Assert.Equal(expectedStrings, strings.ReverseStrings_Return(out _)); + + string[] res; + strings.ReverseStrings_Out(out _, out res); + Assert.Equal(expectedStrings, res); + } + [Fact] public void NullArrayWithElementMarshalling_Ref() { @@ -349,6 +508,17 @@ public void NullArrayWithElementMarshalling_Return() Assert.Null(res); } + [Fact] + public void NullArrayWithElementMarshalling_Return_This() + { + string[] strings = null; + Assert.Null(strings.ReverseStrings_Return(out _)); + + string[] res; + strings.ReverseStrings_Out(out _, out res); + Assert.Null(res); + } + [Fact] public void ConstantSizeArray() { @@ -357,6 +527,14 @@ public void ConstantSizeArray() Assert.Equal(longVal, MemoryMarshal.Read(NativeExportsNE.Arrays.GetLongBytes(longVal))); } + [Fact] + public void ConstantSizeArray_This() + { + var longVal = 0x12345678ABCDEF10L; + + Assert.Equal(longVal, MemoryMarshal.Read(longVal.GetLongBytes())); + } + [Fact] public void DynamicSizedArrayWithConstantComponent() { @@ -400,6 +578,39 @@ public void Array_ByValueOut() } } + [Fact] + public void Array_ByValueOut_This() + { + { + var testArray = new IntStructWrapper[10]; + int start = 5; + + testArray.FillRangeArray(testArray.Length, start); + Assert.Equal(Enumerable.Range(start, testArray.Length), testArray.Select(wrapper => wrapper.Value)); + + // Any items not populated by the invoke target should be initialized to default + testArray = new IntStructWrapper[10]; + int lengthToFill = testArray.Length / 2; + testArray.FillRangeArray(lengthToFill, start); + Assert.Equal(Enumerable.Range(start, lengthToFill), testArray[..lengthToFill].Select(wrapper => wrapper.Value)); + Assert.All(testArray[lengthToFill..], wrapper => Assert.Equal(0, wrapper.Value)); + } + { + var testArray = new char[10]; + ushort start = 65; + + testArray.FillChars(testArray.Length, start); + Assert.Equal(Enumerable.Range(start, testArray.Length), testArray.Select(c => (int)c)); + + // Any items not populated by the invoke target should be initialized to default + testArray = new char[10]; + int lengthToFill = testArray.Length / 2; + testArray.FillChars(lengthToFill, start); + Assert.Equal(Enumerable.Range(start, lengthToFill), testArray[..lengthToFill].Select(c => (int)c)); + Assert.All(testArray[lengthToFill..], c => Assert.Equal(0, c)); + } + } + [Fact] public void Array_ByValueInOut() { @@ -412,6 +623,18 @@ public void Array_ByValueInOut() Assert.Equal(testValues.Select(wrapper => wrapper.Value * 2), testArray.Select(wrapper => wrapper.Value)); } + [Fact] + public void Array_ByValueInOut_This() + { + var testValues = Enumerable.Range(42, 15).Select(i => new IntStructWrapper { Value = i }); + + var testArray = testValues.ToArray(); + + testArray.DoubleValues(testArray.Length); + + Assert.Equal(testValues.Select(wrapper => wrapper.Value * 2), testArray.Select(wrapper => wrapper.Value)); + } + [Theory] [InlineData(true)] [InlineData(false)] @@ -421,6 +644,15 @@ public void NonBlittableElementArray_ByValue(bool result) Assert.Equal(result, NativeExportsNE.Arrays.AndAllMembers(array, array.Length)); } + [Theory] + [InlineData(true)] + [InlineData(false)] + public void NonBlittableElementArray_ByValue_This(bool result) + { + BoolStruct[] array = GetBoolStructsToAnd(result); + Assert.Equal(result, array.AndAllMembers(array.Length)); + } + [Theory] [InlineData(true)] [InlineData(false)] @@ -451,6 +683,17 @@ public void NonBlittableElementArray_Out() Assert.Equal(expected, result); } + [Fact] + public void NonBlittableElementArray_Out_This() + { + BoolStruct[] array = GetBoolStructsToNegate(); + BoolStruct[] expected = GetNegatedBoolStructs(array); + + BoolStruct[] result; + array.NegateBools(array.Length, out result); + Assert.Equal(expected, result); + } + [Fact] public void NonBlittableElementArray_Return() { @@ -461,6 +704,16 @@ public void NonBlittableElementArray_Return() Assert.Equal(expected, result); } + [Fact] + public void NonBlittableElementArray_Return_This() + { + BoolStruct[] array = GetBoolStructsToNegate(); + BoolStruct[] expected = GetNegatedBoolStructs(array); + + BoolStruct[] result = array.NegateBools(array.Length); + Assert.Equal(expected, result); + } + private static BoolStruct[] GetBoolStructsToAnd(bool result) => new BoolStruct[] { new BoolStruct @@ -544,6 +797,36 @@ public void ArraysOfArrays() } } + [Fact] + public void ArraysOfArrays_This() + { + var random = new Random(42); + int numRows = random.Next(1, 5); + int numColumns = random.Next(1, 5); + int[][] matrix = new int[numRows][]; + for (int i = 0; i < numRows; i++) + { + matrix[i] = new int[numColumns]; + for (int j = 0; j < numColumns; j++) + { + matrix[i][j] = random.Next(); + } + } + + int[] numRowsArray = new int[numColumns]; + numRowsArray.AsSpan().Fill(numRows); + + int[][] transposed = matrix.TransposeMatrix(numRowsArray, numColumns); + + for (int i = 0; i < numRows; i++) + { + for (int j = 0; j < numColumns; j++) + { + Assert.Equal(matrix[i][j], transposed[j][i]); + } + } + } + private static string ReverseChars(string value) { if (value == null) diff --git a/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.Tests/BlittableStructTests.cs b/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.Tests/BlittableStructTests.cs index f02379ddcabf2..c3a6ea8588837 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.Tests/BlittableStructTests.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.Tests/BlittableStructTests.cs @@ -41,6 +41,16 @@ public static partial void IncrementInvertPointerFieldsRefReturn( PointerFields input, ref PointerFields result); } + public static partial class IntStructExtensions + { + [LibraryImport(NativeExportsNE.NativeExportsNE_Binary, EntryPoint = "blittablestructs_return_instance")] + public static partial IntFields DoubleIntFields(this IntFields result); + + [LibraryImport(NativeExportsNE.NativeExportsNE_Binary, EntryPoint = "blittablestructs_double_intfields_refreturn")] + public static partial void DoubleIntFieldsOutReturn( + this IntFields input, + out IntFields result); + } public class BlittableStructTests { @@ -67,6 +77,11 @@ public void ValidateIntFields() Assert.Equal(initial, input); Assert.Equal(expected, result); } + { + var result = input.DoubleIntFields(); + Assert.Equal(initial, input); + Assert.Equal(expected, result); + } { var result = new IntFields(); NativeExportsNE.DoubleIntFieldsRefReturn(input, ref result); @@ -80,6 +95,12 @@ public void ValidateIntFields() Assert.Equal(initial, input); Assert.Equal(expected, result); } + { + IntFields result; + input.DoubleIntFieldsOutReturn(out result); + Assert.Equal(initial, input); + Assert.Equal(expected, result); + } { input = initial; diff --git a/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/CodeSnippets.cs b/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/CodeSnippets.cs index 32a0d9a414e19..6204b100ce325 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/CodeSnippets.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/CodeSnippets.cs @@ -532,6 +532,15 @@ partial class Test } """; + public static string ExplicitThis => $$""" + using System.Runtime.InteropServices; + static partial class StringNativeExtensions + { + [LibraryImport("DoesNotExist")] + public static partial void Method(this int t); + } + """; + public static string BasicParametersAndModifiers(string preDeclaration = "") => BasicParametersAndModifiers(typeof(T).ToString(), preDeclaration); /// diff --git a/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/Compiles.cs b/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/Compiles.cs index 4c80ac61cc460..43fdaf4930037 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/Compiles.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/Compiles.cs @@ -42,6 +42,7 @@ public static IEnumerable CodeSnippetsToCompile() yield return new[] { ID(), CodeSnippets.DefaultParameters }; yield return new[] { ID(), CodeSnippets.UseCSharpFeaturesForConstants }; yield return new[] { ID(), CodeSnippets.LibraryImportInRefStruct }; + yield return new[] { ID(), CodeSnippets.ExplicitThis }; // Parameter / return types yield return new[] { ID(), CodeSnippets.BasicParametersAndModifiers() }; @@ -719,7 +720,7 @@ public class Basic { } [Theory] [MemberData(nameof(CodeSnippetsToVerifyNoTreesProduced))] - public async Task ValidateNoGeneratedOuptutForNoImport(string id, string source, TestTargetFramework framework) + public async Task ValidateNoGeneratedOutputForNoImport(string id, string source, TestTargetFramework framework) { TestUtils.Use(id); var test = new NoChangeTest(framework)