Skip to content

Commit

Permalink
eachslice, eachrow, eachcol (introduced in #29749) now return a…
Browse files Browse the repository at this point in the history
…n `EachSlice` object (along with `EachRow`/`EachCol` aliases). The main benefit is that it will allow dispatch on the iterator to provide more efficient methods, e.g.

```
sum(A::EachRow) = vec(sum(parent(A), dims=1))
```

This will encourage the use of `eachcol`/`eachrow` to resolve ambiguities in user-facing APIs, in particular, the "obsverations as rows vs columns" problem in the statistics/ML packages.

This also makes `eachslice` work over multiple dimensions.
  • Loading branch information
simonbyrne committed Jun 12, 2019
1 parent b777409 commit 5728764
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 9 deletions.
45 changes: 38 additions & 7 deletions base/abstractarraymath.jl
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,26 @@ _reperr(s, n, N) = throw(ArgumentError("number of " * s * " repetitions " *
return R
end

struct EachSlice{A,I,L}
arr::A # underlying array
cartiter::I # CartesianIndices iterator
lookup::L # dimension look up: dimension index in cartiter, or nothing
end

function iterate(s::EachSlice, state...)
r = iterate(s.cartiter, state...)
isnothing(r) && return r
(c,nextstate) = r
view(s.arr, map(l -> isnothing(l) ? (:) : c[l], s.lookup)...), nextstate
end

size(s::EachSlice) = size(s.cartiter)
length(s::EachSlice) = length(s.cartiter)
ndims(s::EachSlice) = ndims(s.cartiter)
IteratorSize(::Type{EachSlice{A,I,L}}) where {A,I,L} = IteratorSize(I)

parent(s::EachSlice) = s.arr

"""
eachrow(A::AbstractVecOrMat)
Expand All @@ -418,8 +438,13 @@ See also [`eachcol`](@ref) and [`eachslice`](@ref).
!!! compat "Julia 1.1"
This function requires at least Julia 1.1.
"""
eachrow(A::AbstractVecOrMat) = (view(A, i, :) for i in axes(A, 1))
function eachrow(A::AbstractVecOrMat)
iter = CartesianIndices((axes(A,1),))
lookup = (1,nothing)
EachSlice(A,iter,lookup)
end

const EachRow{A,I} = EachSlice{A,I,Tuple{Int,Nothing}}

"""
eachcol(A::AbstractVecOrMat)
Expand All @@ -432,7 +457,12 @@ See also [`eachrow`](@ref) and [`eachslice`](@ref).
!!! compat "Julia 1.1"
This function requires at least Julia 1.1.
"""
eachcol(A::AbstractVecOrMat) = (view(A, :, i) for i in axes(A, 2))
function eachcol(A::AbstractVecOrMat)
iter = CartesianIndices((axes(A,2),))
lookup = (nothing,1)
EachSlice(A,iter,lookup)
end
const EachCol{A,I} = EachSlice{A,I,Tuple{Nothing,Int}}

"""
eachslice(A::AbstractArray; dims)
Expand All @@ -449,9 +479,10 @@ See also [`eachrow`](@ref), [`eachcol`](@ref), and [`selectdim`](@ref).
This function requires at least Julia 1.1.
"""
@inline function eachslice(A::AbstractArray; dims)
length(dims) == 1 || throw(ArgumentError("only single dimensions are supported"))
dim = first(dims)
dim <= ndims(A) || throw(DimensionMismatch("A doesn't have $dim dimensions"))
idx1, idx2 = ntuple(d->(:), dim-1), ntuple(d->(:), ndims(A)-dim)
return (view(A, idx1..., i, idx2...) for i in axes(A, dim))
for dim in dims
dim <= ndims(A) || throw(DimensionMismatch("A doesn't have $dim dimensions"))
end
iter = CartesianIndices(map(dim -> axes(A,dim), dims))
lookup = ntuple(dim -> findfirst(isequal(dim), dims), ndims(A))
EachSlice(A,iter,lookup)
end
22 changes: 20 additions & 2 deletions test/arrayops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2005,19 +2005,38 @@ end
end

# row/column/slice iterator tests
using Base: eachrow, eachcol
using Base: eachrow, eachcol, EachRow, EachCol
@testset "row/column/slice iterators" begin

@test eachrow(ones(3)) isa EachRow
@test !(eachrow(ones(3)) isa EachCol)
@test eachcol(ones(3)) isa EachCol
@test !(eachcol(ones(3)) isa EachRow)

@test eachrow(ones(3,3)) isa EachRow
@test !(eachrow(ones(3,3)) isa EachCol)
@test eachcol(ones(3,3)) isa EachCol
@test !(eachcol(ones(3,3)) isa EachRow)

# Simple ones
M = [1 2 3; 4 5 6; 7 8 9]
@test collect(eachrow(M)) == collect(eachslice(M, dims = 1)) == [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
@test collect(eachcol(M)) == collect(eachslice(M, dims = 2)) == [[1, 4, 7], [2, 5, 8], [3, 6, 9]]
@test_throws DimensionMismatch eachslice(M, dims = 4)
@test eltype(eachrow(M)) == typeof(first(eachrow(M)))
@test eltype(eachcol(M)) == typeof(first(eachcol(M)))

# Higher-dimensional case
M = reshape([(1:16)...], 2, 2, 2, 2)
@test_throws MethodError collect(eachrow(M))
@test_throws MethodError collect(eachcol(M))
@test collect(eachslice(M, dims = 1))[1][:, :, 1] == [1 5; 3 7]
@test collect(eachslice(M, dims = (1,4)))[1, 1] == [1 5; 3 7]
@test collect(eachslice(M, dims = (1,4)))[1, 2] == [9 13; 11 15]
@test collect(eachslice(M, dims = (4,1)))[1, 1] == [1 5; 3 7]
@test collect(eachslice(M, dims = (4,1)))[1, 2] == [2 6; 4 8]
@test eltype(eachslice(M, dims=1)) == typeof(first(eachslice(M, dims=1)))
@test eltype(eachslice(M, dims=(4,1))) == typeof(first(eachslice(M, dims=(4,1))))
end

###
Expand Down Expand Up @@ -2630,4 +2649,3 @@ end

# Fix oneunit bug for unitful arrays
@test oneunit([Second(1) Second(2); Second(3) Second(4)]) == [Second(1) Second(0); Second(0) Second(1)]

0 comments on commit 5728764

Please sign in to comment.