Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add rule for prod #335

Merged
merged 18 commits into from
May 27, 2021
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "0.7.65"
version = "0.7.66"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
81 changes: 81 additions & 0 deletions src/rulesets/Base/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,84 @@ function rrule(
end
return y, sum_abs2_pullback
end

#####
##### `prod`
#####

function rrule(::typeof(prod), x::AbstractArray{T}; dims=:) where {T<:CommutativeMulNumber}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologies for diving in with my usual request, but is there a sensible way that we could restrict the type here, since the tests currently only look at Arrays? e.g. I imagine that a at least one of a Fill, Diagonal, StaticArray etc will do something weird here. Would StridedArray suffice for your use case?

Copy link
Member Author

@mcabbott mcabbott Dec 28, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a test with PermutedDimsArray, which isn't a StridedArray, and I think it ought to work fine with StaticArrays, although I have not tested that in any depth. Diagonal seems to work although I struggle to imagine why calling prod on one would be a good idea, but weird things happen:

julia> unthunk(rrule(prod, Diagonal(SA[1,2,3,4]))[2](1.0)[2])
4×4 Diagonal{Float64, MVector{4, Float64}}:
 0.0   ⋅    ⋅    ⋅ 
  ⋅   0.0   ⋅    ⋅ 
  ⋅    ⋅   0.0   ⋅ 
  ⋅    ⋅    ⋅   0.0
  
julia> unthunk(rrule(prod, Fill(2,3))[2](1.0)[2])
3-element Vector{Float64}:
 4.0
 4.0
 4.0  

Fill makes a Vector gradient. Somehow rrule(sum, Fill(2,3)) makes a Fill, because it simply broadcasts rather than calling similar. Is this something the package aims to guarantee? I don't see a test for it. Elsewhere it chooses similar over broadcasting to void other issues.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a test with PermutedDimsArray, which isn't a StridedArray

Apologies, it's late here. They are quite clearly in the tests...

Fill makes a Vector gradient

We would definitely want the output of rrule w.r.t. a Fill argument to be either another Fill or an appropriate Composite. This is the kind of thing that e.g. Zygote can probably get right without a rule in ChainRules, so I think the ideal solution here is just not to implement a rule that covers Fill.

Also, should the result with Diagonal have zeros on the diagonal?

Copy link
Member Author

@mcabbott mcabbott Dec 28, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, should the result with Diagonal have zeros on the diagonal?

Something is broken on the last commit, but not this, this one you can do in your head: It's a product of mostly zeros, so the gradient with respect to nonzero entries still vanishes.

Fill makes a Vector gradient

We would definitely want

This could be arranged at the cost of more complexity... although possibly Fill ought to define similar more like that of Diagonal if it wishes to be preserved under such things?

Although clearly not all gradients are going to perserve this structure:

julia> gradient(sum∘cumsum, Fill(1,3))[1]
3-element Vector{Int64}:
 3
 2
 1
 
 julia> gradient(x -> sum(cumsum(x, dims=1)), Diagonal([1,2,3]))[1]  # what should this produce?

And here's how much Zygote can figure out right now:

julia> function myprod(xs)
      out = 1.0
      for x in xs
        out *= x
      end
      out
      end
myprod (generic function with 1 method)

julia> gradient(myprod, Fill(1,3))[1]
3-element Vector{Float64}:
1.0
1.0
1.0

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something is broken on the last commit, but not this, this one you can do in your head: It's a product of mostly zeros, so the gradient with respect to nonzero entries still vanishes.

Yes, good point.

julia> gradient(sum∘cumsum, Fill(1,3))[1]

This should produce a Composite, with the value with an appropriate value field.

julia> gradient(x -> sum(cumsum(x, dims=1)), Diagonal([1,2,3]))[1] # what should this produce?

This should also produce either a Diagonal or an appropriate Composite.

But with all of these types, I'm not saying that your PR needs to cover them. I'm purely suggesting it addresses the minimal set of types that you're confident are done correctly, and assumes that AD can do a reasonable job of deriving the others.

julia> gradient(myprod, Fill(1,3))[1]

The answer to this is the result of a bug in Zygote that I should fix -- it looks like an example of what I'm commenting on here, where getindex has been implemented for too broad a set of types. Zygote really should be able to derive the rule for this properly. i.e. getindex only ever returns the value field of a Fill, so you shouldn't even need a rule for getindex for Fill.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we can get this merged, shall we change this to StridedArray and then we can make a follow up later?

y = prod(x; dims=dims)
# vald = dims isa Colon ? nothing : dims isa Integer ? Val(Int(dims)) : Val(Tuple(dims))
function prod_pullback(dy)
x_thunk = InplaceableThunk(
# Out-of-place versions
@thunk if dims === (:)
∇prod(x, dy, y)
elseif any(iszero, x) # Then, and only then, will ./x lead to NaN
vald = dims isa Colon ? nothing : dims isa Integer ? Val(Int(dims)) : Val(Tuple(dims))
∇prod_dims(vald, x, dy, y) # val(Int(dims)) is about 2x faster than Val(Tuple(dims))
else
conj.(y ./ x) .* dy
end
,
# In-place versions -- same branching
dx -> if dims === (:)
∇prod!(dx, x, dy, y)
elseif any(iszero, x)
vald = dims isa Colon ? nothing : dims isa Integer ? Val(Int(dims)) : Val(Tuple(dims))
∇prod_dims!(dx, vald, x, dy, y)
else
dx .+= conj.(y ./ x) .* dy
end
)
return (NO_FIELDS, x_thunk)
end
return y, prod_pullback
end

function ∇prod_dims(vald::Val{dims}, x, dy, y=prod(x; dims=dims)) where {dims}
T = promote_type(eltype(x), eltype(dy))
dx = fill!(similar(x, T, axes(x)), zero(T))
∇prod_dims!(dx, vald, x, dy, y)
return dx
end

function ∇prod_dims!(dx, ::Val{dims}, x, dy, y) where {dims}
iters = ntuple(d -> d in dims ? tuple(:) : axes(x,d), ndims(x)) # Without Val(dims) this is a serious type instability
@inbounds for ind in Iterators.product(iters...)
jay = map(i -> i isa Colon ? 1 : i, ind)
@views ∇prod!(dx[ind...], x[ind...], dy[jay...], y[jay...])
end
return dx
end

function ∇prod(x, dy::Number=1, y::Number=prod(x))
T = promote_type(eltype(x), eltype(dy))
dx = fill!(similar(x, T, axes(x)), zero(T)) # axes(x) makes MArray on StaticArrays, Array for structured matrices
∇prod!(dx, x, dy, y)
return dx
end

function ∇prod!(dx, x, dy::Number=1, y::Number=prod(x))
numzero = iszero(y) ? count(iszero, x) : 0
if numzero == 0 # This can happen while y==0, if there are several small xs
dx .+= conj.(y ./ x) .* dy
elseif numzero == 1
∇prod_one_zero!(dx, x, dy)
else
# numzero > 1, then all first derivatives are zero
end
return dx
end

function ∇prod_one_zero!(dx, x, dy::Number=1) # Assumes exactly one x is zero
i_zero = 0
p_rest = one(promote_type(eltype(x), typeof(dy)))
for i in eachindex(x)
xi = @inbounds x[i]
p_rest *= ifelse(iszero(xi), one(xi), conj(xi))
i_zero = ifelse(iszero(xi), i, i_zero)
end
dx[i_zero] += p_rest * dy
return
end
56 changes: 56 additions & 0 deletions test/rulesets/Base/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,60 @@
end
end
end # sum abs2

@testset "prod" begin
@testset "Array{$T}" for T in [Float64, ComplexF64]
@testset "size = $sz, dims = $dims" for (sz, dims) in [
((12,), :), ((12,), 1),
((3,4), 1), ((3,4), 2), ((3,4), :), ((3,4), [1,2]),
((3,4,1), 1), ((3,2,2), 3), ((3,2,2), 2:3),
]
x = randn(T, sz)
test_rrule(prod, x; fkwargs=(dims=dims,), check_inferred=true)
x[1] = 0
test_rrule(prod, x; fkwargs=(dims=dims,), check_inferred=true)
x[5] = 0
test_rrule(prod, x; fkwargs=(dims=dims,), check_inferred=true)
x[3] = x[7] = 0 # two zeros along some slice, for any dims
test_rrule(prod, x; fkwargs=(dims=dims,), check_inferred=true)

if ndims(x) == 3
xp = PermutedDimsArray(x, (3,2,1)) # not a StridedArray
xpdot, xpbar = permutedims(rand(T, sz), (3,2,1)), permutedims(rand(T, sz), (3,2,1))
test_rrule(prod, xp ⊢ xpbar; fkwargs=(dims=dims,), check_inferred=true)
end
end

@testset "structured wrappers" begin
# Adjoint -- like PermutedDimsArray this may actually be used
xa = adjoint(rand(T,4,4))
test_rrule(prod, xa ⊢ rand(T,4,4))
test_rrule(prod, xa ⊢ rand(T,4,4), fkwargs=(dims=2,))
@test unthunk(rrule(prod, adjoint(rand(T,3,3)))[2](1.0)[2]) isa Matrix
@test unthunk(rrule(prod, adjoint(rand(T,3,3)), dims=1)[2](ones(1,3))[2]) isa Matrix

# Diagonal -- a stupid thing to do, product of zeros! Shouldn't be an error though:
@test iszero(unthunk(rrule(prod, Diagonal(rand(T,3)))[2](1.0)[2]))
@test iszero(unthunk(rrule(prod, Diagonal(rand(T,3)), dims=1)[2](ones(1,3))[2]))
@test unthunk(rrule(prod, Diagonal(rand(T,1)))[2](1.0)[2]) == hcat(1) # 1x1 sparse matrix
@test unthunk(rrule(prod, Diagonal(ones(T,2)), dims=1)[2](ones(1,2))[2]) == [0 1; 1 0]

# Triangular -- almost equally stupud
@test iszero(unthunk(rrule(prod, UpperTriangular(rand(T,3,3)))[2](1.0)[2]))
@test unthunk(rrule(prod, UpperTriangular(ones(T,2,2)))[2](1.0)[2]) == [0 0; 1 0]

# Symmetric -- at least this doesn't have zeros, still an unlikely combination
xs = Symmetric(rand(T,4,4))
@test_skip test_rrule(prod, xs ⊢ rand(T,4,4))
@test_skip test_rrule(prod, xs ⊢ rand(T,4,4), fkwargs=(dims=2,))
Comment on lines +67 to +68
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a bug in how FiniteDifferences does this, or am I thinking incorrectly about what it should produce?

using ForwardDiff, FiniteDifferences
xs = Symmetric(reshape(1:16,4,4)./10)
xm = Matrix(xs)

g1 = ForwardDiff.gradient(prod, xm) # symmetric, but not Symmetric
g2 = grad(central_fdm(5, 1), prod, xm)[1]
g3 = grad(central_fdm(5, 1), prod, xs)[1]  # is Symmetric, and differs

g1 ≈ g2
diag(g1) ≈ diag(g3)
UnitUpperTriangular(g1) ≈ UnitUpperTriangular(g3 ./ 2)  # this seems weird

With dims:

g4 = ForwardDiff.gradient(x -> sum(prod(x,dims=1)), xm) # no longer symmetric
g5 = grad(central_fdm(5, 1), x -> sum(prod(x,dims=1)), xm)[1] 
g6 = grad(central_fdm(5, 1), x -> sum(prod(x,dims=1)), xs)[1] 

g4 ≈ g5

proj(m) = (m .+ m')./2;
proj(g4) ≈ proj(proj(g4)) # it's a projection

fold(m) = m .+ m' .- Diagonal(m)
fold(g4) ≈ g6

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest checking what FiniteDifferences.to_vec is output ting

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now I've left here a simpler test that it does run without error on a Symmetric: unthunk(rrule(prod, Symmetric(ones(T,2,2)))[2](1.0)[2]) == [1 1; 1 1]. I very much doubt this case is going to see use, but it shouldn't give an error.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be better to have a test with Symmetric([2.0 3.0; 3.0 2.0]) ones are hard to trust

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I've switched it. It's mostly to check this doesn't give an error, the computation here does not care at all what kind of matrix it gets. Although ones is in fact sufficient to see this weirdness:

julia> grad(central_fdm(5, 1), prod, Symmetric(ones(2,2)))[1]
2×2 Symmetric{Float64, Matrix{Float64}}:
 1.0  2.0
 2.0  1.0

julia> ForwardDiff.gradient(prod, ones(2,2))
2×2 Matrix{Float64}:
 1.0  1.0
 1.0  1.0

@test unthunk(rrule(prod, Symmetric(T[1 2; -333 4]))[2](1.0)[2]) == [16 8; 8 4]
end
end
@testset "Array{Float32}, no zero entries" begin
v = [1f-5, 1f-10, 1f-15, 1f-20]
@test prod(v) == 0
@test unthunk(rrule(prod, v)[2](1f0)[2]) == zeros(4)
mcabbott marked this conversation as resolved.
Show resolved Hide resolved
test_rrule(prod, v)
end
end # prod
end