Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vectorize SpanHelpers<T>.IndexOf #60974

Merged
merged 10 commits into from
Nov 22, 2021
26 changes: 18 additions & 8 deletions src/libraries/System.Private.CoreLib/src/System/Array.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1232,18 +1232,28 @@ ref Unsafe.Add(ref MemoryMarshal.GetArrayDataReference(Unsafe.As<char[]>(array))
}
else if (Unsafe.SizeOf<T>() == sizeof(int))
{
int result = SpanHelpers.IndexOf(
ref Unsafe.Add(ref MemoryMarshal.GetArrayDataReference(Unsafe.As<int[]>(array)), startIndex),
Unsafe.As<T, int>(ref value),
count);
int result = typeof(T).IsValueType
? SpanHelpers.IndexOfValueType(
ref Unsafe.Add(ref MemoryMarshal.GetArrayDataReference(Unsafe.As<int[]>(array)), startIndex),
Unsafe.As<T, int>(ref value),
count)
: SpanHelpers.IndexOf(
ref Unsafe.Add(ref MemoryMarshal.GetArrayDataReference(Unsafe.As<int[]>(array)), startIndex),
Unsafe.As<T, int>(ref value),
count);
return (result >= 0 ? startIndex : 0) + result;
}
else if (Unsafe.SizeOf<T>() == sizeof(long))
{
int result = SpanHelpers.IndexOf(
ref Unsafe.Add(ref MemoryMarshal.GetArrayDataReference(Unsafe.As<long[]>(array)), startIndex),
Unsafe.As<T, long>(ref value),
count);
int result = typeof(T).IsValueType
? SpanHelpers.IndexOfValueType(
ref Unsafe.Add(ref MemoryMarshal.GetArrayDataReference(Unsafe.As<long[]>(array)), startIndex),
Unsafe.As<T, long>(ref value),
count)
: SpanHelpers.IndexOf(
ref Unsafe.Add(ref MemoryMarshal.GetArrayDataReference(Unsafe.As<long[]>(array)), startIndex),
Unsafe.As<T, long>(ref value),
count);
return (result >= 0 ? startIndex : 0) + result;
}
}
Expand Down
126 changes: 126 additions & 0 deletions src/libraries/System.Private.CoreLib/src/System/SpanHelpers.T.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System.Diagnostics;
using System.Numerics;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Runtime.Intrinsics;
using Internal.Runtime.CompilerServices;

Expand Down Expand Up @@ -225,6 +226,22 @@ public static unsafe bool Contains<T>(ref T searchSpace, T value, int length) wh
{
Debug.Assert(length >= 0);

if (typeof(T).IsValueType && RuntimeHelpers.IsBitwiseEquatable<T>())
{
// bool and char will already have been checked before, just do checks for types
Copy link
Member

@danmoseley danmoseley Oct 29, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed I see byte/bool and char are checked here
https://github.com/danmoseley/runtime/blob/f3ca6f91ba9c758bb246be8ba26bd356d3f9dda6/src/libraries/System.Private.CoreLib/src/System/MemoryExtensions.cs#L298

... why are 1 and 2 byte sizes treated specially there, and 4 and 8 byte sizes treated specially here? why not all in the same place?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I think I was referencing a different check for those types, but putting them all near the section you linked would be cleaner and more intuitive. I'll move the checks there instead.

// that are equal to sizeof(int) or sizeof(long)
if (Unsafe.SizeOf<T>() == sizeof(int))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Won't this break for float and double? This is why we have IsBitwiseEquatable (see https://source.dot.net/System.Private.CoreLib/R/e4188e6833cbc739.html) as a helper API.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, I'll add a check.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need tests for float and double somewhere -- did anything fail before you fixed this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, nothing failed for me when testing locally using:

.\build.cmd clr+libs+libs.tests -c Checked -test
.\build.cmd clr+libs+libs.tests -c Release -test

But some new tests for float and double would probably be appropriate just to be sure. Do you know the best place to add those?

{
int result = IndexOfValueType(ref Unsafe.As<T, int>(ref searchSpace), Unsafe.As<T, int>(ref value), length);
return result != -1;
}
else if (Unsafe.SizeOf<T>() == sizeof(long))
{
int result = IndexOfValueType(ref Unsafe.As<T, long>(ref searchSpace), Unsafe.As<T, long>(ref value), length);
return result != -1;
}
}

nint index = 0; // Use nint for arithmetic to avoid unnecessary 64->32->64 truncations

if (default(T) != null || (object)value != null)
Expand Down Expand Up @@ -291,6 +308,115 @@ public static unsafe bool Contains<T>(ref T searchSpace, T value, int length) wh
return true;
}

internal static unsafe int IndexOfValueType<T>(ref T searchSpace, T value, int length) where T : struct, IEquatable<T>
{
Debug.Assert(length >= 0);

nint index = 0; // Use nint for arithmetic to avoid unnecessary 64->32->64 truncations
if (Vector.IsHardwareAccelerated && Vector<T>.IsTypeSupported && (Vector<T>.Count * 2) <= length)
{
Vector<T> valueVector = new Vector<T>(value);
Vector<T> compareVector = default;
Vector<T> matchVector = default;
if ((uint)length % (uint)Vector<T>.Count != 0)
{
// Number of elements is not a multiple of Vector<T>.Count, so do one
// check and shift only enough for the remaining set to be a multiple
// of Vecotr<T>.Count.
alexcovington marked this conversation as resolved.
Show resolved Hide resolved
compareVector = Unsafe.As<T, Vector<T>>(ref Unsafe.Add(ref searchSpace, index));
matchVector = Vector.Equals(valueVector, compareVector);
if (matchVector != Vector<T>.Zero)
{
goto VectorMatch;
}
index += length % Vector<T>.Count;
length -= length % Vector<T>.Count;
}
while (length > 0)
{
compareVector = Unsafe.As<T, Vector<T>>(ref Unsafe.Add(ref searchSpace, index));
matchVector = Vector.Equals(valueVector, compareVector);
if (matchVector != Vector<T>.Zero)
{
goto VectorMatch;
}
index += Vector<T>.Count;
length -= Vector<T>.Count;
}
goto NotFound;
VectorMatch:
for (int i = 0; i < Vector<T>.Count; i++)
if (compareVector[i].Equals(value))
return (int)(index + i);
}

while (length >= 8)
{
if (value.Equals(Unsafe.Add(ref searchSpace, index)))
goto Found;
if (value.Equals(Unsafe.Add(ref searchSpace, index + 1)))
goto Found1;
if (value.Equals(Unsafe.Add(ref searchSpace, index + 2)))
goto Found2;
if (value.Equals(Unsafe.Add(ref searchSpace, index + 3)))
goto Found3;
if (value.Equals(Unsafe.Add(ref searchSpace, index + 4)))
goto Found4;
if (value.Equals(Unsafe.Add(ref searchSpace, index + 5)))
goto Found5;
if (value.Equals(Unsafe.Add(ref searchSpace, index + 6)))
goto Found6;
if (value.Equals(Unsafe.Add(ref searchSpace, index + 7)))
goto Found7;

length -= 8;
index += 8;
}

while (length >= 4)
{
if (value.Equals(Unsafe.Add(ref searchSpace, index)))
goto Found;
if (value.Equals(Unsafe.Add(ref searchSpace, index + 1)))
goto Found1;
if (value.Equals(Unsafe.Add(ref searchSpace, index + 2)))
goto Found2;
if (value.Equals(Unsafe.Add(ref searchSpace, index + 3)))
goto Found3;

length -= 4;
index += 4;
}

while (length > 0)
{
if (value.Equals(Unsafe.Add(ref searchSpace, index)))
goto Found;

index += 1;
length--;
}
NotFound:
return -1;

Found: // Workaround for https://github.com/dotnet/runtime/issues/8795
return (int)index;
Found1:
return (int)(index + 1);
Found2:
return (int)(index + 2);
Found3:
return (int)(index + 3);
Found4:
return (int)(index + 4);
Found5:
return (int)(index + 5);
Found6:
return (int)(index + 6);
Found7:
return (int)(index + 7);
}

public static unsafe int IndexOf<T>(ref T searchSpace, T value, int length) where T : IEquatable<T>
{
Debug.Assert(length >= 0);
Expand Down