Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplest prod(x; dims) gradient #112

Merged
merged 1 commit into from
Feb 26, 2020
Merged

Simplest prod(x; dims) gradient #112

merged 1 commit into from
Feb 26, 2020

Conversation

mcabbott
Copy link
Member

The current gradient for prod(x; dims) gives incorrect results, this PR fixes it (parallel to FluxML/Tracker.jl#1 ):

julia> using Zygote, ForwardDiff

julia> r = rand(2,3,2);

julia> ForwardDiff.gradient(w->sum(prod(w, dims=(2,3))), r)
2×3×2 Array{Float64,3}:
[:, :, 1] =
 0.00131643  0.000954347  0.0051387 
 0.0177437   0.0354628    0.00934587

[:, :, 2] =
 0.00434307  0.0140455   0.00152818
 0.0151417   0.00464615  0.00451601

julia> Zygote.gradient(w->sum(prod(w, dims=(2,3))), r)[1] # wrong answer!
2×3×2 Array{Float64,3}:
[:, :, 1] =
 5.93867e-6  4.30525e-6  2.31817e-5
 1.60301e-5  3.2038e-5   8.44331e-6

[:, :, 2] =
 1.95925e-5  6.33622e-5  6.89391e-6
 1.36795e-5  4.19746e-6  4.07989e-6

julia> Zygote.@adjoint function prod(xs; dims = :) # as in this PR
         p = prod(xs; dims = dims)
         p, Δ -> (p ./ xs .* Δ,)
       end

julia> Zygote.refresh()

julia> Zygote.gradient(w->sum(prod(w, dims=(2,3))), r)[1] # now matches ForwardDiff
2×3×2 Array{Float64,3}:
[:, :, 1] =
 0.00131643  0.000954347  0.0051387 
 0.0177437   0.0354628    0.00934587

[:, :, 2] =
 0.00434307  0.0140455   0.00152818
 0.0151417   0.00464615  0.00451601

This does not handle zeros in the array correctly -- see FluxML/Flux.jl#524 for attempts to do that. The circshift(... operation deleted here was a correct (but slow) gradient for prod(x), but is clearly independent of dims.

The example above is almost the same as the one in the tests, which strangely passes, without this PR. Perhaps something is wrong with gradtest?

julia> @test gradtest(x -> prod(x, dims = (2, 3)), (3,4,5))
Test Passed

julia> @test gradtest(x -> prod(x), (3,4,5))
Test Passed

@mcabbott
Copy link
Member Author

Bump?

Just cleaned up the mess that github's merge tool sometimes creates with unicode y, ȳ -> etc. Locally I had test failures of gradtest((x, w) -> conv(x, w, cdims), x, w) etc, which I think are unrelated.

@DhairyaLGandhi
Copy link
Member

With a precursory look, it looks about right to me. Hopping on a flight now, but in the meantime

bors try

bors bot added a commit that referenced this pull request Feb 14, 2020
@bors
Copy link
Contributor

bors bot commented Feb 14, 2020

try

Build succeeded

@mcabbott mcabbott mentioned this pull request Feb 25, 2020
@CarloLucibello
Copy link
Member

bors r+

@CarloLucibello
Copy link
Member

@maleadt is bors down?

@DhairyaLGandhi
Copy link
Member

bors r+

@maleadt
Copy link
Contributor

maleadt commented Feb 26, 2020

https://app.bors.tech/
Looks down.

@CarloLucibello
Copy link
Member

bors r+

@bors
Copy link
Contributor

bors bot commented Feb 26, 2020

🔒 Permission denied

Existing reviewers: click here to make CarloLucibello a reviewer

@CarloLucibello
Copy link
Member

@MikeInnes @dhairyagandhi96 @maleadt could someone add me to bors' reviewers?

@DhairyaLGandhi
Copy link
Member

Could you try now?

@CarloLucibello
Copy link
Member

bors r+

@bors
Copy link
Contributor

bors bot commented Feb 26, 2020

🔒 Permission denied

Existing reviewers: click here to make CarloLucibello a reviewer

@CarloLucibello
Copy link
Member

Could you try now?

no luck

@DhairyaLGandhi
Copy link
Member

bors delegate=CarloLucibello

@bors
Copy link
Contributor

bors bot commented Feb 26, 2020

✌️ CarloLucibello can now approve this pull request. To approve and merge a pull request, simply reply with bors r+. More detailed instructions are available here.

@DhairyaLGandhi
Copy link
Member

Should do for now, will look at the dashboard in my morn

@DhairyaLGandhi
Copy link
Member

bors r+

@CarloLucibello
Copy link
Member

Should do for now, will look at the dashboard in my morn

I think I am in Flux's list now but not in Zygote's

@DhairyaLGandhi
Copy link
Member

That's what caught me off guard, bors-ng/bors-ng#517 suggests the fix too

@bors
Copy link
Contributor

bors bot commented Feb 26, 2020

Build succeeded

@bors bors bot merged commit af498fa into FluxML:master Feb 26, 2020
@mcabbott mcabbott deleted the patch-4 branch February 26, 2020 20:58
bors bot added a commit that referenced this pull request Feb 27, 2020
523: fix prod with tuple arg r=CarloLucibello a=CarloLucibello

This fixes the following problem caused by a relaxation of the signature in #112 
```julia
julia> gradient(x -> prod((1,2,3)), 1)
ERROR: MethodError: no method matching prod(::Tuple{Int64,Int64,Int64}; dims=Colon())
Closest candidates are:
  prod(::Tuple{Any,Vararg{Any,N} where N}) at tuple.jl:385 got unsupported keyword argument "dims"
  prod(::Any) at reduce.jl:448 got unsupported keyword argument "dims"
  prod(::Any, ::StaticArrays.StaticArray{#s160,T,N} where N where #s160<:Tuple; dims) where T at /home/carlo/.julia/packages/StaticArrays/1g9bq/src/mapreduce.jl:234
  ...
Stacktrace:
 [1] #adjoint#3920 at /home/carlo/.julia/packages/Zygote/XCgv1/src/lib/array.jl:220 [inlined]
 [2] adjoint at ./none:0 [inlined]
 [3] _pullback at /home/carlo/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:47 [inlined]
 [4] #17 at ./REPL[18]:1 [inlined]
 [5] _pullback(::Zygote.Context, ::var"#17#18", ::Int64) at /home/carlo/.julia/packages/Zygote/XCgv1/src/compiler/interface2.jl:?
 [6] _pullback(::Function, ::Int64) at /home/carlo/.julia/packages/Zygote/XCgv1/src/compiler/interface.jl:31
 [7] pullback(::Function, ::Int64) at /home/carlo/.julia/packages/Zygote/XCgv1/src/compiler/interface.jl:37
 [8] gradient(::Function, ::Int64, ::Vararg{Int64,N} where N) at /home/carlo/.julia/packages/Zygote/XCgv1/src/compiler/interface.jl:46
 [9] top-level scope at REPL[18]:1
``` 

Co-authored-by: CarloLucibello <carlo.lucibello@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants