-
-
Notifications
You must be signed in to change notification settings - Fork 37
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Simplest prod(xs; dims) gradient #1
Conversation
I'm curious if @willtebbutt has any opinions on this. |
My opinion is that I would like some finite-differencing tests to ensure correctness 🙂 |
I didn't add tests here as there are already some, here. However they pass because they aren't using the custom gradient currently defined, but falling back to TrackedReal. Demonstration: using Tracker, ForwardDiff
r = rand(2,3,2);
ForwardDiff.gradient(w->sum(prod(w, dims=(2,3))), r)
Tracker.gradient(w->sum(prod(w, (2,3))), r)[1] # 0.6 notation hits @grad definition, wrong answer
Tracker.gradient(w->sum(prod(w, dims=(2,3))), r)[1] # correct via TrackedReal, slow
Tracker.@grad function prod(xs; dims=:) # this PR
p = prod(Tracker.data(xs); dims=dims)
p, Δ -> (p ./ xs .* Δ,)
end
Tracker.gradient(w->sum(prod(w, dims=(2,3))), r)[1] # now correct What this PR doesn't try to do is to treat cases where some entries are zero correctly. That is more involved; in the Flux issue I wrote a few different ways (with a trade-off between complication & speed) but the suggestion was to first make a PR for the simplest case. r[1,1,1] = 0;
ForwardDiff.gradient(w->sum(prod(w, dims=(2,3))), r)
Tracker.gradient(w->sum(prod(w, dims=(2,3))), r)[1] # contains NaN |
112: Simplest prod(x; dims) gradient r=dhairyagandhi96 a=mcabbott The current gradient for `prod(x; dims)` gives incorrect results, this PR fixes it (parallel to FluxML/Tracker.jl#1 ): ``` julia> using Zygote, ForwardDiff julia> r = rand(2,3,2); julia> ForwardDiff.gradient(w->sum(prod(w, dims=(2,3))), r) 2×3×2 Array{Float64,3}: [:, :, 1] = 0.00131643 0.000954347 0.0051387 0.0177437 0.0354628 0.00934587 [:, :, 2] = 0.00434307 0.0140455 0.00152818 0.0151417 0.00464615 0.00451601 julia> Zygote.gradient(w->sum(prod(w, dims=(2,3))), r)[1] # wrong answer! 2×3×2 Array{Float64,3}: [:, :, 1] = 5.93867e-6 4.30525e-6 2.31817e-5 1.60301e-5 3.2038e-5 8.44331e-6 [:, :, 2] = 1.95925e-5 6.33622e-5 6.89391e-6 1.36795e-5 4.19746e-6 4.07989e-6 julia> Zygote.@adjoint function prod(xs; dims = :) # as in this PR p = prod(xs; dims = dims) p, Δ -> (p ./ xs .* Δ,) end julia> Zygote.refresh() julia> Zygote.gradient(w->sum(prod(w, dims=(2,3))), r)[1] # now matches ForwardDiff 2×3×2 Array{Float64,3}: [:, :, 1] = 0.00131643 0.000954347 0.0051387 0.0177437 0.0354628 0.00934587 [:, :, 2] = 0.00434307 0.0140455 0.00152818 0.0151417 0.00464615 0.00451601 ``` This does not handle zeros in the array correctly -- see FluxML/Flux.jl#524 for attempts to do that. The `circshift(...` operation deleted here was a correct (but slow) gradient for `prod(x)`, but is clearly independent of `dims`. The example above is almost the same as the one in the tests, which strangely passes, without this PR. Perhaps something is wrong with `gradtest`? ``` julia> @test gradtest(x -> prod(x, dims = (2, 3)), (3,4,5)) Test Passed julia> @test gradtest(x -> prod(x), (3,4,5)) Test Passed ``` Co-authored-by: Michael Abbott <me@pseudomac>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, I'm happy with this. Apologies for leaving it open for so long.
bors r+ |
Oh, wait, no bors on this repo 😂 . @mcabbott could you bump the batch version so that we can create a release? |
Will not treat zeros correctly, see FluxML/Flux.jl#524
Sure, done! |
As promised in FluxML/Flux.jl#524, this makes the gradient for
prod
understand keyworddims
(rather than falling back to TrackedReal). It does not treat zeros correctly, see discussion.