Skip to content

Commit

Permalink
Use MemoryMarshal.Cast in a few places (dotnet#99835)
Browse files Browse the repository at this point in the history
* Use MemoryMarshal.Cast in a few places

* Removed special-casing

* Fix build
  • Loading branch information
stephentoub committed Mar 19, 2024
1 parent 07c99ab commit 886bca3
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;

#pragma warning disable CS8500 // This takes the address of, gets the size of, or declares a pointer to a managed type

namespace System.Numerics.Tensors
{
/// <summary>Performs primitive tensor operations over spans of memory.</summary>
Expand All @@ -23,7 +25,31 @@ private static void ValidateInputOutputSpanNonOverlapping<T>(ReadOnlySpan<T> inp
}

/// <summary>Throws an <see cref="OverflowException"/> for trying to negate the minimum value of a two-complement value.</summary>
internal static void ThrowNegateTwosCompOverflow() => throw new OverflowException(SR.Overflow_NegateTwosCompNum);
private static void ThrowNegateTwosCompOverflow() => throw new OverflowException(SR.Overflow_NegateTwosCompNum);

/// <summary>Creates a span of <typeparamref name="TTo"/> from a <typeparamref name="TFrom"/> when they're the same type.</summary>
/// <remarks>
/// This is the same as MemoryMarshal.Cast, except only to be used when TFrom and TTo are the same type or effectively
/// the same type (e.g. int and nint in a 32-bit process). MemoryMarshal.Cast can't currently be used as it's
/// TFrom/TTo are constrained to be value types.
/// </remarks>
private static unsafe Span<TTo> Rename<TFrom, TTo>(Span<TFrom> span)
{
Debug.Assert(sizeof(TFrom) == sizeof(TTo));
return *(Span<TTo>*)(&span);
}

/// <summary>Creates a span of <typeparamref name="TTo"/> from a <typeparamref name="TFrom"/> when they're the same type.</summary>
/// <remarks>
/// This is the same as MemoryMarshal.Cast, except only to be used when TFrom and TTo are the same type or effectively
/// the same type (e.g. int and nint in a 32-bit process). MemoryMarshal.Cast can't currently be used as it's
/// TFrom/TTo are constrained to be value types.
/// </remarks>
private static unsafe ReadOnlySpan<TTo> Rename<TFrom, TTo>(ReadOnlySpan<TFrom> span)
{
Debug.Assert(sizeof(TFrom) == sizeof(TTo));
return *(ReadOnlySpan<TTo>*)(&span);
}

/// <summary>Mask used to handle alignment elements before vectorized handling of the input.</summary>
/// <remarks>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1247,12 +1247,5 @@ static void VectorizedSmall8(ref TInput xRef, ref TOutput dRef, nuint remainder)
}
}
}

/// <summary>Creates a span of <typeparamref name="TTo"/> from a <typeparamref name="TTo"/> when they're the same type.</summary>
private static unsafe Span<TTo> Rename<TFrom, TTo>(Span<TFrom> span)
{
Debug.Assert(sizeof(TFrom) == sizeof(TTo));
return MemoryMarshal.CreateSpan(ref Unsafe.As<TFrom, TTo>(ref MemoryMarshal.GetReference(span)), span.Length);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -643,13 +643,6 @@ static Vector512<uint> SingleToHalfAsWidenedUInt32(Vector512<float> value)
}
}

/// <summary>Creates a span of <typeparamref name="TTo"/> from a <typeparamref name="TTo"/> when they're the same type.</summary>
private static unsafe ReadOnlySpan<TTo> Rename<TFrom, TTo>(ReadOnlySpan<TFrom> span)
{
Debug.Assert(sizeof(TFrom) == sizeof(TTo));
return MemoryMarshal.CreateReadOnlySpan(ref Unsafe.As<TFrom, TTo>(ref MemoryMarshal.GetReference(span)), span.Length);
}

/// <summary>Gets whether <typeparamref name="T"/> is <see cref="uint"/> or <see cref="nuint"/> if in a 32-bit process.</summary>
private static bool IsUInt32Like<T>() => typeof(T) == typeof(uint) || (IntPtr.Size == 4 && typeof(T) == typeof(nuint));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,13 +106,13 @@ public static void Round<T>(ReadOnlySpan<T> x, int digits, MidpointRounding mode
if (typeof(T) == typeof(float))
{
ReadOnlySpan<float> roundPower10Single = [1e0f, 1e1f, 1e2f, 1e3f, 1e4f, 1e5f, 1e6f];
roundPower10 = MemoryMarshal.CreateReadOnlySpan(ref Unsafe.As<float, T>(ref MemoryMarshal.GetReference(roundPower10Single)), roundPower10Single.Length);
roundPower10 = Rename<float, T>(roundPower10Single);
}
else if (typeof(T) == typeof(double))
{
Debug.Assert(typeof(T) == typeof(double));
ReadOnlySpan<double> roundPower10Double = [1e0, 1e1, 1e2, 1e3, 1e4, 1e5, 1e6, 1e7, 1e8, 1e9, 1e10, 1e11, 1e12, 1e13, 1e14, 1e15];
roundPower10 = MemoryMarshal.CreateReadOnlySpan(ref Unsafe.As<double, T>(ref MemoryMarshal.GetReference(roundPower10Double)), roundPower10Double.Length);
roundPower10 = Rename<double, T>(roundPower10Double);
}
else
{
Expand Down
24 changes: 10 additions & 14 deletions src/libraries/System.Private.CoreLib/src/System/Number.Parsing.cs
Original file line number Diff line number Diff line change
Expand Up @@ -883,16 +883,16 @@ internal static bool SpanStartsWith<TChar>(ReadOnlySpan<TChar> span, ReadOnlySpa
{
if (typeof(TChar) == typeof(char))
{
ReadOnlySpan<char> typedSpan = MemoryMarshal.CreateReadOnlySpan(ref Unsafe.As<TChar, char>(ref MemoryMarshal.GetReference(span)), span.Length);
ReadOnlySpan<char> typedValue = MemoryMarshal.CreateReadOnlySpan(ref Unsafe.As<TChar, char>(ref MemoryMarshal.GetReference(value)), value.Length);
ReadOnlySpan<char> typedSpan = MemoryMarshal.Cast<TChar, char>(span);
ReadOnlySpan<char> typedValue = MemoryMarshal.Cast<TChar, char>(value);
return typedSpan.StartsWith(typedValue, comparisonType);
}
else
{
Debug.Assert(typeof(TChar) == typeof(byte));

ReadOnlySpan<byte> typedSpan = MemoryMarshal.CreateReadOnlySpan(ref Unsafe.As<TChar, byte>(ref MemoryMarshal.GetReference(span)), span.Length);
ReadOnlySpan<byte> typedValue = MemoryMarshal.CreateReadOnlySpan(ref Unsafe.As<TChar, byte>(ref MemoryMarshal.GetReference(value)), value.Length);
ReadOnlySpan<byte> typedSpan = MemoryMarshal.Cast<TChar, byte>(span);
ReadOnlySpan<byte> typedValue = MemoryMarshal.Cast<TChar, byte>(value);
return typedSpan.StartsWithUtf8(typedValue, comparisonType);
}
}
Expand All @@ -903,17 +903,13 @@ internal static ReadOnlySpan<TChar> SpanTrim<TChar>(ReadOnlySpan<TChar> span)
{
if (typeof(TChar) == typeof(char))
{
ReadOnlySpan<char> typedSpan = MemoryMarshal.CreateReadOnlySpan(ref Unsafe.As<TChar, char>(ref MemoryMarshal.GetReference(span)), span.Length);
ReadOnlySpan<char> result = typedSpan.Trim();
return MemoryMarshal.CreateReadOnlySpan(ref Unsafe.As<char, TChar>(ref MemoryMarshal.GetReference(result)), result.Length);
return MemoryMarshal.Cast<char, TChar>(MemoryMarshal.Cast<TChar, char>(span).Trim());
}
else
{
Debug.Assert(typeof(TChar) == typeof(byte));

ReadOnlySpan<byte> typedSpan = MemoryMarshal.CreateReadOnlySpan(ref Unsafe.As<TChar, byte>(ref MemoryMarshal.GetReference(span)), span.Length);
ReadOnlySpan<byte> result = typedSpan.TrimUtf8();
return MemoryMarshal.CreateReadOnlySpan(ref Unsafe.As<byte, TChar>(ref MemoryMarshal.GetReference(result)), result.Length);
return MemoryMarshal.Cast<byte, TChar>(MemoryMarshal.Cast<TChar, byte>(span).TrimUtf8());
}
}

Expand All @@ -923,16 +919,16 @@ internal static bool SpanEqualsOrdinalIgnoreCase<TChar>(ReadOnlySpan<TChar> span
{
if (typeof(TChar) == typeof(char))
{
ReadOnlySpan<char> typedSpan = MemoryMarshal.CreateReadOnlySpan(ref Unsafe.As<TChar, char>(ref MemoryMarshal.GetReference(span)), span.Length);
ReadOnlySpan<char> typedValue = MemoryMarshal.CreateReadOnlySpan(ref Unsafe.As<TChar, char>(ref MemoryMarshal.GetReference(value)), value.Length);
ReadOnlySpan<char> typedSpan = MemoryMarshal.Cast<TChar, char>(span);
ReadOnlySpan<char> typedValue = MemoryMarshal.Cast<TChar, char>(value);
return typedSpan.EqualsOrdinalIgnoreCase(typedValue);
}
else
{
Debug.Assert(typeof(TChar) == typeof(byte));

ReadOnlySpan<byte> typedSpan = MemoryMarshal.CreateReadOnlySpan(ref Unsafe.As<TChar, byte>(ref MemoryMarshal.GetReference(span)), span.Length);
ReadOnlySpan<byte> typedValue = MemoryMarshal.CreateReadOnlySpan(ref Unsafe.As<TChar, byte>(ref MemoryMarshal.GetReference(value)), value.Length);
ReadOnlySpan<byte> typedSpan = MemoryMarshal.Cast<TChar, byte>(span);
ReadOnlySpan<byte> typedValue = MemoryMarshal.Cast<TChar, byte>(value);
return typedSpan.EqualsOrdinalIgnoreCaseUtf8(typedValue);
}
}
Expand Down

0 comments on commit 886bca3

Please sign in to comment.