Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor CpuMathUtils #1229

Merged
merged 3 commits into from
Oct 15, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions src/Microsoft.ML.Core/Utilities/Contracts.cs
Original file line number Diff line number Diff line change
Expand Up @@ -945,6 +945,19 @@ public static void AssertNonWhiteSpace(this IExceptionContext ctx, string s, str
DbgFailEmpty(ctx, msg);
}

[Conditional("DEBUG")]
public static void AssertNonEmpty<T>(ReadOnlySpan<T> args)
{
if (args.IsEmpty)
DbgFail();
}
[Conditional("DEBUG")]
public static void AssertNonEmpty<T>(Span<T> args)
{
if (args.IsEmpty)
DbgFail();
}

[Conditional("DEBUG")]
public static void AssertNonEmpty<T>(ICollection<T> args)
{
Expand Down
147 changes: 74 additions & 73 deletions src/Microsoft.ML.CpuMath/AvxIntrinsics.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

using System;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Runtime.Intrinsics;
using System.Runtime.Intrinsics.X86;
using nuint = System.UInt64;
Expand Down Expand Up @@ -448,7 +449,7 @@ public static unsafe void MatMulTranPX(bool add, AlignedArray mat, int[] rgposSr
// dst[i] += scale
public static unsafe void AddScalarU(float scalar, Span<float> dst)
{
fixed (float* pdst = dst)
fixed (float* pdst = &MemoryMarshal.GetReference(dst))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be good to have an analyzer for this...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not necessarily. You have to ensure that the span is non-empty, because if you call this on an empty span, you'll get back a garbage pointer.

Also, there are some concerns that the .NET Core team has that these calls shouldn't be necessary. See the conversation at dotnet/corefx#32669 (comment), so it may not be worth it to write the analyzer if the recommendation isn't to use this method everywhere.

{
float* pDstEnd = pdst + dst.Length;
float* pDstCurrent = pdst;
Expand Down Expand Up @@ -490,7 +491,7 @@ public static unsafe void Scale(float scale, Span<float> dst)
{
fixed (uint* pLeadingAlignmentMask = &LeadingAlignmentMask[0])
fixed (uint* pTrailingAlignmentMask = &TrailingAlignmentMask[0])
fixed (float* pd = dst)
fixed (float* pd = &MemoryMarshal.GetReference(dst))
{
float* pDstCurrent = pd;
int length = dst.Length;
Expand Down Expand Up @@ -606,12 +607,12 @@ public static unsafe void Scale(float scale, Span<float> dst)
}
}

public static unsafe void ScaleSrcU(float scale, Span<float> src, Span<float> dst)
public static unsafe void ScaleSrcU(float scale, ReadOnlySpan<float> src, Span<float> dst, int count)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we have a Contract.Assert that count is less than or equal to the length of both dst and src?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

{
fixed (float* psrc = src)
fixed (float* pdst = dst)
fixed (float* psrc = &MemoryMarshal.GetReference(src))
fixed (float* pdst = &MemoryMarshal.GetReference(dst))
{
float* pDstEnd = pdst + dst.Length;
float* pDstEnd = pdst + count;
float* pSrcCurrent = psrc;
float* pDstCurrent = pdst;

Expand Down Expand Up @@ -654,7 +655,7 @@ public static unsafe void ScaleSrcU(float scale, Span<float> src, Span<float> ds
// dst[i] = a * (dst[i] + b)
public static unsafe void ScaleAddU(float a, float b, Span<float> dst)
{
fixed (float* pdst = dst)
fixed (float* pdst = &MemoryMarshal.GetReference(dst))
{
float* pDstEnd = pdst + dst.Length;
float* pDstCurrent = pdst;
Expand Down Expand Up @@ -697,14 +698,14 @@ public static unsafe void ScaleAddU(float a, float b, Span<float> dst)
}
}

public static unsafe void AddScaleU(float scale, Span<float> src, Span<float> dst)
public static unsafe void AddScaleU(float scale, ReadOnlySpan<float> src, Span<float> dst, int count)
{
fixed (float* psrc = src)
fixed (float* pdst = dst)
fixed (float* psrc = &MemoryMarshal.GetReference(src))
fixed (float* pdst = &MemoryMarshal.GetReference(dst))
{
float* pSrcCurrent = psrc;
float* pDstCurrent = pdst;
float* pEnd = pdst + dst.Length;
float* pEnd = pdst + count;

Vector256<float> scaleVector256 = Avx.SetAllVector256(scale);

Expand Down Expand Up @@ -751,13 +752,13 @@ public static unsafe void AddScaleU(float scale, Span<float> src, Span<float> ds
}
}

public static unsafe void AddScaleCopyU(float scale, Span<float> src, Span<float> dst, Span<float> result)
public static unsafe void AddScaleCopyU(float scale, ReadOnlySpan<float> src, ReadOnlySpan<float> dst, Span<float> result, int count)
{
fixed (float* psrc = src)
fixed (float* pdst = dst)
fixed (float* pres = result)
fixed (float* psrc = &MemoryMarshal.GetReference(src))
fixed (float* pdst = &MemoryMarshal.GetReference(dst))
fixed (float* pres = &MemoryMarshal.GetReference(result))
{
float* pResEnd = pres + result.Length;
float* pResEnd = pres + count;
float* pSrcCurrent = psrc;
float* pDstCurrent = pdst;
float* pResCurrent = pres;
Expand Down Expand Up @@ -807,16 +808,16 @@ public static unsafe void AddScaleCopyU(float scale, Span<float> src, Span<float
}
}

public static unsafe void AddScaleSU(float scale, Span<float> src, Span<int> idx, Span<float> dst)
public static unsafe void AddScaleSU(float scale, ReadOnlySpan<float> src, ReadOnlySpan<int> idx, Span<float> dst, int count)
{
fixed (float* psrc = src)
fixed (int* pidx = idx)
fixed (float* pdst = dst)
fixed (float* psrc = &MemoryMarshal.GetReference(src))
fixed (int* pidx = &MemoryMarshal.GetReference(idx))
fixed (float* pdst = &MemoryMarshal.GetReference(dst))
{
float* pSrcCurrent = psrc;
int* pIdxCurrent = pidx;
float* pDstCurrent = pdst;
int* pEnd = pidx + idx.Length;
int* pEnd = pidx + count;

Vector256<float> scaleVector256 = Avx.SetAllVector256(scale);

Expand Down Expand Up @@ -858,14 +859,14 @@ public static unsafe void AddScaleSU(float scale, Span<float> src, Span<int> idx
}
}

public static unsafe void AddU(Span<float> src, Span<float> dst)
public static unsafe void AddU(ReadOnlySpan<float> src, Span<float> dst, int count)
{
fixed (float* psrc = src)
fixed (float* pdst = dst)
fixed (float* psrc = &MemoryMarshal.GetReference(src))
fixed (float* pdst = &MemoryMarshal.GetReference(dst))
{
float* pSrcCurrent = psrc;
float* pDstCurrent = pdst;
float* pEnd = psrc + src.Length;
float* pEnd = psrc + count;

while (pSrcCurrent + 8 <= pEnd)
{
Expand Down Expand Up @@ -905,16 +906,16 @@ public static unsafe void AddU(Span<float> src, Span<float> dst)
}
}

public static unsafe void AddSU(Span<float> src, Span<int> idx, Span<float> dst)
public static unsafe void AddSU(ReadOnlySpan<float> src, ReadOnlySpan<int> idx, Span<float> dst, int count)
{
fixed (float* psrc = src)
fixed (int* pidx = idx)
fixed (float* pdst = dst)
fixed (float* psrc = &MemoryMarshal.GetReference(src))
fixed (int* pidx = &MemoryMarshal.GetReference(idx))
fixed (float* pdst = &MemoryMarshal.GetReference(dst))
{
float* pSrcCurrent = psrc;
int* pIdxCurrent = pidx;
float* pDstCurrent = pdst;
int* pEnd = pidx + idx.Length;
int* pEnd = pidx + count;

while (pIdxCurrent + 8 <= pEnd)
{
Expand Down Expand Up @@ -950,16 +951,16 @@ public static unsafe void AddSU(Span<float> src, Span<int> idx, Span<float> dst)
}
}

public static unsafe void MulElementWiseU(Span<float> src1, Span<float> src2, Span<float> dst)
public static unsafe void MulElementWiseU(ReadOnlySpan<float> src1, ReadOnlySpan<float> src2, Span<float> dst, int count)
{
fixed (float* psrc1 = src1)
fixed (float* psrc2 = src2)
fixed (float* pdst = dst)
fixed (float* psrc1 = &MemoryMarshal.GetReference(src1))
fixed (float* psrc2 = &MemoryMarshal.GetReference(src2))
fixed (float* pdst = &MemoryMarshal.GetReference(dst))
{
float* pSrc1Current = psrc1;
float* pSrc2Current = psrc2;
float* pDstCurrent = pdst;
float* pEnd = pdst + dst.Length;
float* pEnd = pdst + count;

while (pDstCurrent + 8 <= pEnd)
{
Expand Down Expand Up @@ -999,9 +1000,9 @@ public static unsafe void MulElementWiseU(Span<float> src1, Span<float> src2, Sp
}
}

public static unsafe float SumU(Span<float> src)
public static unsafe float SumU(ReadOnlySpan<float> src)
{
fixed (float* psrc = src)
fixed (float* psrc = &MemoryMarshal.GetReference(src))
{
float* pSrcEnd = psrc + src.Length;
float* pSrcCurrent = psrc;
Expand Down Expand Up @@ -1037,9 +1038,9 @@ public static unsafe float SumU(Span<float> src)
}
}

public static unsafe float SumSqU(Span<float> src)
public static unsafe float SumSqU(ReadOnlySpan<float> src)
{
fixed (float* psrc = src)
fixed (float* psrc = &MemoryMarshal.GetReference(src))
{
float* pSrcEnd = psrc + src.Length;
float* pSrcCurrent = psrc;
Expand Down Expand Up @@ -1081,9 +1082,9 @@ public static unsafe float SumSqU(Span<float> src)
}
}

public static unsafe float SumSqDiffU(float mean, Span<float> src)
public static unsafe float SumSqDiffU(float mean, ReadOnlySpan<float> src)
{
fixed (float* psrc = src)
fixed (float* psrc = &MemoryMarshal.GetReference(src))
{
float* pSrcEnd = psrc + src.Length;
float* pSrcCurrent = psrc;
Expand Down Expand Up @@ -1130,9 +1131,9 @@ public static unsafe float SumSqDiffU(float mean, Span<float> src)
}
}

public static unsafe float SumAbsU(Span<float> src)
public static unsafe float SumAbsU(ReadOnlySpan<float> src)
{
fixed (float* psrc = src)
fixed (float* psrc = &MemoryMarshal.GetReference(src))
{
float* pSrcEnd = psrc + src.Length;
float* pSrcCurrent = psrc;
Expand Down Expand Up @@ -1174,9 +1175,9 @@ public static unsafe float SumAbsU(Span<float> src)
}
}

public static unsafe float SumAbsDiffU(float mean, Span<float> src)
public static unsafe float SumAbsDiffU(float mean, ReadOnlySpan<float> src)
{
fixed (float* psrc = src)
fixed (float* psrc = &MemoryMarshal.GetReference(src))
{
float* pSrcEnd = psrc + src.Length;
float* pSrcCurrent = psrc;
Expand Down Expand Up @@ -1223,9 +1224,9 @@ public static unsafe float SumAbsDiffU(float mean, Span<float> src)
}
}

public static unsafe float MaxAbsU(Span<float> src)
public static unsafe float MaxAbsU(ReadOnlySpan<float> src)
{
fixed (float* psrc = src)
fixed (float* psrc = &MemoryMarshal.GetReference(src))
{
float* pSrcEnd = psrc + src.Length;
float* pSrcCurrent = psrc;
Expand Down Expand Up @@ -1267,9 +1268,9 @@ public static unsafe float MaxAbsU(Span<float> src)
}
}

public static unsafe float MaxAbsDiffU(float mean, Span<float> src)
public static unsafe float MaxAbsDiffU(float mean, ReadOnlySpan<float> src)
{
fixed (float* psrc = src)
fixed (float* psrc = &MemoryMarshal.GetReference(src))
{
float* pSrcEnd = psrc + src.Length;
float* pSrcCurrent = psrc;
Expand Down Expand Up @@ -1316,14 +1317,14 @@ public static unsafe float MaxAbsDiffU(float mean, Span<float> src)
}
}

public static unsafe float DotU(Span<float> src, Span<float> dst)
public static unsafe float DotU(ReadOnlySpan<float> src, ReadOnlySpan<float> dst, int count)
{
fixed (float* psrc = src)
fixed (float* pdst = dst)
fixed (float* psrc = &MemoryMarshal.GetReference(src))
fixed (float* pdst = &MemoryMarshal.GetReference(dst))
{
float* pSrcCurrent = psrc;
float* pDstCurrent = pdst;
float* pSrcEnd = psrc + src.Length;
float* pSrcEnd = psrc + count;

Vector256<float> result256 = Avx.SetZeroVector256<float>();

Expand Down Expand Up @@ -1371,16 +1372,16 @@ public static unsafe float DotU(Span<float> src, Span<float> dst)
}
}

public static unsafe float DotSU(Span<float> src, Span<float> dst, Span<int> idx)
public static unsafe float DotSU(ReadOnlySpan<float> src, ReadOnlySpan<float> dst, ReadOnlySpan<int> idx, int count)
{
fixed (float* psrc = src)
fixed (float* pdst = dst)
fixed (int* pidx = idx)
fixed (float* psrc = &MemoryMarshal.GetReference(src))
fixed (float* pdst = &MemoryMarshal.GetReference(dst))
fixed (int* pidx = &MemoryMarshal.GetReference(idx))
{
float* pSrcCurrent = psrc;
float* pDstCurrent = pdst;
int* pIdxCurrent = pidx;
int* pIdxEnd = pidx + idx.Length;
int* pIdxEnd = pidx + count;

Vector256<float> result256 = Avx.SetZeroVector256<float>();

Expand Down Expand Up @@ -1428,14 +1429,14 @@ public static unsafe float DotSU(Span<float> src, Span<float> dst, Span<int> idx
}
}

public static unsafe float Dist2(Span<float> src, Span<float> dst)
public static unsafe float Dist2(ReadOnlySpan<float> src, ReadOnlySpan<float> dst, int count)
{
fixed (float* psrc = src)
fixed (float* pdst = dst)
fixed (float* psrc = &MemoryMarshal.GetReference(src))
fixed (float* pdst = &MemoryMarshal.GetReference(dst))
{
float* pSrcCurrent = psrc;
float* pDstCurrent = pdst;
float* pSrcEnd = psrc + src.Length;
float* pSrcEnd = psrc + count;

Vector256<float> sqDistanceVector256 = Avx.SetZeroVector256<float>();

Expand Down Expand Up @@ -1482,13 +1483,13 @@ public static unsafe float Dist2(Span<float> src, Span<float> dst)
}
}

public static unsafe void SdcaL1UpdateU(float primalUpdate, Span<float> src, float threshold, Span<float> v, Span<float> w)
public static unsafe void SdcaL1UpdateU(float primalUpdate, int count, ReadOnlySpan<float> src, float threshold, Span<float> v, Span<float> w)
{
fixed (float* psrc = src)
fixed (float* pdst1 = v)
fixed (float* pdst2 = w)
fixed (float* psrc = &MemoryMarshal.GetReference(src))
fixed (float* pdst1 = &MemoryMarshal.GetReference(v))
fixed (float* pdst2 = &MemoryMarshal.GetReference(w))
{
float* pSrcEnd = psrc + src.Length;
float* pSrcEnd = psrc + count;
float* pSrcCurrent = psrc;
float* pDst1Current = pdst1;
float* pDst2Current = pdst2;
Expand Down Expand Up @@ -1544,14 +1545,14 @@ public static unsafe void SdcaL1UpdateU(float primalUpdate, Span<float> src, flo
}
}

public static unsafe void SdcaL1UpdateSU(float primalUpdate, Span<float> src, Span<int> indices, float threshold, Span<float> v, Span<float> w)
public static unsafe void SdcaL1UpdateSU(float primalUpdate, int count, ReadOnlySpan<float> src, ReadOnlySpan<int> indices, float threshold, Span<float> v, Span<float> w)
{
fixed (float* psrc = src)
fixed (int* pidx = indices)
fixed (float* pdst1 = v)
fixed (float* pdst2 = w)
fixed (float* psrc = &MemoryMarshal.GetReference(src))
fixed (int* pidx = &MemoryMarshal.GetReference(indices))
fixed (float* pdst1 = &MemoryMarshal.GetReference(v))
fixed (float* pdst2 = &MemoryMarshal.GetReference(w))
{
int* pIdxEnd = pidx + indices.Length;
int* pIdxEnd = pidx + count;
float* pSrcCurrent = psrc;
int* pIdxCurrent = pidx;

Expand Down
Loading