Skip to content

Commit

Permalink
Merge #523
Browse files Browse the repository at this point in the history
523: fix prod with tuple arg r=CarloLucibello a=CarloLucibello

This fixes the following problem caused by a relaxation of the signature in #112 
```julia
julia> gradient(x -> prod((1,2,3)), 1)
ERROR: MethodError: no method matching prod(::Tuple{Int64,Int64,Int64}; dims=Colon())
Closest candidates are:
  prod(::Tuple{Any,Vararg{Any,N} where N}) at tuple.jl:385 got unsupported keyword argument "dims"
  prod(::Any) at reduce.jl:448 got unsupported keyword argument "dims"
  prod(::Any, ::StaticArrays.StaticArray{#s160,T,N} where N where #s160<:Tuple; dims) where T at /home/carlo/.julia/packages/StaticArrays/1g9bq/src/mapreduce.jl:234
  ...
Stacktrace:
 [1] #adjoint#3920 at /home/carlo/.julia/packages/Zygote/XCgv1/src/lib/array.jl:220 [inlined]
 [2] adjoint at ./none:0 [inlined]
 [3] _pullback at /home/carlo/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:47 [inlined]
 [4] #17 at ./REPL[18]:1 [inlined]
 [5] _pullback(::Zygote.Context, ::var"#17#18", ::Int64) at /home/carlo/.julia/packages/Zygote/XCgv1/src/compiler/interface2.jl:?
 [6] _pullback(::Function, ::Int64) at /home/carlo/.julia/packages/Zygote/XCgv1/src/compiler/interface.jl:31
 [7] pullback(::Function, ::Int64) at /home/carlo/.julia/packages/Zygote/XCgv1/src/compiler/interface.jl:37
 [8] gradient(::Function, ::Int64, ::Vararg{Int64,N} where N) at /home/carlo/.julia/packages/Zygote/XCgv1/src/compiler/interface.jl:46
 [9] top-level scope at REPL[18]:1
``` 

Co-authored-by: CarloLucibello <carlo.lucibello@gmail.com>
  • Loading branch information
bors[bot] and CarloLucibello committed Feb 27, 2020
2 parents a543f1b + ad4a8f1 commit aaa5b01
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/lib/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ end
return sum(abs2, X; dims=dims), Δ::Union{Number, AbstractArray}->(nothing, ((2Δ) .* X))
end

@adjoint function prod(xs; dims = :)
@adjoint function prod(xs::AbstractArray; dims = :)
p = prod(xs; dims = dims)
p, Δ -> (p ./ xs .* Δ,)
end
Expand Down
1 change: 1 addition & 0 deletions test/gradcheck.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ end

@test gradtest(x -> prod(x, dims = (2, 3)), (3,4,5))
@test gradtest(x -> prod(x), (3,4))
@test gradient(x -> prod(x), (1,2,3))[1] == (6,3,2)

@test gradtest(x -> softmax(x).*(1:3), 3)
@test gradtest(x -> softmax(x).*(1:3), (3,5))
Expand Down

0 comments on commit aaa5b01

Please sign in to comment.