Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[release/7.0] Only stack allocate when marshalling Utf8 arguments #74553

Merged
merged 2 commits into from
Aug 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1474,7 +1474,7 @@ private bool ShouldBePinned
get
{
return MarshalDirection == MarshalDirection.Forward
&& MarshallerType != MarshallerType.Field
&& MarshallerType == MarshallerType.Argument
&& !IsManagedByRef
&& In
&& !Out;
Expand Down Expand Up @@ -1672,7 +1672,11 @@ protected override void TransformManagedToNative(ILCodeStream codeStream)
{
ILEmitter emitter = _ilCodeStreams.Emitter;

if (In && !Out && !IsManagedByRef)
if (MarshalDirection == MarshalDirection.Forward
&& MarshallerType == MarshallerType.Argument
&& !IsManagedByRef
&& In
&& !Out)
{
TypeDesc marshallerIn = MarshallerIn;

Expand Down
19 changes: 14 additions & 5 deletions src/tests/nativeaot/SmokeTests/PInvoke/PInvoke.cs
Original file line number Diff line number Diff line change
Expand Up @@ -702,6 +702,8 @@ public struct SequentialStruct
public float f2;
[MarshalAs(UnmanagedType.LPStr)]
public String f3;
[MarshalAs(UnmanagedType.LPUTF8Str)]
public String f4;
}

[StructLayout(LayoutKind.Sequential)]
Expand All @@ -713,6 +715,8 @@ public class SequentialClass
public float f2;
[MarshalAs(UnmanagedType.LPStr)]
public String f3;
[MarshalAs(UnmanagedType.LPUTF8Str)]
public String f4;
}

// A second struct with the same name but nested. Regression test against native types being mangled into
Expand Down Expand Up @@ -824,15 +828,16 @@ private static void TestStruct()
ss.f1 = 1;
ss.f2 = 10.0f;
ss.f3 = "Hello";
ss.f4 = "Hola";

ThrowIfNotEquals(true, StructTest(ss), "Struct marshalling scenario1 failed.");

StructTest_ByRef(ref ss);
ThrowIfNotEquals(true, ss.f1 == 2 && ss.f2 == 11.0 && ss.f3.Equals("Ifmmp"), "Struct marshalling scenario2 failed.");
ThrowIfNotEquals(true, ss.f1 == 2 && ss.f2 == 11.0 && ss.f3.Equals("Ifmmp") && ss.f4.Equals("Ipmb"), "Struct marshalling scenario2 failed.");

SequentialStruct ss2 = new SequentialStruct();
StructTest_ByOut(out ss2);
ThrowIfNotEquals(true, ss2.f0 == 1 && ss2.f1 == 1.0 && ss2.f2 == 1.0 && ss2.f3.Equals("0123456"), "Struct marshalling scenario3 failed.");
ThrowIfNotEquals(true, ss2.f0 == 1 && ss2.f1 == 1.0 && ss2.f2 == 1.0 && ss2.f3.Equals("0123456") && ss2.f4.Equals("789"), "Struct marshalling scenario3 failed.");

NesterOfSequentialStruct.SequentialStruct ss3 = new NesterOfSequentialStruct.SequentialStruct();
ss3.f1 = 10.0f;
Expand Down Expand Up @@ -861,6 +866,7 @@ private static void TestStruct()
ssa[i].f1 = i;
ssa[i].f2 = i*i;
ssa[i].f3 = i.LowLevelToString();
ssa[i].f4 = "u8" + i.LowLevelToString();
}
ThrowIfNotEquals(true, StructTest_Array(ssa, ssa.Length), "Array of struct marshalling failed");

Expand Down Expand Up @@ -923,9 +929,10 @@ private static void TestLayoutClassPtr()
ss.f1 = 1;
ss.f2 = 10.0f;
ss.f3 = "Hello";
ss.f4 = "Hola";

ClassTest(ss);
ThrowIfNotEquals(true, ss.f1 == 2 && ss.f2 == 11.0 && ss.f3.Equals("Ifmmp"), "LayoutClassPtr marshalling scenario1 failed.");
ThrowIfNotEquals(true, ss.f1 == 2 && ss.f2 == 11.0 && ss.f3.Equals("Ifmmp") && ss.f4.Equals("Ipmb"), "LayoutClassPtr marshalling scenario1 failed.");
}

#if OPTIMIZED_MODE_WITHOUT_SCANNER
Expand Down Expand Up @@ -955,20 +962,22 @@ private static void TestAsAny()
sc.f1 = 1;
sc.f2 = 10.0f;
sc.f3 = "Hello";
sc.f4 = "Hola";

AsAnyTest(sc);
ThrowIfNotEquals(true, sc.f1 == 2 && sc.f2 == 11.0 && sc.f3.Equals("Ifmmp"), "AsAny marshalling scenario1 failed.");
ThrowIfNotEquals(true, sc.f1 == 2 && sc.f2 == 11.0 && sc.f3.Equals("Ifmmp") && sc.f4.Equals("Ipmb"), "AsAny marshalling scenario1 failed.");

SequentialStruct ss = new SequentialStruct();
ss.f0 = 100;
ss.f1 = 1;
ss.f2 = 10.0f;
ss.f3 = "Hello";
ss.f4 = "Hola";

object o = ss;
AsAnyTest(o);
ss = (SequentialStruct)o;
ThrowIfNotEquals(true, ss.f1 == 2 && ss.f2 == 11.0 && ss.f3.Equals("Ifmmp"), "AsAny marshalling scenario2 failed.");
ThrowIfNotEquals(true, ss.f1 == 2 && ss.f2 == 11.0 && ss.f3.Equals("Ifmmp") && sc.f4.Equals("Ipmb"), "AsAny marshalling scenario2 failed.");
}

private static void TestLayoutClass()
Expand Down
22 changes: 21 additions & 1 deletion src/tests/nativeaot/SmokeTests/PInvoke/PInvokeNative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,7 @@ struct NativeSequentialStruct
int a;
float b;
char *str;
char *u8str;
};

struct NativeSequentialStruct2
Expand All @@ -494,6 +495,9 @@ DLL_EXPORT bool __stdcall StructTest(NativeSequentialStruct nss)
if (!CompareAnsiString(nss.str, "Hello"))
return false;

if (!CompareAnsiString(nss.u8str, "Hola"))
return false;

return true;
}

Expand All @@ -519,6 +523,13 @@ DLL_EXPORT void __stdcall StructTest_ByRef(NativeSequentialStruct *nss)
*p = *p + 1;
p++;
}

p = nss->u8str;
while (*p != '\0')
{
*p = *p + 1;
p++;
}
}

DLL_EXPORT void __stdcall StructTest_ByOut(NativeSequentialStruct *nss)
Expand All @@ -529,14 +540,18 @@ DLL_EXPORT void __stdcall StructTest_ByOut(NativeSequentialStruct *nss)

int arrSize = 7;
char *p;
p = (char *)MemAlloc(sizeof(char) * arrSize);
p = (char *)MemAlloc(sizeof(char) * arrSize + 1);

for (int i = 0; i < arrSize; i++)
{
*(p + i) = i + '0';
}
*(p + arrSize) = '\0';
nss->str = p;

p = (char *)MemAlloc(sizeof(char) * 4);
strcpy(p, "789");
nss->u8str = p;
}

DLL_EXPORT bool __stdcall StructTest_Array(NativeSequentialStruct *nss, int length)
Expand All @@ -558,6 +573,11 @@ DLL_EXPORT bool __stdcall StructTest_Array(NativeSequentialStruct *nss, int leng

if (CompareAnsiString(expected, nss[i].str) == 0)
return false;

sprintf(expected, "u8%d", i);

if (CompareAnsiString(expected, nss[i].u8str) == 0)
return false;
}
return true;
}
Expand Down