Skip to content

Commit

Permalink
Fixing SetSlice, Reshape, TryCopyTo. (dotnet#107852)
Browse files Browse the repository at this point in the history
* working

* comments from PR

* can always reshape to self

* fixed tests

* comments from PR

* fixing tests
  • Loading branch information
michaelgsharp authored and sirntar committed Sep 30, 2024
1 parent aaeb350 commit fa11890
Show file tree
Hide file tree
Showing 9 changed files with 387 additions and 172 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -228,4 +228,7 @@
<data name="ThrowArgument_StackShapesNotSame" xml:space="preserve">
<value>All tensors must have the same shape.</value>
</data>
<data name="Argument_CannotReshapeNonContiguousOrDense" xml:space="preserve">
<value>The Tensor provided is either non-contiguous or non-dense. Reshape only works with contigous and dense memory. You may need to Broadcast or Copy the data to be contigous.</value>
</data>
</root>
Original file line number Diff line number Diff line change
Expand Up @@ -520,31 +520,39 @@ public void CopyTo(scoped TensorSpan<T> destination)
// Using "if (!TryCopyTo(...))" results in two branches: one for the length
// check, and one for the result of TryCopyTo. Since these checks are equivalent,
// we can optimize by performing the check once ourselves then calling Memmove directly.
if (_shape.FlattenedLength <= destination.FlattenedLength)
if (TensorHelpers.IsBroadcastableTo(Lengths, destination.Lengths))
{
scoped Span<nint> curIndexes;
nint[]? curIndexesArray;

if (Rank > TensorShape.MaxInlineRank)
{
curIndexesArray = ArrayPool<nint>.Shared.Rent(Rank);
curIndexes = curIndexesArray.AsSpan(0, Rank);
curIndexesArray = ArrayPool<nint>.Shared.Rent(destination.Rank);
curIndexes = curIndexesArray.AsSpan(0, destination.Rank);

}
else
{
curIndexesArray = null;
curIndexes = stackalloc nint[Rank];
curIndexes = stackalloc nint[destination.Rank];
}
curIndexes.Clear();

nint copiedValues = 0;
TensorSpan<T> slice = destination.Slice(_shape.Lengths);
while (copiedValues < _shape.FlattenedLength)
nint[] tempLengths = Tensor.GetSmallestBroadcastableLengths(Lengths, destination.Lengths);

TensorSpan<T> destinationSlice = destination.Slice(tempLengths);
ReadOnlyTensorSpan<T> srcSlice = Tensor.LazyBroadcast(this, tempLengths);
nint copyLength = srcSlice.Strides[^1] == 1 && TensorHelpers.IsContiguousAndDense(srcSlice) ? srcSlice.Lengths[^1] : 1;
int indexToAdjust = srcSlice.Strides[^1] == 1 && TensorHelpers.IsContiguousAndDense(srcSlice) ? srcSlice.Rank - 2 : srcSlice.Rank - 1;

while (copiedValues < destination.FlattenedLength)
{
TensorSpanHelpers.Memmove(ref Unsafe.Add(ref slice._reference, TensorSpanHelpers.ComputeLinearIndex(curIndexes, Strides, Lengths)), ref Unsafe.Add(ref _reference, TensorSpanHelpers.ComputeLinearIndex(curIndexes, Strides, Lengths)), Lengths[Rank - 1]);
TensorSpanHelpers.AdjustIndexes(Rank - 2, 1, curIndexes, _shape.Lengths);
copiedValues += Lengths[Rank - 1];
TensorSpanHelpers.Memmove(ref Unsafe.Add(ref destinationSlice._reference, TensorSpanHelpers.ComputeLinearIndex(curIndexes, destinationSlice.Strides, destinationSlice.Lengths)), ref Unsafe.Add(ref srcSlice._reference, TensorSpanHelpers.ComputeLinearIndex(curIndexes, srcSlice.Strides, srcSlice.Lengths)), copyLength);
TensorSpanHelpers.AdjustIndexes(indexToAdjust, 1, curIndexes, tempLengths);
copiedValues += copyLength;
}
Debug.Assert(copiedValues == _shape.FlattenedLength, "Didn't copy the right amount to the array.");
Debug.Assert(copiedValues == destination.FlattenedLength, "Didn't copy the right amount to the array.");

if (curIndexesArray != null)
ArrayPool<nint>.Shared.Return(curIndexesArray);
Expand All @@ -568,32 +576,40 @@ public bool TryCopyTo(scoped TensorSpan<T> destination)
{
bool retVal = false;

if (_shape.FlattenedLength <= destination.FlattenedLength)
if (TensorHelpers.IsBroadcastableTo(Lengths, destination.Lengths))
{
scoped Span<nint> curIndexes;
nint[]? curIndexesArray;

if (Rank > TensorShape.MaxInlineRank)
{
curIndexesArray = ArrayPool<nint>.Shared.Rent(Rank);
curIndexes = curIndexesArray.AsSpan(0, Rank);
curIndexesArray = ArrayPool<nint>.Shared.Rent(destination.Rank);
curIndexes = curIndexesArray.AsSpan(0, destination.Rank);

}
else
{
curIndexesArray = null;
curIndexes = stackalloc nint[Rank];
curIndexes = stackalloc nint[destination.Rank];
}
curIndexes.Clear();

nint copiedValues = 0;
TensorSpan<T> slice = destination.Slice(_shape.Lengths);
while (copiedValues < _shape.FlattenedLength)
nint[] tempLengths = Tensor.GetSmallestBroadcastableLengths(Lengths, destination.Lengths);

TensorSpan<T> destinationSlice = destination.Slice(tempLengths);
ReadOnlyTensorSpan<T> srcSlice = Tensor.LazyBroadcast(this, tempLengths);
nint copyLength = srcSlice.Strides[^1] == 1 && TensorHelpers.IsContiguousAndDense(srcSlice) ? srcSlice.Lengths[^1] : 1;
int indexToAdjust = srcSlice.Strides[^1] == 1 && TensorHelpers.IsContiguousAndDense(srcSlice) ? srcSlice.Rank - 2 : srcSlice.Rank - 1;

while (copiedValues < destination.FlattenedLength)
{
TensorSpanHelpers.Memmove(ref Unsafe.Add(ref slice._reference, TensorSpanHelpers.ComputeLinearIndex(curIndexes, Strides, Lengths)), ref Unsafe.Add(ref _reference, TensorSpanHelpers.ComputeLinearIndex(curIndexes, Strides, Lengths)), Lengths[Rank - 1]);
TensorSpanHelpers.AdjustIndexes(Rank - 2, 1, curIndexes, _shape.Lengths);
copiedValues += Lengths[Rank - 1];
TensorSpanHelpers.Memmove(ref Unsafe.Add(ref destinationSlice._reference, TensorSpanHelpers.ComputeLinearIndex(curIndexes, destinationSlice.Strides, destinationSlice.Lengths)), ref Unsafe.Add(ref srcSlice._reference, TensorSpanHelpers.ComputeLinearIndex(curIndexes, srcSlice.Strides, srcSlice.Lengths)), copyLength);
TensorSpanHelpers.AdjustIndexes(indexToAdjust, 1, curIndexes, tempLengths);
copiedValues += copyLength;
}
Debug.Assert(copiedValues == destination.FlattenedLength, "Didn't copy the right amount to the array.");
retVal = true;
Debug.Assert(copiedValues == _shape.FlattenedLength, "Didn't copy the right amount to the array.");

if (curIndexesArray != null)
ArrayPool<nint>.Shared.Return(curIndexesArray);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2724,6 +2724,14 @@ public static Tensor<T> PermuteDimensions<T>(this Tensor<T> tensor, params ReadO
/// <param name="lengths"><see cref="ReadOnlySpan{T}"/> with the new dimensions.</param>
public static Tensor<T> Reshape<T>(this Tensor<T> tensor, params ReadOnlySpan<nint> lengths)
{
if (tensor.Lengths.SequenceEqual(lengths))
return tensor;

if (!TensorHelpers.IsContiguousAndDense<T>(tensor) && !tensor.Strides.Contains(0))
{
ThrowHelper.ThrowArgument_CannotReshapeNonContiguousOrDense();
}

nint[] arrLengths = lengths.ToArray();
// Calculate wildcard info.
if (lengths.Contains(-1))
Expand All @@ -2745,7 +2753,33 @@ public static Tensor<T> Reshape<T>(this Tensor<T> tensor, params ReadOnlySpan<ni
nint tempLinear = TensorSpanHelpers.CalculateTotalLength(arrLengths);
if (tempLinear != tensor.FlattenedLength)
ThrowHelper.ThrowArgument_InvalidReshapeDimensions();
nint[] strides = TensorSpanHelpers.CalculateStrides(arrLengths);

nint[] strides;

// If we contain a 0 stride we can only add dimensions of length 1.
if (tensor.Strides.Contains(0))
{
List<nint> origStrides = new List<nint>(tensor.Strides.ToArray());
int lengthOffset = 0;
for (int i = 0; i < arrLengths.Length; i++)
{
if (lengthOffset < tensor.Rank && arrLengths[i] == tensor.Lengths[lengthOffset])
lengthOffset++;
else if (arrLengths[i] == 1)
{
if (lengthOffset == tensor.Rank)
origStrides.Add(tensor.Strides[lengthOffset - 1]);
else
origStrides.Insert(i, tensor.Strides[i] * tensor.Lengths[i]);
}
else
ThrowHelper.ThrowArgument_InvalidReshapeDimensions();
}
strides = origStrides.ToArray();
}
else
strides = TensorSpanHelpers.CalculateStrides(arrLengths);

return new Tensor<T>(tensor._values, arrLengths, strides);
}

Expand All @@ -2758,6 +2792,14 @@ public static Tensor<T> Reshape<T>(this Tensor<T> tensor, params ReadOnlySpan<ni
/// <param name="lengths"><see cref="ReadOnlySpan{T}"/> with the new dimensions.</param>
public static TensorSpan<T> Reshape<T>(in this TensorSpan<T> tensor, params scoped ReadOnlySpan<nint> lengths)
{
if (tensor.Lengths.SequenceEqual(lengths))
return tensor;

if (!TensorHelpers.IsContiguousAndDense<T>(tensor) && !tensor.Strides.Contains(0))
{
ThrowHelper.ThrowArgument_CannotReshapeNonContiguousOrDense();
}

nint[] arrLengths = lengths.ToArray();
// Calculate wildcard info.
if (lengths.Contains(-1))
Expand All @@ -2779,7 +2821,35 @@ public static TensorSpan<T> Reshape<T>(in this TensorSpan<T> tensor, params scop
nint tempLinear = TensorSpanHelpers.CalculateTotalLength(arrLengths);
if (tempLinear != tensor.FlattenedLength)
ThrowHelper.ThrowArgument_InvalidReshapeDimensions();
nint[] strides = TensorSpanHelpers.CalculateStrides(arrLengths);

nint[] strides;

// If we contain a 0 stride we can only add dimensions of length 1.
if (tensor.Strides.Contains(0))
{
List<nint> origStrides = new List<nint>(tensor.Strides.ToArray());
int lengthOffset = 0;
for (int i = 0; i < arrLengths.Length; i++)
{
if (lengthOffset < tensor.Rank && arrLengths[i] == tensor.Lengths[lengthOffset])
{
lengthOffset++;
}
else if (arrLengths[i] == 1)
{
if (lengthOffset == tensor.Rank)
origStrides.Add(tensor.Strides[lengthOffset - 1]);
else
origStrides.Insert(i, tensor.Strides[i] * tensor.Lengths[i]);
}
else
ThrowHelper.ThrowArgument_InvalidReshapeDimensions();
}
strides = origStrides.ToArray();
}
else
strides = TensorSpanHelpers.CalculateStrides(arrLengths);

TensorSpan<T> output = new TensorSpan<T>(ref tensor._reference, arrLengths, strides, tensor._shape._memoryLength);
return output;
}
Expand All @@ -2793,6 +2863,14 @@ public static TensorSpan<T> Reshape<T>(in this TensorSpan<T> tensor, params scop
/// <param name="lengths"><see cref="ReadOnlySpan{T}"/> with the new dimensions.</param>
public static ReadOnlyTensorSpan<T> Reshape<T>(in this ReadOnlyTensorSpan<T> tensor, params scoped ReadOnlySpan<nint> lengths)
{
if (tensor.Lengths.SequenceEqual(lengths))
return tensor;

if (!TensorHelpers.IsContiguousAndDense<T>(tensor) && !tensor.Strides.Contains(0))
{
ThrowHelper.ThrowArgument_CannotReshapeNonContiguousOrDense();
}

nint[] arrLengths = lengths.ToArray();
// Calculate wildcard info.
if (lengths.Contains(-1))
Expand All @@ -2814,7 +2892,33 @@ public static ReadOnlyTensorSpan<T> Reshape<T>(in this ReadOnlyTensorSpan<T> ten
nint tempLinear = TensorSpanHelpers.CalculateTotalLength(arrLengths);
if (tempLinear != tensor.FlattenedLength)
ThrowHelper.ThrowArgument_InvalidReshapeDimensions();
nint[] strides = TensorSpanHelpers.CalculateStrides(arrLengths);

nint[] strides;

// If we contain a 0 stride we can only add dimensions of length 1.
if (tensor.Strides.Contains(0))
{
List<nint> origStrides = new List<nint>(tensor.Strides.ToArray());
int lengthOffset = 0;
for (int i = 0; i < arrLengths.Length; i++)
{
if (lengthOffset < tensor.Rank && arrLengths[i] == tensor.Lengths[lengthOffset])
lengthOffset++;
else if (arrLengths[i] == 1)
{
if (lengthOffset == tensor.Rank)
origStrides.Add(tensor.Strides[lengthOffset - 1]);
else
origStrides.Insert(i, tensor.Strides[i] * tensor.Lengths[i]);
}
else
ThrowHelper.ThrowArgument_InvalidReshapeDimensions();
}
strides = origStrides.ToArray();
}
else
strides = TensorSpanHelpers.CalculateStrides(arrLengths);

ReadOnlyTensorSpan<T> output = new ReadOnlyTensorSpan<T>(ref tensor._reference, arrLengths, strides, tensor._shape._memoryLength);
return output;
}
Expand Down Expand Up @@ -3053,14 +3157,17 @@ public static ref readonly TensorSpan<T> SetSlice<T>(this in TensorSpan<T> tenso
TensorSpan<T> srcSpan;
if (ranges == ReadOnlySpan<NRange>.Empty)
{
if (!tensor.Lengths.SequenceEqual(values.Lengths))
if (!TensorHelpers.IsBroadcastableTo(values.Lengths, tensor.Lengths))
ThrowHelper.ThrowArgument_SetSliceNoRange(nameof(values));
srcSpan = tensor.Slice(tensor.Lengths);
srcSpan = tensor;
}
else
srcSpan = tensor.Slice(ranges);

if (!srcSpan.Lengths.SequenceEqual(values.Lengths))
if (!TensorHelpers.IsContiguousAndDense<T>(srcSpan))
ThrowHelper.ThrowArgument_SetSliceInvalidShapes(nameof(values));

if (!TensorHelpers.IsBroadcastableTo(values.Lengths, srcSpan.Lengths))
ThrowHelper.ThrowArgument_SetSliceInvalidShapes(nameof(values));

values.CopyTo(srcSpan);
Expand Down Expand Up @@ -3555,8 +3662,13 @@ public static Tensor<T> Unsqueeze<T>(this Tensor<T> tensor, int dimension)

List<nint> tempLengths = tensor._lengths.ToList();
tempLengths.Insert(dimension, 1);
nint[] lengths = tempLengths.ToArray();
nint[] strides = TensorSpanHelpers.CalculateStrides(lengths);
nint[] lengths = [.. tempLengths];
List<nint> tempStrides = tensor.Strides.ToArray().ToList();
if (dimension == tensor.Rank)
tempStrides.Add(tensor.Strides[dimension - 1]);
else
tempStrides.Insert(dimension, tensor.Strides[dimension] * tensor.Lengths[dimension]);
nint[] strides = [.. tempStrides];
return new Tensor<T>(tensor._values, lengths, strides);
}

Expand All @@ -3574,8 +3686,13 @@ public static TensorSpan<T> Unsqueeze<T>(in this TensorSpan<T> tensor, int dimen

List<nint> tempLengths = tensor.Lengths.ToArray().ToList();
tempLengths.Insert(dimension, 1);
nint[] lengths = tempLengths.ToArray();
nint[] strides = TensorSpanHelpers.CalculateStrides(lengths);
nint[] lengths = [.. tempLengths];
List<nint> tempStrides = tensor.Strides.ToArray().ToList();
if (dimension == tensor.Rank)
tempStrides.Add(tensor.Strides[dimension - 1]);
else
tempStrides.Insert(dimension, tensor.Strides[dimension] * tensor.Lengths[dimension]);
nint[] strides = [.. tempStrides];
return new TensorSpan<T>(ref tensor._reference, lengths, strides, tensor._shape._memoryLength);
}

Expand All @@ -3593,8 +3710,13 @@ public static ReadOnlyTensorSpan<T> Unsqueeze<T>(in this ReadOnlyTensorSpan<T> t

List<nint> tempLengths = tensor.Lengths.ToArray().ToList();
tempLengths.Insert(dimension, 1);
nint[] lengths = tempLengths.ToArray();
nint[] strides = TensorSpanHelpers.CalculateStrides(lengths);
nint[] lengths = [.. tempLengths];
List<nint> tempStrides = tensor.Strides.ToArray().ToList();
if (dimension == tensor.Rank)
tempStrides.Add(tensor.Strides[dimension - 1]);
else
tempStrides.Insert(dimension, tensor.Strides[dimension] * tensor.Lengths[dimension]);
nint[] strides = [.. tempStrides];
return new ReadOnlyTensorSpan<T>(ref tensor._reference, lengths, strides, tensor._shape._memoryLength);
}
#endregion
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ internal static bool IsBroadcastableTo(ReadOnlySpan<nint> lengths1, ReadOnlySpan
nint s1;
nint s2;

if (lengths1.Length == 0 || lengths2.Length == 0)
return false;

while (lengths1Index >= 0 || lengths2Index >= 0)
{
// if a dimension is missing in one of the shapes, it is considered to be 1
Expand All @@ -56,7 +59,7 @@ internal static bool IsBroadcastableTo(ReadOnlySpan<nint> lengths1, ReadOnlySpan
else
s2 = lengths2[lengths2Index--];

if (s1 == s2 || (s1 == 1 && s2 != 1) || (s2 == 1 && s1 != 1)) { }
if (s1 == s2 || (s1 == 1 && s2 > 1) || (s2 == 1 && s1 > 1)) { }
else
{
areCompatible = false;
Expand Down
Loading

0 comments on commit fa11890

Please sign in to comment.