Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add @adjoin for relu to increase speed #367

Merged
merged 12 commits into from
Oct 24, 2019
2 changes: 1 addition & 1 deletion src/Zygote.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ include("lib/number.jl")
include("lib/base.jl")
include("lib/array.jl")
include("lib/buffer.jl")
include("lib/nnlib.jl")
include("lib/broadcast.jl")
include("lib/nnlib.jl")
findmyway marked this conversation as resolved.
Show resolved Hide resolved
include("lib/forward.jl")
include("lib/utils.jl")
@init @require Distances="b4f34e82-e78d-54a5-968a-f98e89d6e8f7" include("lib/distances.jl")
Expand Down
9 changes: 8 additions & 1 deletion src/lib/nnlib.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
using NNlib
import NNlib: softmax, ∇softmax, logsoftmax, ∇logsoftmax, conv, depthwiseconv, ∇conv_data, ∇depthwiseconv_data, maxpool, meanpool, σ
import NNlib: softmax, ∇softmax, logsoftmax, ∇logsoftmax, conv, depthwiseconv, ∇conv_data, ∇depthwiseconv_data, maxpool, meanpool, σ, relu

@adjoint Base.Broadcast.broadcasted(::typeof(identity), x::Numeric) = x, Δ -> (nothing, Δ)
findmyway marked this conversation as resolved.
Show resolved Hide resolved

@adjoint function Base.Broadcast.broadcasted(::typeof(relu), x::Numeric)
y = relu.(x)
y, Δ -> (nothing, ifelse.(y .> 0, Δ, zero.(y)))
end

@adjoint function σ(x::Real)
y = σ(x)
Expand Down
8 changes: 7 additions & 1 deletion test/gradcheck.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ Random.seed!(0)

@test gradient(//, 2, 3) === (1//3, -2//9)

@test gradtest((x, W, b) -> identity.(W*x .+ b), 5, (2,5), 2)
@test gradtest((x, W, b) -> identity.(W*x .+ b), (5,3), (2,5), 2)

@test gradtest((x, W, b) -> relu.(W*x .+ b), 5, (2,5), 2)
@test gradtest((x, W, b) -> relu.(W*x .+ b), (5,3), (2,5), 2)

@test gradtest((x, W, b) -> σ.(W*x .+ b), 5, (2,5), 2)
@test gradtest((x, W, b) -> σ.(W*x .+ b), (5,3), (2,5), 2)
@test gradtest((x, W, b) -> logσ.(W*x .+ b), 5, (2,5), 2)
Expand All @@ -40,7 +46,7 @@ Random.seed!(0)
@test gradtest(x -> sum(x, dims = (2, 3)), (3,4,5))
@test gradtest(x -> sum(abs2, x), randn(4, 3, 2))
@test gradtest(x -> sum(abs2, x; dims=1), randn(4, 3, 2))
@test gradtest(x -> prod(x, dims = (2, 3)), (3,4,5))
@test_broken gradtest(x -> prod(x, dims = (2, 3)), (3,4,5))
@test gradtest(x -> prod(x), (3,4,5))

@test gradtest(x -> softmax(x).*(1:3), 3)
Expand Down