Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add rule for cumprod #420

Merged
merged 15 commits into from
Aug 27, 2021
93 changes: 93 additions & 0 deletions src/rulesets/Base/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -218,3 +218,96 @@ function ∇prod_one_zero!(dx, x, dy::Number=1) # Assumes exactly one x is zero
dx[i_zero] += p_rest * dy
return
end

#####
##### `cumprod`
#####

function rrule(::typeof(cumprod), x::AbstractVector{<:Real}; dims::Integer=1)
y = cumprod(x; dims=dims) # does nothing unless dims == 1
project_x = ProjectTo(x)
function cumprod_pullback_1(dy_raw)
dy = unthunk(dy_raw)
dx_thunk = InplaceableThunk(
dx -> if dims == 1
∇cumprod!(dx, x, dy, y)
else
dx .+= dy
end
,
@thunk project_x(if dims == 1
∇cumprod(x, dy, y)
else
dy
end)
)
return (NoTangent(), dx_thunk)
end
return y, cumprod_pullback_1
end

function rrule(::typeof(cumprod), x::AbstractArray{<:Real}; dims::Integer)
y = cumprod(x; dims=dims)
project_x = ProjectTo(x)
function cumprod_pullback_2(dy_raw)
dy = unthunk(dy_raw)
dx_thunk = InplaceableThunk(
dx -> if dims <= ndims(x)
vald = Val(Int(dims))
∇cumprod_dim!(dx, vald, x, dy, y)
else
dx .+= dy
end
,
@thunk project_x(if dims <= ndims(x)
vald = Val(Int(dims))
∇cumprod_dim(vald, x, dy, y)
else
dy
end)
)
return (NoTangent(), dx_thunk)
end
return y, cumprod_pullback_2
end

function ∇cumprod_dim(vald::Val{dim}, x::AbstractArray, dy=fill!(zero(x),1), y=cumprod(x; dims=dim)) where {dim}
T = promote_type(eltype(x), eltype(dy))
dx = fill!(similar(x, T, axes(x)), zero(T))
∇cumprod_dim!(dx, vald, x, dy, y)
return dx
end

@inline function ∇cumprod_dim!(dx::AbstractArray, ::Val{dim}, x::AbstractArray, dy, y) where {dim}
iters = ntuple(k -> k==dim ? Ref(:) : axes(x,k), ndims(x))
for ind in Iterators.product(iters...)
@views ∇cumprod!(dx[ind...], x[ind...], dy[ind...], y[ind...])
end
return dx
end

function ∇cumprod(x::AbstractVector, dy=one(x), y=cumprod(x))
T = promote_type(eltype(x), eltype(dy)) # really needs to allow dy * y / x
dx = fill!(similar(x, T, axes(x)), zero(T)) # axes(x) makes MArray on StaticArrays, Array for structured matrices
∇cumprod!(dx, x, dy, y)
return dx
end

@inline function ∇cumprod!(dx::AbstractVector, x::AbstractVector, dy, y)
lo, hi = firstindex(x), lastindex(x)
z = something(findfirst(iszero, x), hi+1)
acc = zero(eltype(dy))
@inbounds for k in z-1:-1:lo
acc += y[k] * dy[k]
dx[k] += acc / x[k]
end
@inbounds if z != hi+1
yk = z==1 ? one(eltype(y)) : y[z-1] # will be prod(x[j] for j=1:k if j!=z)
dx[z] += yk * dy[z]
for k in (z+1):hi
yk *= x[k]
dx[z] += yk * dy[k]
end
end
return dx
end
38 changes: 38 additions & 0 deletions test/rulesets/Base/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -154,3 +154,41 @@
end
end # prod
end

@testset "Accumulations" begin
@testset "cumprod" begin
v = round.(10 .* randn(9), sigdigits=3)
test_rrule(cumprod, v)
v[3] = 0
test_rrule(cumprod, v)
v[6] = 0
test_rrule(cumprod, v)

@testset "higher dimensions, dims=$dims" for dims in (1,2,3)
m = round.(10 .* randn(4,5), sigdigits=3)
test_rrule(cumprod, m; fkwargs=(;dims=dims), atol=0.1)
m[2,2] = 0
m[2,4] = 0
test_rrule(cumprod, m; fkwargs=(;dims=dims))

t = round.(10 .* randn(3,3,3), sigdigits=3)
test_rrule(cumprod, t; fkwargs=(;dims=dims))
t[2,2,2] = 0
t[2,3,3] = 0
test_rrule(cumprod, t; fkwargs=(;dims=dims))
end

@testset "types" begin
back = rrule(cumprod, [1, 2, 3])[2] # rule allows integer input, but test_rrule does not
@test unthunk(back(fill(0.5, 3))[2]) == [9/2, 2, 1]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are we testing values here?
It would be good to add comments explaining what we are particularly checking for that test_rrule will not catch

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test_rrule doesn't seem to accept integer input, this tests that the rule still does.

julia> test_rrule(cumprod, [1,2,3])
test_rrule: cumprod on Vector{Int64}: Error During Test at /Users/me/.julia/packages/ChainRulesTestUtils/6oOem/src/testers.jl:227
  Got exception outside of a @test
  InexactError: Int64(0.99)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add comment to that effect?

Yeah can't test methods that require integers with finite differencing.
Since once you apply a finite difference you get a Float64 instead.
Which means you don't hit the method for integers.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment now says "# rule allows integer input, but test_rrule does not"


back = rrule(cumprod, PermutedDimsArray([1 2; 3 4], (2,1)); dims=1)[2]
@test unthunk(back(ones(Float32, 2,2))[2]) == [3 5; 1 3]

@test_throws Exception cumprod(Symmetric([1 2; 3 4]), dims=1) # forward pass fails, so can't test gradient

back = rrule(cumprod, Diagonal([1, 2]); dims=1)[2]
@test unthunk(back(fill(0.5, 2, 2))[2]) ≈ [1/2 0; 0 0] # ProjectTo'd to Diagonal now
end
end
end