Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add rule for prod #335

Merged
merged 18 commits into from
May 27, 2021
Merged

Add rule for prod #335

merged 18 commits into from
May 27, 2021

Conversation

mcabbott
Copy link
Member

@mcabbott mcabbott commented Dec 26, 2020

This adds a reverse-mode gradient for prod(x; dims), which should correctly treat zero entries.

It ends up a little more complicated than seems ideal. In particular this won't work on CuArrays (at least when there are zeros, I think it will give a scalar access warning, which might be better than NaNs; when there aren't, it should work). Is there a mechanism worked out for where a Cu version of ∇prod_dims! should live, if someone were to write one?

It also probably won't work well for second derivatives.

@codecov-io
Copy link

codecov-io commented Dec 26, 2020

Codecov Report

Merging #335 (9d1b3ba) into master (ebc99f7) will increase coverage by 0.09%.
The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #335      +/-   ##
==========================================
+ Coverage   97.64%   97.73%   +0.09%     
==========================================
  Files          18       18              
  Lines        1018     1061      +43     
==========================================
+ Hits          994     1037      +43     
  Misses         24       24              
Impacted Files Coverage Δ
src/rulesets/Base/mapreduce.jl 100.00% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update ebc99f7...839a6f3. Read the comment docs.

Copy link
Member

@simeonschaub simeonschaub left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM overall, just some comments.

src/rulesets/Base/mapreduce.jl Outdated Show resolved Hide resolved
src/rulesets/Base/mapreduce.jl Outdated Show resolved Hide resolved
test/rulesets/Base/mapreduce.jl Show resolved Hide resolved
test/rulesets/Base/mapreduce.jl Outdated Show resolved Hide resolved
##### `prod`
#####

function rrule(::typeof(prod), x::AbstractArray{T}; dims=:) where {T<:CommutativeMulNumber}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologies for diving in with my usual request, but is there a sensible way that we could restrict the type here, since the tests currently only look at Arrays? e.g. I imagine that a at least one of a Fill, Diagonal, StaticArray etc will do something weird here. Would StridedArray suffice for your use case?

Copy link
Member Author

@mcabbott mcabbott Dec 28, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a test with PermutedDimsArray, which isn't a StridedArray, and I think it ought to work fine with StaticArrays, although I have not tested that in any depth. Diagonal seems to work although I struggle to imagine why calling prod on one would be a good idea, but weird things happen:

julia> unthunk(rrule(prod, Diagonal(SA[1,2,3,4]))[2](1.0)[2])
4×4 Diagonal{Float64, MVector{4, Float64}}:
 0.0   ⋅    ⋅    ⋅ 
  ⋅   0.0   ⋅    ⋅ 
  ⋅    ⋅   0.0   ⋅ 
  ⋅    ⋅    ⋅   0.0
  
julia> unthunk(rrule(prod, Fill(2,3))[2](1.0)[2])
3-element Vector{Float64}:
 4.0
 4.0
 4.0  

Fill makes a Vector gradient. Somehow rrule(sum, Fill(2,3)) makes a Fill, because it simply broadcasts rather than calling similar. Is this something the package aims to guarantee? I don't see a test for it. Elsewhere it chooses similar over broadcasting to void other issues.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a test with PermutedDimsArray, which isn't a StridedArray

Apologies, it's late here. They are quite clearly in the tests...

Fill makes a Vector gradient

We would definitely want the output of rrule w.r.t. a Fill argument to be either another Fill or an appropriate Composite. This is the kind of thing that e.g. Zygote can probably get right without a rule in ChainRules, so I think the ideal solution here is just not to implement a rule that covers Fill.

Also, should the result with Diagonal have zeros on the diagonal?

Copy link
Member Author

@mcabbott mcabbott Dec 28, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, should the result with Diagonal have zeros on the diagonal?

Something is broken on the last commit, but not this, this one you can do in your head: It's a product of mostly zeros, so the gradient with respect to nonzero entries still vanishes.

Fill makes a Vector gradient

We would definitely want

This could be arranged at the cost of more complexity... although possibly Fill ought to define similar more like that of Diagonal if it wishes to be preserved under such things?

Although clearly not all gradients are going to perserve this structure:

julia> gradient(sum∘cumsum, Fill(1,3))[1]
3-element Vector{Int64}:
 3
 2
 1
 
 julia> gradient(x -> sum(cumsum(x, dims=1)), Diagonal([1,2,3]))[1]  # what should this produce?

And here's how much Zygote can figure out right now:

julia> function myprod(xs)
      out = 1.0
      for x in xs
        out *= x
      end
      out
      end
myprod (generic function with 1 method)

julia> gradient(myprod, Fill(1,3))[1]
3-element Vector{Float64}:
1.0
1.0
1.0

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something is broken on the last commit, but not this, this one you can do in your head: It's a product of mostly zeros, so the gradient with respect to nonzero entries still vanishes.

Yes, good point.

julia> gradient(sum∘cumsum, Fill(1,3))[1]

This should produce a Composite, with the value with an appropriate value field.

julia> gradient(x -> sum(cumsum(x, dims=1)), Diagonal([1,2,3]))[1] # what should this produce?

This should also produce either a Diagonal or an appropriate Composite.

But with all of these types, I'm not saying that your PR needs to cover them. I'm purely suggesting it addresses the minimal set of types that you're confident are done correctly, and assumes that AD can do a reasonable job of deriving the others.

julia> gradient(myprod, Fill(1,3))[1]

The answer to this is the result of a bug in Zygote that I should fix -- it looks like an example of what I'm commenting on here, where getindex has been implemented for too broad a set of types. Zygote really should be able to derive the rule for this properly. i.e. getindex only ever returns the value field of a Fill, so you shouldn't even need a rule for getindex for Fill.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we can get this merged, shall we change this to StridedArray and then we can make a follow up later?

@nickrobinson251 nickrobinson251 added the type constraints Potentially raises a question about how tightly to constrain argument types for a rule. See #232 label Jan 1, 2021
@sethaxen sethaxen mentioned this pull request Mar 9, 2021
@oxinabox
Copy link
Member

oxinabox commented Mar 9, 2021

bump

@codecov-commenter
Copy link

codecov-commenter commented May 7, 2021

Codecov Report

Merging #335 (f756149) into master (cbba09c) will decrease coverage by 0.01%.
The diff coverage is 98.03%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #335      +/-   ##
==========================================
- Coverage   98.49%   98.48%   -0.02%     
==========================================
  Files          23       23              
  Lines        1929     1980      +51     
==========================================
+ Hits         1900     1950      +50     
- Misses         29       30       +1     
Impacted Files Coverage Δ
src/rulesets/Base/mapreduce.jl 98.71% <98.03%> (-1.29%) ⬇️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update cbba09c...f756149. Read the comment docs.

Comment on lines +71 to +68
@test_skip test_rrule(prod, xs ⊢ rand(T,4,4))
@test_skip test_rrule(prod, xs ⊢ rand(T,4,4), fkwargs=(dims=2,))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a bug in how FiniteDifferences does this, or am I thinking incorrectly about what it should produce?

using ForwardDiff, FiniteDifferences
xs = Symmetric(reshape(1:16,4,4)./10)
xm = Matrix(xs)

g1 = ForwardDiff.gradient(prod, xm) # symmetric, but not Symmetric
g2 = grad(central_fdm(5, 1), prod, xm)[1]
g3 = grad(central_fdm(5, 1), prod, xs)[1]  # is Symmetric, and differs

g1 ≈ g2
diag(g1) ≈ diag(g3)
UnitUpperTriangular(g1) ≈ UnitUpperTriangular(g3 ./ 2)  # this seems weird

With dims:

g4 = ForwardDiff.gradient(x -> sum(prod(x,dims=1)), xm) # no longer symmetric
g5 = grad(central_fdm(5, 1), x -> sum(prod(x,dims=1)), xm)[1] 
g6 = grad(central_fdm(5, 1), x -> sum(prod(x,dims=1)), xs)[1] 

g4 ≈ g5

proj(m) = (m .+ m')./2;
proj(g4) ≈ proj(proj(g4)) # it's a projection

fold(m) = m .+ m' .- Diagonal(m)
fold(g4) ≈ g6

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest checking what FiniteDifferences.to_vec is output ting

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now I've left here a simpler test that it does run without error on a Symmetric: unthunk(rrule(prod, Symmetric(ones(T,2,2)))[2](1.0)[2]) == [1 1; 1 1]. I very much doubt this case is going to see use, but it shouldn't give an error.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be better to have a test with Symmetric([2.0 3.0; 3.0 2.0]) ones are hard to trust

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I've switched it. It's mostly to check this doesn't give an error, the computation here does not care at all what kind of matrix it gets. Although ones is in fact sufficient to see this weirdness:

julia> grad(central_fdm(5, 1), prod, Symmetric(ones(2,2)))[1]
2×2 Symmetric{Float64, Matrix{Float64}}:
 1.0  2.0
 2.0  1.0

julia> ForwardDiff.gradient(prod, ones(2,2))
2×2 Matrix{Float64}:
 1.0  1.0
 1.0  1.0

@mcabbott
Copy link
Member Author

mcabbott commented May 13, 2021

Status here is that:

  • Something is weird with tests of Symmetric matrices. I think the rule is fine, and gives the same answer as collect(s). (It's unlikely anyone wants that anyway, for prod.)
  • I found a type instability which was making this 30x slower than it has to be, now fixed.
  • It's actually not hard to write this with map instead of explicit iteration; it's a few times slower when there is one zero (or in each such column/slice) but might be more generic, more likely to work for CuArrays.
julia> x=rand(10,100); x[1:21:end].=0; # half the columns have a zero

julia> @btime gradient(x -> sum(prod(x, dims=1)), $x)[1]  # Zygote, with NaN
  1.292 μs (3 allocations: 8.83 KiB)
10×100 Matrix{Float64}:
 NaN    1.64248e-6    0.0  2.85863e-5     0.0      0.0  0.00412328    0.0  0.000262674

julia> @btime gradient(x -> sum(prod(x, dims=1)), $x)[1]   # this PR, original, with a a type instability
  56.000 μs (1706 allocations: 51.06 KiB)
10×100 Matrix{Float64}:
 1.499e-5  2.54822e-5  0.0         0.000129398    0.000634891  0.0         0.00343482

julia> @btime gradient(x -> sum(prod(x, dims=1)), $x)[1]  # with Val(Int(dims)), same as dims=1 hard coded
  1.850 μs (5 allocations: 8.86 KiB)
10×100 Matrix{Float64}:
 1.12526e-6  0.000483517  0.0         5.28479e-6    0.000555819  0.0         0.00126222

julia> @btime gradient(x -> sum(prod(x, dims=1)), $x)[1]  # version with map, not indexing
  2.704 μs (6 allocations: 8.89 KiB)
10×100 Matrix{Float64}:
 5.36396e-7  0.0370314  0.0         0.022377    0.0           8.56615e-6  0.0         1.3721e-5
 0.0         0.0327162  4.17844e-5  0.00593767  0.0            7.05661e-7  0.0         0.0244155

Current PR is the 2nd-last variant, code for last variant reads:

function ∇prod(x::AbstractArray, dy::Number=1, y::Number=prod(x))
    numzero = iszero(y) ? count(iszero, x) : 0
    dx = if numzero == 0  # This can happen while y==0, if there are several small xs
        map(xi -> conj(y / xi) * dy, x)
    elseif numzero == 1
        y_rest = prod(xi -> ifelse(iszero(xi), one(xi), conj(xi)), x)
        val1 = y_rest * dy / one(eltype(x))  # Divide for type stability: ∇prod([0,2,3], 1) should be floats
        map(xi -> ifelse(iszero(xi), val1, zero(val1)), x)
    else
        val0 = zero(conj(y / one(eltype(x))) * dy)
        map(xi -> val0, x)  # Construct a zero array in the same manner, same type
    end
    return dx
end

function ∇prod!(dx, x, dy::Number=1, y::Number=prod(x))
    numzero = iszero(y) ? count(iszero, x) : 0
    if numzero == 0  # This can happen while y==0, if there are several small xs
        dx .+= conj.(y ./ x) .* dy
    elseif numzero == 1
        y_rest = prod(xi -> ifelse(iszero(xi), one(xi), conj(xi)), x)
        val = y_rest * dy
        dx .+= ifelse.(iszero.(x), val, zero.(x))
    else
        # numzero > 1, then all first derivatives are zero
    end
    return dx
end

But my vote is to leave it; someone else can switch to this later if they see a need.

@mcabbott
Copy link
Member Author

Bump?

Failure on nightly looks unrelated, LinearAlgebra/factorization.jl:182

@oxinabox
Copy link
Member

Array concerns can be addressed in follow-up if and when it occurs

@oxinabox oxinabox merged commit 30ea23f into JuliaDiff:master May 27, 2021
@oxinabox
Copy link
Member

Thanks, sorry about the long delay

@mcabbott mcabbott deleted the prod2 branch May 27, 2021 19:30
@mcabbott mcabbott mentioned this pull request May 28, 2021
bors bot added a commit to FluxML/Zygote.jl that referenced this pull request Jun 8, 2021
988: delete rule for prod to use ChainRules' r=oxinabox a=oxinabox

@mcabbott  added a rule for prod into ChainRules
JuliaDiff/ChainRules.jl#335

It's better than the one in Zygote as it gets the right answer even if one of the elements is zero.

So we can delete the old one here.
But leaving the tests in place per our policy, as a double check against regressions in ChainRules

Co-authored-by: Lyndon White <oxinabox@ucc.asn.au>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
type constraints Potentially raises a question about how tightly to constrain argument types for a rule. See #232
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants