Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve Span.Reverse fast path performance #70944

Merged
merged 11 commits into from
Nov 18, 2022
100 changes: 55 additions & 45 deletions src/libraries/System.Private.CoreLib/src/System/SpanHelpers.Byte.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Buffers.Binary;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Numerics;
Expand Down Expand Up @@ -1155,21 +1156,23 @@ private static unsafe nuint UnalignedCountVector128(ref byte searchSpace)

public static void Reverse(ref byte buf, nuint length)
{
if (Avx2.IsSupported && (nuint)Vector256<byte>.Count * 2 <= length)
Debug.Assert(length > 0);
ref byte first = ref buf;
ref byte last = ref Unsafe.NullRef<byte>();
nuint lastOffset = length;

if (Avx2.IsSupported && lastOffset >= (nuint)Vector256<byte>.Count * 2)
{
Vector256<byte> reverseMask = Vector256.Create(
(byte)15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, // first 128-bit lane
15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0); // second 128-bit lane
nuint numElements = (nuint)Vector256<byte>.Count;
nuint numIters = (length / numElements) / 2;
for (nuint i = 0; i < numIters; i++)
{
nuint firstOffset = i * numElements;
nuint lastOffset = length - ((1 + i) * numElements);

// Load in values from beginning and end of the array.
Vector256<byte> tempFirst = Vector256.LoadUnsafe(ref buf, firstOffset);
Vector256<byte> tempLast = Vector256.LoadUnsafe(ref buf, lastOffset);
last = ref Unsafe.Subtract(ref Unsafe.Add(ref first, (int)lastOffset), (nuint)Vector256<byte>.Count);
do
{
// Load the values into vectors
Vector256<byte> tempFirst = Vector256.LoadUnsafe(ref first);
Vector256<byte> tempLast = Vector256.LoadUnsafe(ref last);

// Avx2 operates on two 128-bit lanes rather than the full 256-bit vector.
// Perform a shuffle to reverse each 128-bit lane, then permute to finish reversing the vector:
Expand All @@ -1196,48 +1199,55 @@ public static void Reverse(ref byte buf, nuint length)
tempLast = Avx2.Permute2x128(tempLast, tempLast, 0b00_01);

// Store the reversed vectors
tempLast.StoreUnsafe(ref buf, firstOffset);
tempFirst.StoreUnsafe(ref buf, lastOffset);
}
buf = ref Unsafe.Add(ref buf, numIters * numElements);
length -= numIters * numElements * 2;
tempLast.StoreUnsafe(ref first);
tempFirst.StoreUnsafe(ref last);

first = ref Unsafe.Add(ref first, (nuint)Vector256<byte>.Count);
last = ref Unsafe.Subtract(ref last, (nuint)Vector256<byte>.Count);
lastOffset -= (nuint)Vector256<byte>.Count * 2;
} while (lastOffset >= (nuint)Vector256<byte>.Count * 2);
yesmey marked this conversation as resolved.
Show resolved Hide resolved
}
else if (Vector128.IsHardwareAccelerated && (nuint)Vector128<byte>.Count * 2 <= length)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the removal of Vector128 code path causes a 10% regression for larger inputs on arm64. This needs to be addressed. I can run the benchmarks for you on an arm64 machine, just ping me here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed it back to the previous behavior, can you run the benchmarks again when you got time, please?


// Use ReverseEndianness on 8 bytes pairs
if (lastOffset >= (sizeof(long) * 2))
{
nuint numElements = (nuint)Vector128<byte>.Count;
nuint numIters = (length / numElements) / 2;
for (nuint i = 0; i < numIters; i++)
last = ref Unsafe.Subtract(ref Unsafe.Add(ref first, (int)lastOffset), (nuint)sizeof(long));
do
{
nuint firstOffset = i * numElements;
nuint lastOffset = length - ((1 + i) * numElements);

// Load in values from beginning and end of the array.
Vector128<byte> tempFirst = Vector128.LoadUnsafe(ref buf, firstOffset);
Vector128<byte> tempLast = Vector128.LoadUnsafe(ref buf, lastOffset);

// Shuffle to reverse each vector:
// +---------------------------------------------------------------+
// | A | B | C | D | E | F | G | H | I | J | K | L | M | N | O | P |
// +---------------------------------------------------------------+
// --->
// +---------------------------------------------------------------+
// | P | O | N | M | L | K | J | I | H | G | F | E | D | C | B | A |
// +---------------------------------------------------------------+
tempFirst = Vector128.Shuffle(tempFirst, Vector128.Create(
(byte)15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0));
tempLast = Vector128.Shuffle(tempLast, Vector128.Create(
(byte)15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0));
long tempFirst = Unsafe.ReadUnaligned<long>(ref first);
long tempLast = Unsafe.ReadUnaligned<long>(ref last);

// Store the reversed vectors
tempLast.StoreUnsafe(ref buf, firstOffset);
tempFirst.StoreUnsafe(ref buf, lastOffset);
}
buf = ref Unsafe.Add(ref buf, numIters * numElements);
length -= numIters * numElements * 2;
// swap and store in reversed position
Unsafe.WriteUnaligned(ref first, BinaryPrimitives.ReverseEndianness(tempLast));
Unsafe.WriteUnaligned(ref last, BinaryPrimitives.ReverseEndianness(tempFirst));

first = ref Unsafe.Add(ref first, (nuint)sizeof(long));
last = ref Unsafe.Subtract(ref last, (nuint)sizeof(long));
lastOffset -= sizeof(long) * 2;
} while (lastOffset >= (sizeof(long) * 2));
}

// Use ReverseEndianness on 4 bytes pairs
if (lastOffset >= (sizeof(int) * 2))
{
last = ref Unsafe.Subtract(ref Unsafe.Add(ref first, (int)lastOffset), (nuint)sizeof(int));
do
{
int tempFirst = Unsafe.ReadUnaligned<int>(ref first);
int tempLast = Unsafe.ReadUnaligned<int>(ref last);

// swap and store in reversed position
Unsafe.WriteUnaligned(ref first, BinaryPrimitives.ReverseEndianness(tempLast));
Unsafe.WriteUnaligned(ref last, BinaryPrimitives.ReverseEndianness(tempFirst));

first = ref Unsafe.Add(ref first, (nuint)sizeof(int));
last = ref Unsafe.Subtract(ref last, (nuint)sizeof(int));
lastOffset -= sizeof(int) * 2;
} while (lastOffset >= (sizeof(int) * 2));
}

// Store any remaining values one-by-one
ReverseInner(ref buf, length);
ReverseInner(ref first, lastOffset);
}
}
}
77 changes: 41 additions & 36 deletions src/libraries/System.Private.CoreLib/src/System/SpanHelpers.Char.cs
Original file line number Diff line number Diff line change
Expand Up @@ -733,23 +733,26 @@ private static unsafe nint UnalignedCountVector128(ref char searchSpace)

public static void Reverse(ref char buf, nuint length)
{
if (Avx2.IsSupported && (nuint)Vector256<short>.Count * 2 <= length)
Debug.Assert(length > 0);
ref char first = ref buf;
ref char last = ref Unsafe.NullRef<char>();
nuint lastOffset = length;

if (Avx2.IsSupported && lastOffset >= (nuint)Vector256<short>.Count * 2)
{
ref byte bufByte = ref Unsafe.As<char, byte>(ref buf);
nuint byteLength = length * sizeof(char);
last = ref Unsafe.Subtract(ref Unsafe.Add(ref first, (int)lastOffset), (nuint)Vector256<short>.Count);

Vector256<byte> reverseMask = Vector256.Create(
(byte)14, 15, 12, 13, 10, 11, 8, 9, 6, 7, 4, 5, 2, 3, 0, 1, // first 128-bit lane
14, 15, 12, 13, 10, 11, 8, 9, 6, 7, 4, 5, 2, 3, 0, 1); // second 128-bit lane
nuint numElements = (nuint)Vector256<byte>.Count;
nuint numIters = (byteLength / numElements) / 2;
for (nuint i = 0; i < numIters; i++)
do
{
nuint firstOffset = i * numElements;
nuint lastOffset = byteLength - ((1 + i) * numElements);
ref byte firstByte = ref Unsafe.As<char, byte>(ref first);
ref byte lastByte = ref Unsafe.As<char, byte>(ref last);

// Load in values from beginning and end of the array.
Vector256<byte> tempFirst = Vector256.LoadUnsafe(ref bufByte, firstOffset);
Vector256<byte> tempLast = Vector256.LoadUnsafe(ref bufByte, lastOffset);
// Load the values into vectors
Vector256<byte> tempFirst = Vector256.LoadUnsafe(ref firstByte);
Vector256<byte> tempLast = Vector256.LoadUnsafe(ref lastByte);

// Avx2 operates on two 128-bit lanes rather than the full 256-bit vector.
// Perform a shuffle to reverse each 128-bit lane, then permute to finish reversing the vector:
Expand All @@ -770,27 +773,27 @@ public static void Reverse(ref char buf, nuint length)
tempLast = Avx2.Permute2x128(tempLast, tempLast, 0b00_01);

// Store the reversed vectors
tempLast.StoreUnsafe(ref bufByte, firstOffset);
tempFirst.StoreUnsafe(ref bufByte, lastOffset);
}
bufByte = ref Unsafe.Add(ref bufByte, numIters * numElements);
length -= numIters * (nuint)Vector256<short>.Count * 2;
// Store any remaining values one-by-one
buf = ref Unsafe.As<byte, char>(ref bufByte);
tempLast.StoreUnsafe(ref firstByte);
tempFirst.StoreUnsafe(ref lastByte);

first = ref Unsafe.Add(ref first, (nuint)Vector256<short>.Count);
last = ref Unsafe.Subtract(ref last, (nuint)Vector256<short>.Count);
lastOffset -= (nuint)Vector256<short>.Count * 2;
} while (lastOffset >= (nuint)Vector256<short>.Count * 2);
}
else if (Vector128.IsHardwareAccelerated && (nuint)Vector128<short>.Count * 2 <= length)

if (Vector128.IsHardwareAccelerated && lastOffset >= (nuint)Vector128<short>.Count * 2)
{
ref short bufShort = ref Unsafe.As<char, short>(ref buf);
nuint numElements = (nuint)Vector128<short>.Count;
nuint numIters = (length / numElements) / 2;
for (nuint i = 0; i < numIters; i++)
last = ref Unsafe.Subtract(ref Unsafe.Add(ref first, (int)lastOffset), (nuint)Vector128<short>.Count);

do
{
nuint firstOffset = i * numElements;
nuint lastOffset = length - ((1 + i) * numElements);
ref short firstByte = ref Unsafe.As<char, short>(ref first);
ref short lastByte = ref Unsafe.As<char, short>(ref last);

// Load in values from beginning and end of the array.
Vector128<short> tempFirst = Vector128.LoadUnsafe(ref bufShort, firstOffset);
Vector128<short> tempLast = Vector128.LoadUnsafe(ref bufShort, lastOffset);
// Load the values into vectors
Vector128<short> tempFirst = Vector128.LoadUnsafe(ref firstByte);
Vector128<short> tempLast = Vector128.LoadUnsafe(ref lastByte);

// Shuffle to reverse each vector:
// +-------------------------------+
Expand All @@ -804,15 +807,17 @@ public static void Reverse(ref char buf, nuint length)
tempLast = Vector128.Shuffle(tempLast, Vector128.Create(7, 6, 5, 4, 3, 2, 1, 0));

// Store the reversed vectors
tempLast.StoreUnsafe(ref bufShort, firstOffset);
tempFirst.StoreUnsafe(ref bufShort, lastOffset);
}
bufShort = ref Unsafe.Add(ref bufShort, numIters * numElements);
length -= numIters * (nuint)Vector128<short>.Count * 2;
// Store any remaining values one-by-one
buf = ref Unsafe.As<short, char>(ref bufShort);
tempLast.StoreUnsafe(ref firstByte);
tempFirst.StoreUnsafe(ref lastByte);

first = ref Unsafe.Add(ref first, (nuint)Vector128<short>.Count);
last = ref Unsafe.Subtract(ref last, (nuint)Vector128<short>.Count);
lastOffset -= (nuint)Vector128<short>.Count * 2;
} while (lastOffset >= (nuint)Vector128<short>.Count * 2);
}
ReverseInner(ref buf, length);

// Store any remaining values one-by-one
ReverseInner(ref first, lastOffset);
}
}
}
Loading