Skip to content

Commit

Permalink
Merge pull request #451 from mcabbott/cat
Browse files Browse the repository at this point in the history
Improved rules for `cat`s
  • Loading branch information
oxinabox committed Jun 24, 2021
2 parents 52a0eea + 8060bd5 commit 830e97d
Show file tree
Hide file tree
Showing 3 changed files with 195 additions and 57 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
[compat]
ChainRulesCore = "0.10.4"
ChainRulesTestUtils = "0.7.9"
Compat = "3.30"
Compat = "3.31"
FiniteDifferences = "0.12.8"
StaticArrays = "1.2"
julia = "1"
Expand Down
192 changes: 152 additions & 40 deletions src/rulesets/Base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,66 +24,178 @@ end
##### `hcat` (🐈)
#####

function rrule(::typeof(hcat), A::AbstractArray, Bs::AbstractArray...)
function hcat_pullback(Ȳ)
Xs = (A, Bs...)
ntuple(length(Bs) + 2) do full_i
full_i == 1 && return NoTangent()

i = full_i - 1
l = mapreduce(j->size(Xs[j], 2), Base.add_sum, 1:i-1; init=0)
u = l + size(Xs[i], 2)
dim = u > l + 1 ? (l+1:u) : u
# NOTE: The copy here is defensive, since `selectdim` returns a view which we can
# materialize with `copy`
copy(selectdim(Ȳ, 2, dim))
function rrule(::typeof(hcat), Xs::Union{AbstractArray, Number}...)
Y = hcat(Xs...) # note that Y always has 1-based indexing, even if X isa OffsetArray
ndimsY = Val(ndims(Y)) # this avoids closing over Y, Val() is essential for type-stability
sizes = map(size, Xs) # this avoids closing over Xs
function 🐈_pullback(dY)
hi = Ref(0) # Ref avoids hi::Core.Box
dXs = map(sizes) do sizeX
ndimsX = length(sizeX)
lo = hi[] + 1
hi[] += get(sizeX, 2, 1)
ind = ntuple(ndimsY) do d
if d==2
d > ndimsX ? lo : lo:hi[]
else
d > ndimsX ? 1 : (:)
end
end
if ndimsX > 0
# Here InplaceableThunk breaks @inferred, removed for now
# InplaceableThunk(@thunk(dY[ind...]), dX -> dX .+= view(dY, ind...))
dY[ind...]
else
# This is a hack to perhaps avoid GPU scalar indexing
sum(view(dY, ind...))
end
end
return (NoTangent(), dXs...)
end
return hcat(A, Bs...), hcat_pullback
return Y, 🐈_pullback
end

function rrule(::typeof(reduce), ::typeof(hcat), As::AbstractVector{<:AbstractVecOrMat})
function reduce_hcat_pullback(ΔY)
sizes = size.(As, 2)
cumsizes = cumsum(sizes)
∂As = map(cumsizes, sizes) do post, diff
pre = post - diff + 1
return ΔY[:, pre:post]
widths = map(A -> size(A,2), As)
function reduce_hcat_pullback_2(dY)
hi = Ref(0)
dAs = map(widths) do w
lo = hi[]+1
hi[] += w
dY[:, lo:hi[]]
end
return (NoTangent(), NoTangent(), ∂As)
return (NoTangent(), NoTangent(), dAs)
end
return reduce(hcat, As), reduce_hcat_pullback
return reduce(hcat, As), reduce_hcat_pullback_2
end

function rrule(::typeof(reduce), ::typeof(hcat), As::AbstractVector{<:AbstractVector})
axe = axes(As,1)
function reduce_hcat_pullback_1(dY)
hi = Ref(0)
dAs = map(_ -> dY[:, hi[]+=1], axe)
return (NoTangent(), NoTangent(), dAs)
end
return reduce(hcat, As), reduce_hcat_pullback_1
end

#####
##### `vcat`
#####

function rrule(::typeof(vcat), A::AbstractArray, Bs::AbstractArray...)
function vcat_pullback(Ȳ)
n = size(A, 1)
∂A = copy(selectdim(Ȳ, 1, 1:n))
∂Bs = ntuple(length(Bs)) do i
l = n + mapreduce(j->size(Bs[j], 1), Base.add_sum, 1:i-1; init=0)
u = l + size(Bs[i], 1)
copy(selectdim(Ȳ, 1, l+1:u))
function rrule(::typeof(vcat), Xs::Union{AbstractArray, Number}...)
Y = vcat(Xs...)
ndimsY = Val(ndims(Y))
sizes = map(size, Xs)
function vcat_pullback(dY)
hi = Ref(0)
dXs = map(sizes) do sizeX
ndimsX = length(sizeX)
lo = hi[] + 1
hi[] += get(sizeX, 1, 1)
ind = ntuple(ndimsY) do d
if d==1
d > ndimsX ? lo : lo:hi[]
else
d > ndimsX ? 1 : (:)
end
end
if ndimsX > 0
# InplaceableThunk(@thunk(dY[ind...]), dX -> dX .+= view(dY, ind...))
dY[ind...]
else
sum(view(dY, ind...))
end
end
return (NoTangent(), ∂A, ∂Bs...)
return (NoTangent(), dXs...)
end
return vcat(A, Bs...), vcat_pullback
return Y, vcat_pullback
end

function rrule(::typeof(reduce), ::typeof(vcat), As::AbstractVector{<:AbstractVecOrMat})
function reduce_vcat_pullback(ΔY)
sizes = size.(As, 1)
cumsizes = cumsum(sizes)
∂As = map(cumsizes, sizes) do post, diff
pre = post - diff + 1
return ΔY[pre:post, :]
Y = reduce(vcat, As)
ndimsY = Val(ndims(Y))
heights = map(A -> size(A,1), As)
function reduce_vcat_pullback(dY)
hi = Ref(0)
dAs = map(heights) do z
lo = hi[]+1
hi[] += z
ind = ntuple(d -> d==1 ? (lo:hi[]) : (:), ndimsY)
dY[ind...]
end
return (NoTangent(), NoTangent(), dAs)
end
return Y, reduce_vcat_pullback
end

#####
##### `cat`
#####

_val(::Val{x}) where {x} = x

function rrule(::typeof(cat), Xs::Union{AbstractArray, Number}...; dims)
Y = cat(Xs...; dims=dims)
cdims = dims isa Val ? Int(_val(dims)) : dims isa Integer ? Int(dims) : Tuple(dims)
ndimsY = Val(ndims(Y))
sizes = map(size, Xs)
function cat_pullback(dY)
prev = fill(0, _val(ndimsY)) # note that Y always has 1-based indexing, even if X isa OffsetArray
dXs = map(sizes) do sizeX
ndimsX = length(sizeX)
index = ntuple(ndimsY) do d
if d in cdims
d > ndimsX ? (prev[d]+1) : (prev[d]+1:prev[d]+sizeX[d])
else
d > ndimsX ? 1 : (:)
end
end
for d in cdims
prev[d] += get(sizeX, d, 1)
end
if ndimsX > 0
# InplaceableThunk(@thunk(dY[index...]), dX -> dX .+= view(dY, index...))
dY[index...]
else
sum(view(dY, index...))
end
end
return (NoTangent(), dXs...)
end
return Y, cat_pullback
end

#####
##### `hvcat`
#####

function rrule(::typeof(hvcat), rows, values::Union{AbstractArray, Number}...)
Y = hvcat(rows, values...)
cols = size(Y,2)
ndimsY = Val(ndims(Y))
sizes = map(size, values)
function hvcat_pullback(dY)
prev = fill(0, 2)
dXs = map(sizes) do sizeX
ndimsX = length(sizeX)
index = ntuple(ndimsY) do d
if d in (1, 2)
d > ndimsX ? (prev[d]+1) : (prev[d]+1:prev[d]+sizeX[d])
else
d > ndimsX ? 1 : (:)
end
end
prev[2] += get(sizeX, 2, 1)
if prev[2] == cols
prev[2] = 0
prev[1] += get(sizeX, 1, 1)
end
dY[index...]
end
return (NoTangent(), NoTangent(), ∂As)
return (NoTangent(), NoTangent(), dXs...)
end
return reduce(vcat, As), reduce_vcat_pullback
return Y, hvcat_pullback
end

#####
Expand Down
58 changes: 42 additions & 16 deletions test/rulesets/Base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,31 +5,57 @@
end

@testset "hcat" begin
A = randn(3, 2)
B = randn(3)
C = randn(3, 3)
test_rrule(hcat, A, B, C; check_inferred=false)
test_rrule(hcat, randn(3, 2), randn(3), randn(3, 3); check_inferred=VERSION>v"1.1")
test_rrule(hcat, rand(), rand(1,2), rand(1,2,1); check_inferred=VERSION>v"1.1")
test_rrule(hcat, rand(3,1,1,2), rand(3,3,1,2); check_inferred=VERSION>v"1.1")
end

@testset "reduce hcat" begin
A = randn(3, 2)
B = randn(3, 1)
C = randn(3, 3)
test_rrule(reduce, hcat NoTangent(), [A, B, C])
mats = [randn(3, 2), randn(3, 1), randn(3, 3)]
test_rrule(reduce, hcat NoTangent(), mats)

vecs = [rand(3) for _ in 1:4]
test_rrule(reduce, hcat NoTangent(), vecs)

mix = AbstractVecOrMat[rand(4,2), rand(4)] # this is weird, but does hit the fast path
test_rrule(reduce, hcat NoTangent(), mix)

adjs = vec([randn(2, 4), randn(1, 4), randn(3, 4)]') # not a Vector
# test_rrule(reduce, hcat ⊢ NoTangent(), adjs ⊢ map(m -> rand(size(m)), adjs))
dy = 1 ./ reduce(hcat, adjs)
@test rrule(reduce, hcat, adjs)[2](dy)[3] rrule(reduce, hcat, collect.(adjs))[2](dy)[3]
end

@testset "vcat" begin
A = randn(2, 4)
B = randn(1, 4)
C = randn(3, 4)
test_rrule(vcat, A, B, C; check_inferred=false)
test_rrule(vcat, randn(2, 4), randn(1, 4), randn(3, 4); check_inferred=VERSION>v"1.1")
test_rrule(vcat, rand(), rand(); check_inferred=VERSION>v"1.1")
test_rrule(vcat, rand(), rand(3), rand(3,1,1); check_inferred=VERSION>v"1.1")
test_rrule(vcat, rand(3,1,2), rand(4,1,2); check_inferred=VERSION>v"1.1")
end

@testset "reduce vcat" begin
A = randn(2, 4)
B = randn(1, 4)
C = randn(3, 4)
test_rrule(reduce, vcat NoTangent(), [A, B, C])
mats = [randn(2, 4), randn(1, 4), randn(3, 4)]
test_rrule(reduce, vcat NoTangent(), mats)

vecs = [rand(2), rand(3), rand(4)]
test_rrule(reduce, vcat NoTangent(), vecs)

mix = AbstractVecOrMat[rand(4,1), rand(4)]
test_rrule(reduce, vcat NoTangent(), mix)
end

@testset "cat" begin
test_rrule(cat, rand(2, 4), rand(1, 4); fkwargs=(dims=1,), check_inferred=VERSION>v"1.1")
test_rrule(cat, rand(2, 4), rand(2); fkwargs=(dims=Val(2),), check_inferred=VERSION>v"1.1")
test_rrule(cat, rand(), rand(2, 3); fkwargs=(dims=[1,2],), check_inferred=VERSION>v"1.1")
test_rrule(cat, rand(1), rand(3, 2, 1); fkwargs=(dims=(1,2),), check_inferred=false) # infers Tuple{Zero, Vector{Float64}, Any}
end

@testset "hvcat" begin
test_rrule(hvcat, 2 NoTangent(), rand(ComplexF64, 6)...; check_inferred=VERSION>v"1.1")
test_rrule(hvcat, (2, 1) NoTangent(), rand(), rand(1,1), rand(2,2); check_inferred=VERSION>v"1.1")
test_rrule(hvcat, 1 NoTangent(), rand(3)' rand(1,3), transpose(rand(3)) rand(1,3); check_inferred=VERSION>v"1.1")
test_rrule(hvcat, 1 NoTangent(), rand(0,3), rand(2,3), rand(1,3,1); check_inferred=VERSION>v"1.1")
end

@testset "fill" begin
Expand Down

0 comments on commit 830e97d

Please sign in to comment.