Skip to content

Commit

Permalink
reshape(::Array, Val{N}) always returns an Array
Browse files Browse the repository at this point in the history
(cherry picked from commit a3e6fcf)
ref #18160

reshape: only call to_shape when it will return Dims

(cherry picked from commit d92b2db)
  • Loading branch information
timholy authored and tkelman committed Aug 21, 2016
1 parent a6a7fd2 commit 69a605b
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 10 deletions.
31 changes: 21 additions & 10 deletions base/reshapedarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,9 @@ start(R::ReshapedArrayIterator) = start(R.iter)
end
length(R::ReshapedArrayIterator) = length(R.iter)

reshape(parent::AbstractArray, shp::Tuple) = _reshape(parent, to_shape(shp))
reshape(parent::AbstractArray, dims::IntOrInd...) = reshape(parent, dims)
reshape(parent::AbstractArray, shp::NeedsShaping) = reshape(parent, to_shape(shp))
reshape(parent::AbstractArray, dims::Dims) = _reshape(parent, dims)

reshape{T,N}(parent::AbstractArray{T,N}, ndims::Type{Val{N}}) = parent
function reshape{T,AN,N}(parent::AbstractArray{T,AN}, ndims::Type{Val{N}})
Expand All @@ -47,24 +48,34 @@ end
# dimensionality N, either filling with OneTo(1) or collapsing the
# product of trailing dims into the last element
@pure rdims{N}(out::NTuple{N}, inds::Tuple{}, ::Type{Val{N}}) = out
@pure rdims{N}(out::NTuple{N}, inds::Tuple{Any, Vararg{Any}}, ::Type{Val{N}}) = (front(out)..., length(last(out)) * prod(map(length, inds)))
@pure function rdims{N}(out::NTuple{N}, inds::Tuple{Any, Vararg{Any}}, ::Type{Val{N}})
l = length(last(out)) * prod(map(length, inds))
(front(out)..., OneTo(l))
end
@pure rdims{N}(out::Tuple, inds::Tuple{}, ::Type{Val{N}}) = rdims((out..., OneTo(1)), (), Val{N})
@pure rdims{N}(out::Tuple, inds::Tuple{Any, Vararg{Any}}, ::Type{Val{N}}) = rdims((out..., first(inds)), tail(inds), Val{N})

function _reshape(parent::AbstractArray, dims::Dims)
n = _length(parent)
prod(dims) == n || throw(DimensionMismatch("parent has $n elements, which is incompatible with size $dims"))
__reshape((parent, linearindexing(parent)), dims)
end
_reshape(R::ReshapedArray, dims::Dims) = _reshape(R.parent, dims)
# _reshape on Array returns an Array
_reshape(parent::Vector, dims::Dims{1}) = parent
_reshape(parent::Array, dims::Dims{1}) = reshape(parent, dims)
_reshape(parent::Array, dims::Dims) = reshape(parent, dims)

# When reshaping Vector->Vector, don't wrap with a ReshapedArray
_reshape{T}(v::ReshapedArray{T,1}, dims::Tuple{Int}) = _reshape(v.parent, dims)
function _reshape(v::AbstractVector, dims::Tuple{Int})
function _reshape(v::AbstractVector, dims::Dims{1})
len = dims[1]
len == length(v) || throw(DimensionMismatch("parent has $(length(v)) elements, which is incompatible with length $len"))
v
end
# General reshape
function _reshape(parent::AbstractArray, dims::Dims)
n = _length(parent)
prod(dims) == n || throw(DimensionMismatch("parent has $n elements, which is incompatible with size $dims"))
__reshape((parent, linearindexing(parent)), dims)
end

# Reshaping a ReshapedArray
_reshape{T}(v::ReshapedArray{T,1}, dims::Dims{1}) = _reshape(v.parent, dims)
_reshape(R::ReshapedArray, dims::Dims) = _reshape(R.parent, dims)

function __reshape(p::Tuple{AbstractArray,LinearSlow}, dims::Dims)
parent = p[1]
Expand Down
8 changes: 8 additions & 0 deletions test/arrayops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,14 @@ a = zeros(0, 5) # an empty linearslow array
s = view(a, :, [2,3,5])
@test length(reshape(s, length(s))) == 0

# reshape(a, Val{N})
a = ones(Int,3,3)
s = view(a, 1:2, 1:2)
for N in (1,3)
@test isa(reshape(a, Val{N}), Array{Int,N})
@test isa(reshape(s, Val{N}), Base.ReshapedArray{Int,N})
end

@test reshape(1:5, (5,)) === 1:5
@test reshape(1:5, 5) === 1:5

Expand Down

0 comments on commit 69a605b

Please sign in to comment.