Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[release/7.0] Ensure we cleanup the marshalling for elements of collections (stateful and stateless) #76693

Merged
merged 7 commits into from
Oct 7, 2022
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,33 @@ protected StatementSyntax GenerateByValueOutUnmarshalStatement(TypePositionInfo
StubCodeContext.Stage.Unmarshal));
}

protected StatementSyntax GenerateElementCleanupStatement(TypePositionInfo info, StubCodeContext context)
{
string nativeSpanIdentifier = MarshallerHelpers.GetNativeSpanIdentifier(info, context);
StatementSyntax contentsCleanupStatements = GenerateContentsMarshallingStatement(info, context,
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
IdentifierName(MarshallerHelpers.GetNativeSpanIdentifier(info, context)),
IdentifierName("Length")),
StubCodeContext.Stage.Cleanup);

if (contentsCleanupStatements.IsKind(SyntaxKind.EmptyStatement))
{
return EmptyStatement();
}

return Block(
LocalDeclarationStatement(VariableDeclaration(
GenericName(
Identifier(TypeNames.System_Span),
TypeArgumentList(SingletonSeparatedList(_unmanagedElementType))),
SingletonSeparatedList(
VariableDeclarator(
Identifier(nativeSpanIdentifier))
.WithInitializer(EqualsValueClause(
GetUnmanagedValuesDestination(info, context)))))),
contentsCleanupStatements);
}

protected StatementSyntax GenerateContentsMarshallingStatement(
TypePositionInfo info,
StubCodeContext context,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -292,5 +292,37 @@ public static IEnumerable<TypePositionInfo> GetDependentElementsOfMarshallingInf
}
}
}

public static StatementSyntax SkipInitOrDefaultInit(TypePositionInfo info, StubCodeContext context)
{
(TargetFramework fmk, _) = context.GetTargetFramework();
if (info.ManagedType is not PointerTypeInfo
&& info.ManagedType is not ValueTypeInfo { IsByRefLike: true }
&& fmk is TargetFramework.Net)
{
// Use the Unsafe.SkipInit<T> API when available and
// managed type is usable as a generic parameter.
return ExpressionStatement(
InvocationExpression(
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
ParseName(TypeNames.System_Runtime_CompilerServices_Unsafe),
IdentifierName("SkipInit")))
.WithArgumentList(
ArgumentList(SingletonSeparatedList(
Argument(IdentifierName(info.InstanceIdentifier))
.WithRefOrOutKeyword(Token(SyntaxKind.OutKeyword))))));
}
else
{
// Assign out params to default
return ExpressionStatement(
AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
IdentifierName(info.InstanceIdentifier),
LiteralExpression(
SyntaxKind.DefaultLiteralExpression,
Token(SyntaxKind.DefaultKeyword))));
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,27 @@ public StatefulLinearCollectionNonBlittableElementsMarshalling(
}

public TypeSyntax AsNativeType(TypePositionInfo info) => _innerMarshaller.AsNativeType(info);
public IEnumerable<StatementSyntax> GenerateCleanupStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateCleanupStatements(info, context);
public IEnumerable<StatementSyntax> GenerateCleanupStatements(TypePositionInfo info, StubCodeContext context)
{
StatementSyntax elementCleanup = GenerateElementCleanupStatement(info, context);

if (!elementCleanup.IsKind(SyntaxKind.EmptyStatement))
{
yield return elementCleanup;
}

if (!_shape.HasFlag(MarshallerShape.Free))
yield break;

string marshaller = StatefulValueMarshalling.GetMarshallerIdentifier(info, context);
// <marshaller>.Free();
yield return ExpressionStatement(
InvocationExpression(
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
IdentifierName(marshaller),
IdentifierName(ShapeMemberNames.Free)),
ArgumentList()));
}
public IEnumerable<StatementSyntax> GenerateGuaranteedUnmarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateGuaranteedUnmarshalStatements(info, context);

public IEnumerable<StatementSyntax> GenerateMarshalStatements(TypePositionInfo info, StubCodeContext context)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,10 @@ public StatelessFreeMarshalling(ICustomTypeMarshallingStrategy innerMarshaller,

public IEnumerable<StatementSyntax> GenerateCleanupStatements(TypePositionInfo info, StubCodeContext context)
{
foreach (StatementSyntax statement in _innerMarshaller.GenerateCleanupStatements(info, context))
{
yield return statement;
}
// <marshallerType>.Free(<nativeIdentifier>);
yield return ExpressionStatement(
InvocationExpression(
Expand Down Expand Up @@ -372,11 +376,19 @@ public IEnumerable<StatementSyntax> GenerateMarshalStatements(TypePositionInfo i
public IEnumerable<StatementSyntax> GeneratePinStatements(TypePositionInfo info, StubCodeContext context) => Array.Empty<StatementSyntax>();
public IEnumerable<StatementSyntax> GenerateSetupStatements(TypePositionInfo info, StubCodeContext context)
{
string numElementsIdentifier = MarshallerHelpers.GetNumElementsIdentifier(info, context);
yield return LocalDeclarationStatement(
VariableDeclaration(
PredefinedType(Token(SyntaxKind.IntKeyword)),
SingletonSeparatedList(
VariableDeclarator(MarshallerHelpers.GetNumElementsIdentifier(info, context)))));
VariableDeclarator(numElementsIdentifier))));
// Use the numElements local to ensure the compiler doesn't give errors for using an uninitialized variable.
// The value will never be used unless it has been initialized, so this is safe.
yield return MarshallerHelpers.SkipInitOrDefaultInit(
new TypePositionInfo(SpecialTypeInfo.Int32, NoMarshallingInfo.Instance)
{
InstanceIdentifier = numElementsIdentifier
}, context);
}

public IEnumerable<StatementSyntax> GenerateUnmarshalCaptureStatements(TypePositionInfo info, StubCodeContext context) => Array.Empty<StatementSyntax>();
Expand Down Expand Up @@ -512,7 +524,15 @@ public StatelessLinearCollectionNonBlittableElementsMarshalling(

public TypeSyntax AsNativeType(TypePositionInfo info) => _nativeTypeSyntax;

public IEnumerable<StatementSyntax> GenerateCleanupStatements(TypePositionInfo info, StubCodeContext context) => Array.Empty<StatementSyntax>();
public IEnumerable<StatementSyntax> GenerateCleanupStatements(TypePositionInfo info, StubCodeContext context)
{
StatementSyntax elementCleanup = GenerateElementCleanupStatement(info, context);

if (!elementCleanup.IsKind(SyntaxKind.EmptyStatement))
{
yield return elementCleanup;
}
}

public IEnumerable<StatementSyntax> GenerateGuaranteedUnmarshalStatements(TypePositionInfo info, StubCodeContext context)
{
Expand Down Expand Up @@ -588,11 +608,19 @@ public IEnumerable<StatementSyntax> GenerateMarshalStatements(TypePositionInfo i

public IEnumerable<StatementSyntax> GenerateSetupStatements(TypePositionInfo info, StubCodeContext context)
{
string numElementsIdentifier = MarshallerHelpers.GetNumElementsIdentifier(info, context);
yield return LocalDeclarationStatement(
VariableDeclaration(
PredefinedType(Token(SyntaxKind.IntKeyword)),
SingletonSeparatedList(
VariableDeclarator(MarshallerHelpers.GetNumElementsIdentifier(info, context)))));
VariableDeclarator(numElementsIdentifier))));
// Use the numElements local to ensure the compiler doesn't give errors for using an uninitialized variable.
// The value will never be used unless it has been initialized, so this is safe.
yield return MarshallerHelpers.SkipInitOrDefaultInit(
new TypePositionInfo(SpecialTypeInfo.Int32, NoMarshallingInfo.Instance)
{
InstanceIdentifier = numElementsIdentifier
}, context);
}

public IEnumerable<StatementSyntax> GenerateUnmarshalCaptureStatements(TypePositionInfo info, StubCodeContext context) => Array.Empty<StatementSyntax>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -593,7 +593,7 @@ private MarshallingInfo CreateNativeMarshallingInfo(
}

int maxIndirectionDepthUsedLocal = maxIndirectionDepthUsed;
Func<ITypeSymbol, MarshallingInfo> getMarshallingInfoForElement = (ITypeSymbol elementType) => GetMarshallingInfo(elementType, new Dictionary<int, AttributeData>(), 1, ImmutableHashSet<string>.Empty, ref maxIndirectionDepthUsedLocal);
Func<ITypeSymbol, MarshallingInfo> getMarshallingInfoForElement = (ITypeSymbol elementType) => GetMarshallingInfo(elementType, useSiteAttributes, 1, inspectedElements, ref maxIndirectionDepthUsedLocal);
jkoritzinsky marked this conversation as resolved.
Show resolved Hide resolved
if (ManualTypeMarshallingHelper.TryGetLinearCollectionMarshallersFromEntryType(entryPointType, type, _compilation, getMarshallingInfoForElement, out CustomTypeMarshallers? collectionMarshallers))
{
maxIndirectionDepthUsed = maxIndirectionDepthUsedLocal;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,34 +29,7 @@ public static VariableDeclarations GenerateDeclarationsForManagedToNative(BoundG

if (info.RefKind == RefKind.Out)
{
(TargetFramework fmk, _) = context.GetTargetFramework();
if (info.ManagedType is not PointerTypeInfo
&& info.ManagedType is not ValueTypeInfo { IsByRefLike: true }
&& fmk is TargetFramework.Net)
{
// Use the Unsafe.SkipInit<T> API when available and
// managed type is usable as a generic parameter.
initializations.Add(ExpressionStatement(
InvocationExpression(
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
ParseName(TypeNames.System_Runtime_CompilerServices_Unsafe),
IdentifierName("SkipInit")))
.WithArgumentList(
ArgumentList(SingletonSeparatedList(
Argument(IdentifierName(info.InstanceIdentifier))
.WithRefOrOutKeyword(Token(SyntaxKind.OutKeyword)))))));
}
else
{
// Assign out params to default
initializations.Add(ExpressionStatement(
AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
IdentifierName(info.InstanceIdentifier),
LiteralExpression(
SyntaxKind.DefaultLiteralExpression,
Token(SyntaxKind.DefaultKeyword)))));
}
initializations.Add(MarshallerHelpers.SkipInitOrDefaultInit(info, context));
}

// Declare variables for parameters
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ public partial class Stateless
[LibraryImport(NativeExportsNE_Binary, EntryPoint = "sum_int_array")]
public static partial int SumWithBuffer([MarshalUsing(typeof(ListMarshallerWithBuffer<,>))] List<int> values, int numValues);

[LibraryImport(NativeExportsNE_Binary, EntryPoint = "sum_int_ptr_array")]
public static unsafe partial int SumWithFreeTracking([MarshalUsing(typeof(ListMarshaller<,>)), MarshalUsing(typeof(IntWrapperMarshallerWithFreeCounts), ElementIndirectionDepth = 1)] List<IntWrapper> values, int numValues);

[LibraryImport(NativeExportsNE_Binary, EntryPoint = "double_values")]
public static partial int DoubleValues([MarshalUsing(typeof(ListMarshallerWithPinning<,>))] List<BlittableIntWrapper> values, int length);

Expand Down Expand Up @@ -99,6 +102,9 @@ public partial class Stateful
[LibraryImport(NativeExportsNE_Binary, EntryPoint = "sum_int_array")]
public static partial int Sum([MarshalUsing(typeof(ListMarshallerStateful<,>))] List<int> values, int numValues);

[LibraryImport(NativeExportsNE_Binary, EntryPoint = "sum_int_ptr_array")]
public static unsafe partial int SumWithFreeTracking([MarshalUsing(typeof(ListMarshallerStateful<,>)), MarshalUsing(typeof(IntWrapperMarshallerWithFreeCounts), ElementIndirectionDepth = 1)] List<IntWrapper> values, int numValues);

[LibraryImport(NativeExportsNE_Binary, EntryPoint = "sum_int_array_ref")]
public static partial int SumInArray([MarshalUsing(typeof(ListMarshallerStateful<,>))] in List<int> values, int numValues);

Expand Down Expand Up @@ -369,6 +375,30 @@ public void NonBlittableElementCollection_GuaranteedUnmarshal()
Assert.True(NativeExportsNE.Collections.Stateful.ListGuaranteedUnmarshal<BoolStruct, BoolStructMarshaller.BoolStructNative>.Marshaller.ToManagedFinallyCalled);
}

[Fact]
public void ElementsFreed()
{
List<IntWrapper> list = new List<IntWrapper>
{
new IntWrapper { i = 1 },
new IntWrapper { i = 10 },
new IntWrapper { i = 24 },
new IntWrapper { i = 30 },
};

int startingCount = IntWrapperMarshallerWithFreeCounts.NumCallsToFree;

NativeExportsNE.Collections.Stateless.SumWithFreeTracking(list, list.Count);

Assert.Equal(startingCount + list.Count, IntWrapperMarshallerWithFreeCounts.NumCallsToFree);

startingCount = IntWrapperMarshallerWithFreeCounts.NumCallsToFree;

NativeExportsNE.Collections.Stateful.SumWithFreeTracking(list, list.Count);

Assert.Equal(startingCount + list.Count, IntWrapperMarshallerWithFreeCounts.NumCallsToFree);
}

private static List<BoolStruct> GetBoolStructsToAnd(bool result) => new List<BoolStruct>
{
new BoolStruct
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,31 @@ public static void Free(int* unmanaged)
}
}

[CustomMarshaller(typeof(IntWrapper), MarshalMode.Default, typeof(IntWrapperMarshallerWithFreeCounts))]
public static unsafe class IntWrapperMarshallerWithFreeCounts
{
[ThreadStatic]
public static int NumCallsToFree = 0;

public static int* ConvertToUnmanaged(IntWrapper managed)
{
int* ret = (int*)Marshal.AllocCoTaskMem(sizeof(int));
*ret = managed.i;
return ret;
}

public static IntWrapper ConvertToManaged(int* unmanaged)
{
return new IntWrapper { i = *unmanaged };
}

public static void Free(int* unmanaged)
{
NumCallsToFree++;
Marshal.FreeCoTaskMem((IntPtr)unmanaged);
}
}

[CustomMarshaller(typeof(IntWrapper), MarshalMode.Default, typeof(Marshaller))]
public static unsafe class IntWrapperMarshallerStateful
{
Expand Down Expand Up @@ -477,14 +502,14 @@ public void FromManaged(List<T> managed, Span<TUnmanagedElement> buffer)

_list = managed;
// Always allocate at least one byte when the list is zero-length.
int spaceToAllocate = Math.Max(managed.Count * sizeof(TUnmanagedElement), 1);
if (spaceToAllocate <= buffer.Length)
int countToAllocate = Math.Max(managed.Count, 1);
if (countToAllocate <= buffer.Length)
{
_span = buffer[0..spaceToAllocate];
_span = buffer[0..countToAllocate];
}
else
{
_allocatedMemory = Marshal.AllocCoTaskMem(spaceToAllocate);
_allocatedMemory = Marshal.AllocCoTaskMem(countToAllocate * sizeof(TUnmanagedElement));
_span = new Span<TUnmanagedElement>((void*)_allocatedMemory, managed.Count);
}
}
Expand Down