-
Notifications
You must be signed in to change notification settings - Fork 1.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor CpuMathUtils #1229
Refactor CpuMathUtils #1229
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,6 +11,7 @@ | |
|
||
using System; | ||
using System.Runtime.CompilerServices; | ||
using System.Runtime.InteropServices; | ||
using System.Runtime.Intrinsics; | ||
using System.Runtime.Intrinsics.X86; | ||
using nuint = System.UInt64; | ||
|
@@ -448,7 +449,7 @@ public static unsafe void MatMulTranPX(bool add, AlignedArray mat, int[] rgposSr | |
// dst[i] += scale | ||
public static unsafe void AddScalarU(float scalar, Span<float> dst) | ||
{ | ||
fixed (float* pdst = dst) | ||
fixed (float* pdst = &MemoryMarshal.GetReference(dst)) | ||
{ | ||
float* pDstEnd = pdst + dst.Length; | ||
float* pDstCurrent = pdst; | ||
|
@@ -490,7 +491,7 @@ public static unsafe void Scale(float scale, Span<float> dst) | |
{ | ||
fixed (uint* pLeadingAlignmentMask = &LeadingAlignmentMask[0]) | ||
fixed (uint* pTrailingAlignmentMask = &TrailingAlignmentMask[0]) | ||
fixed (float* pd = dst) | ||
fixed (float* pd = &MemoryMarshal.GetReference(dst)) | ||
{ | ||
float* pDstCurrent = pd; | ||
int length = dst.Length; | ||
|
@@ -606,12 +607,12 @@ public static unsafe void Scale(float scale, Span<float> dst) | |
} | ||
} | ||
|
||
public static unsafe void ScaleSrcU(float scale, Span<float> src, Span<float> dst) | ||
public static unsafe void ScaleSrcU(float scale, ReadOnlySpan<float> src, Span<float> dst, int count) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we have a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These already happen in the calling |
||
{ | ||
fixed (float* psrc = src) | ||
fixed (float* pdst = dst) | ||
fixed (float* psrc = &MemoryMarshal.GetReference(src)) | ||
fixed (float* pdst = &MemoryMarshal.GetReference(dst)) | ||
{ | ||
float* pDstEnd = pdst + dst.Length; | ||
float* pDstEnd = pdst + count; | ||
float* pSrcCurrent = psrc; | ||
float* pDstCurrent = pdst; | ||
|
||
|
@@ -654,7 +655,7 @@ public static unsafe void ScaleSrcU(float scale, Span<float> src, Span<float> ds | |
// dst[i] = a * (dst[i] + b) | ||
public static unsafe void ScaleAddU(float a, float b, Span<float> dst) | ||
{ | ||
fixed (float* pdst = dst) | ||
fixed (float* pdst = &MemoryMarshal.GetReference(dst)) | ||
{ | ||
float* pDstEnd = pdst + dst.Length; | ||
float* pDstCurrent = pdst; | ||
|
@@ -697,14 +698,14 @@ public static unsafe void ScaleAddU(float a, float b, Span<float> dst) | |
} | ||
} | ||
|
||
public static unsafe void AddScaleU(float scale, Span<float> src, Span<float> dst) | ||
public static unsafe void AddScaleU(float scale, ReadOnlySpan<float> src, Span<float> dst, int count) | ||
{ | ||
fixed (float* psrc = src) | ||
fixed (float* pdst = dst) | ||
fixed (float* psrc = &MemoryMarshal.GetReference(src)) | ||
fixed (float* pdst = &MemoryMarshal.GetReference(dst)) | ||
{ | ||
float* pSrcCurrent = psrc; | ||
float* pDstCurrent = pdst; | ||
float* pEnd = pdst + dst.Length; | ||
float* pEnd = pdst + count; | ||
|
||
Vector256<float> scaleVector256 = Avx.SetAllVector256(scale); | ||
|
||
|
@@ -751,13 +752,13 @@ public static unsafe void AddScaleU(float scale, Span<float> src, Span<float> ds | |
} | ||
} | ||
|
||
public static unsafe void AddScaleCopyU(float scale, Span<float> src, Span<float> dst, Span<float> result) | ||
public static unsafe void AddScaleCopyU(float scale, ReadOnlySpan<float> src, ReadOnlySpan<float> dst, Span<float> result, int count) | ||
{ | ||
fixed (float* psrc = src) | ||
fixed (float* pdst = dst) | ||
fixed (float* pres = result) | ||
fixed (float* psrc = &MemoryMarshal.GetReference(src)) | ||
fixed (float* pdst = &MemoryMarshal.GetReference(dst)) | ||
fixed (float* pres = &MemoryMarshal.GetReference(result)) | ||
{ | ||
float* pResEnd = pres + result.Length; | ||
float* pResEnd = pres + count; | ||
float* pSrcCurrent = psrc; | ||
float* pDstCurrent = pdst; | ||
float* pResCurrent = pres; | ||
|
@@ -807,16 +808,16 @@ public static unsafe void AddScaleCopyU(float scale, Span<float> src, Span<float | |
} | ||
} | ||
|
||
public static unsafe void AddScaleSU(float scale, Span<float> src, Span<int> idx, Span<float> dst) | ||
public static unsafe void AddScaleSU(float scale, ReadOnlySpan<float> src, ReadOnlySpan<int> idx, Span<float> dst, int count) | ||
{ | ||
fixed (float* psrc = src) | ||
fixed (int* pidx = idx) | ||
fixed (float* pdst = dst) | ||
fixed (float* psrc = &MemoryMarshal.GetReference(src)) | ||
fixed (int* pidx = &MemoryMarshal.GetReference(idx)) | ||
fixed (float* pdst = &MemoryMarshal.GetReference(dst)) | ||
{ | ||
float* pSrcCurrent = psrc; | ||
int* pIdxCurrent = pidx; | ||
float* pDstCurrent = pdst; | ||
int* pEnd = pidx + idx.Length; | ||
int* pEnd = pidx + count; | ||
|
||
Vector256<float> scaleVector256 = Avx.SetAllVector256(scale); | ||
|
||
|
@@ -858,14 +859,14 @@ public static unsafe void AddScaleSU(float scale, Span<float> src, Span<int> idx | |
} | ||
} | ||
|
||
public static unsafe void AddU(Span<float> src, Span<float> dst) | ||
public static unsafe void AddU(ReadOnlySpan<float> src, Span<float> dst, int count) | ||
{ | ||
fixed (float* psrc = src) | ||
fixed (float* pdst = dst) | ||
fixed (float* psrc = &MemoryMarshal.GetReference(src)) | ||
fixed (float* pdst = &MemoryMarshal.GetReference(dst)) | ||
{ | ||
float* pSrcCurrent = psrc; | ||
float* pDstCurrent = pdst; | ||
float* pEnd = psrc + src.Length; | ||
float* pEnd = psrc + count; | ||
|
||
while (pSrcCurrent + 8 <= pEnd) | ||
{ | ||
|
@@ -905,16 +906,16 @@ public static unsafe void AddU(Span<float> src, Span<float> dst) | |
} | ||
} | ||
|
||
public static unsafe void AddSU(Span<float> src, Span<int> idx, Span<float> dst) | ||
public static unsafe void AddSU(ReadOnlySpan<float> src, ReadOnlySpan<int> idx, Span<float> dst, int count) | ||
{ | ||
fixed (float* psrc = src) | ||
fixed (int* pidx = idx) | ||
fixed (float* pdst = dst) | ||
fixed (float* psrc = &MemoryMarshal.GetReference(src)) | ||
fixed (int* pidx = &MemoryMarshal.GetReference(idx)) | ||
fixed (float* pdst = &MemoryMarshal.GetReference(dst)) | ||
{ | ||
float* pSrcCurrent = psrc; | ||
int* pIdxCurrent = pidx; | ||
float* pDstCurrent = pdst; | ||
int* pEnd = pidx + idx.Length; | ||
int* pEnd = pidx + count; | ||
|
||
while (pIdxCurrent + 8 <= pEnd) | ||
{ | ||
|
@@ -950,16 +951,16 @@ public static unsafe void AddSU(Span<float> src, Span<int> idx, Span<float> dst) | |
} | ||
} | ||
|
||
public static unsafe void MulElementWiseU(Span<float> src1, Span<float> src2, Span<float> dst) | ||
public static unsafe void MulElementWiseU(ReadOnlySpan<float> src1, ReadOnlySpan<float> src2, Span<float> dst, int count) | ||
{ | ||
fixed (float* psrc1 = src1) | ||
fixed (float* psrc2 = src2) | ||
fixed (float* pdst = dst) | ||
fixed (float* psrc1 = &MemoryMarshal.GetReference(src1)) | ||
fixed (float* psrc2 = &MemoryMarshal.GetReference(src2)) | ||
fixed (float* pdst = &MemoryMarshal.GetReference(dst)) | ||
{ | ||
float* pSrc1Current = psrc1; | ||
float* pSrc2Current = psrc2; | ||
float* pDstCurrent = pdst; | ||
float* pEnd = pdst + dst.Length; | ||
float* pEnd = pdst + count; | ||
|
||
while (pDstCurrent + 8 <= pEnd) | ||
{ | ||
|
@@ -999,9 +1000,9 @@ public static unsafe void MulElementWiseU(Span<float> src1, Span<float> src2, Sp | |
} | ||
} | ||
|
||
public static unsafe float SumU(Span<float> src) | ||
public static unsafe float SumU(ReadOnlySpan<float> src) | ||
{ | ||
fixed (float* psrc = src) | ||
fixed (float* psrc = &MemoryMarshal.GetReference(src)) | ||
{ | ||
float* pSrcEnd = psrc + src.Length; | ||
float* pSrcCurrent = psrc; | ||
|
@@ -1037,9 +1038,9 @@ public static unsafe float SumU(Span<float> src) | |
} | ||
} | ||
|
||
public static unsafe float SumSqU(Span<float> src) | ||
public static unsafe float SumSqU(ReadOnlySpan<float> src) | ||
{ | ||
fixed (float* psrc = src) | ||
fixed (float* psrc = &MemoryMarshal.GetReference(src)) | ||
{ | ||
float* pSrcEnd = psrc + src.Length; | ||
float* pSrcCurrent = psrc; | ||
|
@@ -1081,9 +1082,9 @@ public static unsafe float SumSqU(Span<float> src) | |
} | ||
} | ||
|
||
public static unsafe float SumSqDiffU(float mean, Span<float> src) | ||
public static unsafe float SumSqDiffU(float mean, ReadOnlySpan<float> src) | ||
{ | ||
fixed (float* psrc = src) | ||
fixed (float* psrc = &MemoryMarshal.GetReference(src)) | ||
{ | ||
float* pSrcEnd = psrc + src.Length; | ||
float* pSrcCurrent = psrc; | ||
|
@@ -1130,9 +1131,9 @@ public static unsafe float SumSqDiffU(float mean, Span<float> src) | |
} | ||
} | ||
|
||
public static unsafe float SumAbsU(Span<float> src) | ||
public static unsafe float SumAbsU(ReadOnlySpan<float> src) | ||
{ | ||
fixed (float* psrc = src) | ||
fixed (float* psrc = &MemoryMarshal.GetReference(src)) | ||
{ | ||
float* pSrcEnd = psrc + src.Length; | ||
float* pSrcCurrent = psrc; | ||
|
@@ -1174,9 +1175,9 @@ public static unsafe float SumAbsU(Span<float> src) | |
} | ||
} | ||
|
||
public static unsafe float SumAbsDiffU(float mean, Span<float> src) | ||
public static unsafe float SumAbsDiffU(float mean, ReadOnlySpan<float> src) | ||
{ | ||
fixed (float* psrc = src) | ||
fixed (float* psrc = &MemoryMarshal.GetReference(src)) | ||
{ | ||
float* pSrcEnd = psrc + src.Length; | ||
float* pSrcCurrent = psrc; | ||
|
@@ -1223,9 +1224,9 @@ public static unsafe float SumAbsDiffU(float mean, Span<float> src) | |
} | ||
} | ||
|
||
public static unsafe float MaxAbsU(Span<float> src) | ||
public static unsafe float MaxAbsU(ReadOnlySpan<float> src) | ||
{ | ||
fixed (float* psrc = src) | ||
fixed (float* psrc = &MemoryMarshal.GetReference(src)) | ||
{ | ||
float* pSrcEnd = psrc + src.Length; | ||
float* pSrcCurrent = psrc; | ||
|
@@ -1267,9 +1268,9 @@ public static unsafe float MaxAbsU(Span<float> src) | |
} | ||
} | ||
|
||
public static unsafe float MaxAbsDiffU(float mean, Span<float> src) | ||
public static unsafe float MaxAbsDiffU(float mean, ReadOnlySpan<float> src) | ||
{ | ||
fixed (float* psrc = src) | ||
fixed (float* psrc = &MemoryMarshal.GetReference(src)) | ||
{ | ||
float* pSrcEnd = psrc + src.Length; | ||
float* pSrcCurrent = psrc; | ||
|
@@ -1316,14 +1317,14 @@ public static unsafe float MaxAbsDiffU(float mean, Span<float> src) | |
} | ||
} | ||
|
||
public static unsafe float DotU(Span<float> src, Span<float> dst) | ||
public static unsafe float DotU(ReadOnlySpan<float> src, ReadOnlySpan<float> dst, int count) | ||
{ | ||
fixed (float* psrc = src) | ||
fixed (float* pdst = dst) | ||
fixed (float* psrc = &MemoryMarshal.GetReference(src)) | ||
fixed (float* pdst = &MemoryMarshal.GetReference(dst)) | ||
{ | ||
float* pSrcCurrent = psrc; | ||
float* pDstCurrent = pdst; | ||
float* pSrcEnd = psrc + src.Length; | ||
float* pSrcEnd = psrc + count; | ||
|
||
Vector256<float> result256 = Avx.SetZeroVector256<float>(); | ||
|
||
|
@@ -1371,16 +1372,16 @@ public static unsafe float DotU(Span<float> src, Span<float> dst) | |
} | ||
} | ||
|
||
public static unsafe float DotSU(Span<float> src, Span<float> dst, Span<int> idx) | ||
public static unsafe float DotSU(ReadOnlySpan<float> src, ReadOnlySpan<float> dst, ReadOnlySpan<int> idx, int count) | ||
{ | ||
fixed (float* psrc = src) | ||
fixed (float* pdst = dst) | ||
fixed (int* pidx = idx) | ||
fixed (float* psrc = &MemoryMarshal.GetReference(src)) | ||
fixed (float* pdst = &MemoryMarshal.GetReference(dst)) | ||
fixed (int* pidx = &MemoryMarshal.GetReference(idx)) | ||
{ | ||
float* pSrcCurrent = psrc; | ||
float* pDstCurrent = pdst; | ||
int* pIdxCurrent = pidx; | ||
int* pIdxEnd = pidx + idx.Length; | ||
int* pIdxEnd = pidx + count; | ||
|
||
Vector256<float> result256 = Avx.SetZeroVector256<float>(); | ||
|
||
|
@@ -1428,14 +1429,14 @@ public static unsafe float DotSU(Span<float> src, Span<float> dst, Span<int> idx | |
} | ||
} | ||
|
||
public static unsafe float Dist2(Span<float> src, Span<float> dst) | ||
public static unsafe float Dist2(ReadOnlySpan<float> src, ReadOnlySpan<float> dst, int count) | ||
{ | ||
fixed (float* psrc = src) | ||
fixed (float* pdst = dst) | ||
fixed (float* psrc = &MemoryMarshal.GetReference(src)) | ||
fixed (float* pdst = &MemoryMarshal.GetReference(dst)) | ||
{ | ||
float* pSrcCurrent = psrc; | ||
float* pDstCurrent = pdst; | ||
float* pSrcEnd = psrc + src.Length; | ||
float* pSrcEnd = psrc + count; | ||
|
||
Vector256<float> sqDistanceVector256 = Avx.SetZeroVector256<float>(); | ||
|
||
|
@@ -1482,13 +1483,13 @@ public static unsafe float Dist2(Span<float> src, Span<float> dst) | |
} | ||
} | ||
|
||
public static unsafe void SdcaL1UpdateU(float primalUpdate, Span<float> src, float threshold, Span<float> v, Span<float> w) | ||
public static unsafe void SdcaL1UpdateU(float primalUpdate, int count, ReadOnlySpan<float> src, float threshold, Span<float> v, Span<float> w) | ||
{ | ||
fixed (float* psrc = src) | ||
fixed (float* pdst1 = v) | ||
fixed (float* pdst2 = w) | ||
fixed (float* psrc = &MemoryMarshal.GetReference(src)) | ||
fixed (float* pdst1 = &MemoryMarshal.GetReference(v)) | ||
fixed (float* pdst2 = &MemoryMarshal.GetReference(w)) | ||
{ | ||
float* pSrcEnd = psrc + src.Length; | ||
float* pSrcEnd = psrc + count; | ||
float* pSrcCurrent = psrc; | ||
float* pDst1Current = pdst1; | ||
float* pDst2Current = pdst2; | ||
|
@@ -1544,14 +1545,14 @@ public static unsafe void SdcaL1UpdateU(float primalUpdate, Span<float> src, flo | |
} | ||
} | ||
|
||
public static unsafe void SdcaL1UpdateSU(float primalUpdate, Span<float> src, Span<int> indices, float threshold, Span<float> v, Span<float> w) | ||
public static unsafe void SdcaL1UpdateSU(float primalUpdate, int count, ReadOnlySpan<float> src, ReadOnlySpan<int> indices, float threshold, Span<float> v, Span<float> w) | ||
{ | ||
fixed (float* psrc = src) | ||
fixed (int* pidx = indices) | ||
fixed (float* pdst1 = v) | ||
fixed (float* pdst2 = w) | ||
fixed (float* psrc = &MemoryMarshal.GetReference(src)) | ||
fixed (int* pidx = &MemoryMarshal.GetReference(indices)) | ||
fixed (float* pdst1 = &MemoryMarshal.GetReference(v)) | ||
fixed (float* pdst2 = &MemoryMarshal.GetReference(w)) | ||
{ | ||
int* pIdxEnd = pidx + indices.Length; | ||
int* pIdxEnd = pidx + count; | ||
float* pSrcCurrent = psrc; | ||
int* pIdxCurrent = pidx; | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might be good to have an analyzer for this...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not necessarily. You have to ensure that the span is non-empty, because if you call this on an empty span, you'll get back a garbage pointer.
Also, there are some concerns that the .NET Core team has that these calls shouldn't be necessary. See the conversation at dotnet/corefx#32669 (comment), so it may not be worth it to write the analyzer if the recommendation isn't to use this method everywhere.