Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplest prod(xs; dims) gradient #1

Merged
merged 2 commits into from
Aug 11, 2020
Merged

Conversation

mcabbott
Copy link
Member

As promised in FluxML/Flux.jl#524, this makes the gradient for prod understand keyword dims (rather than falling back to TrackedReal). It does not treat zeros correctly, see discussion.

@MikeInnes
Copy link
Member

MikeInnes commented Apr 4, 2019

I'm curious if @willtebbutt has any opinions on this.

@willtebbutt
Copy link
Member

My opinion is that I would like some finite-differencing tests to ensure correctness 🙂

@mcabbott
Copy link
Member Author

mcabbott commented Apr 5, 2019

I didn't add tests here as there are already some, here. However they pass because they aren't using the custom gradient currently defined, but falling back to TrackedReal. Demonstration:

using Tracker, ForwardDiff

r = rand(2,3,2);
ForwardDiff.gradient(w->sum(prod(w, dims=(2,3))), r)

Tracker.gradient(w->sum(prod(w, (2,3))), r)[1] # 0.6 notation hits @grad definition, wrong answer
Tracker.gradient(w->sum(prod(w, dims=(2,3))), r)[1] # correct via TrackedReal, slow

Tracker.@grad function prod(xs; dims=:) # this PR
  p = prod(Tracker.data(xs); dims=dims)
  p, Δ -> (p ./ xs .* Δ,)
end

Tracker.gradient(w->sum(prod(w, dims=(2,3))), r)[1] # now correct

What this PR doesn't try to do is to treat cases where some entries are zero correctly. That is more involved; in the Flux issue I wrote a few different ways (with a trade-off between complication & speed) but the suggestion was to first make a PR for the simplest case.

r[1,1,1] = 0;

ForwardDiff.gradient(w->sum(prod(w, dims=(2,3))), r)
Tracker.gradient(w->sum(prod(w, dims=(2,3))), r)[1] # contains NaN

bors bot added a commit to FluxML/Zygote.jl that referenced this pull request Feb 26, 2020
112: Simplest prod(x; dims) gradient r=dhairyagandhi96 a=mcabbott

The current gradient for `prod(x; dims)` gives incorrect results, this PR fixes it (parallel to  FluxML/Tracker.jl#1 ):
```
julia> using Zygote, ForwardDiff

julia> r = rand(2,3,2);

julia> ForwardDiff.gradient(w->sum(prod(w, dims=(2,3))), r)
2×3×2 Array{Float64,3}:
[:, :, 1] =
 0.00131643  0.000954347  0.0051387 
 0.0177437   0.0354628    0.00934587

[:, :, 2] =
 0.00434307  0.0140455   0.00152818
 0.0151417   0.00464615  0.00451601

julia> Zygote.gradient(w->sum(prod(w, dims=(2,3))), r)[1] # wrong answer!
2×3×2 Array{Float64,3}:
[:, :, 1] =
 5.93867e-6  4.30525e-6  2.31817e-5
 1.60301e-5  3.2038e-5   8.44331e-6

[:, :, 2] =
 1.95925e-5  6.33622e-5  6.89391e-6
 1.36795e-5  4.19746e-6  4.07989e-6

julia> Zygote.@adjoint function prod(xs; dims = :) # as in this PR
         p = prod(xs; dims = dims)
         p, Δ -> (p ./ xs .* Δ,)
       end

julia> Zygote.refresh()

julia> Zygote.gradient(w->sum(prod(w, dims=(2,3))), r)[1] # now matches ForwardDiff
2×3×2 Array{Float64,3}:
[:, :, 1] =
 0.00131643  0.000954347  0.0051387 
 0.0177437   0.0354628    0.00934587

[:, :, 2] =
 0.00434307  0.0140455   0.00152818
 0.0151417   0.00464615  0.00451601
```
This does not handle zeros in the array correctly -- see FluxML/Flux.jl#524 for attempts to do that. The `circshift(...` operation deleted here was a correct (but slow) gradient for `prod(x)`, but is clearly independent of `dims`. 

The example above is almost the same as the one in the tests, which strangely passes, without this PR. Perhaps something is wrong with `gradtest`?
```
julia> @test gradtest(x -> prod(x, dims = (2, 3)), (3,4,5))
Test Passed

julia> @test gradtest(x -> prod(x), (3,4,5))
Test Passed
```

Co-authored-by: Michael Abbott <me@pseudomac>
@mcabbott mcabbott closed this Aug 9, 2020
@mcabbott mcabbott reopened this Aug 9, 2020
Copy link
Member

@willtebbutt willtebbutt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I'm happy with this. Apologies for leaving it open for so long.

@willtebbutt
Copy link
Member

bors r+

@willtebbutt
Copy link
Member

Oh, wait, no bors on this repo 😂 . @mcabbott could you bump the batch version so that we can create a release?

mcabbott and others added 2 commits August 10, 2020 13:49
@mcabbott
Copy link
Member Author

Sure, done!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants