Skip to content

Commit

Permalink
Test the number of outputs in frule and rrule are correct (#168) (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
mzgubic committed Jun 4, 2021
1 parent a6bc18b commit 3bde970
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 16 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRulesTestUtils"
uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a"
version = "0.7.3"
version = "0.7.4"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
8 changes: 4 additions & 4 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ julia> using ChainRulesTestUtils;
julia> test_frule(two2three, 3.33, -7.77);
Test Summary: | Pass Total
test_frule: two2three on Float64,Float64 | 5 5
test_frule: two2three on Float64,Float64 | 6 6
```

### Testing the `rrule`
Expand All @@ -71,7 +71,7 @@ The call will test the `rrule` for function `f` at the point `x`, and similarly
```jldoctest ex; output = false
julia> test_rrule(two2three, 3.33, -7.77);
Test Summary: | Pass Total
test_rrule: two2three on Float64,Float64 | 6 6
test_rrule: two2three on Float64,Float64 | 7 7
```

## Scalar example
Expand All @@ -98,11 +98,11 @@ call.
```jldoctest ex; output = false
julia> test_scalar(relu, 0.5);
Test Summary: | Pass Total
test_scalar: relu at 0.5 | 7 7
test_scalar: relu at 0.5 | 9 9
julia> test_scalar(relu, -0.5);
Test Summary: | Pass Total
test_scalar: relu at -0.5 | 7 7
test_scalar: relu at -0.5 | 9 9
```

## Specifying Tangents
Expand Down
24 changes: 13 additions & 11 deletions src/testers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,11 @@ function test_scalar(f, z; rtol=1e-9, atol=1e-9, fdm=_fdm, fkwargs=NamedTuple(),
end

"""
test_frule(f, inputs...; kwargs...)
test_frule(f, args..; kwargs...)
# Arguments
- `f`: Function for which the `frule` should be tested.
- `inputs` either the primal inputs `x`, or primals and their tangents: `x ⊢ ẋ`
- `args` either the primal args `x`, or primals and their tangents: `x ⊢ ẋ`
- `x`: input at which to evaluate `f` (should generally be set to an arbitary point in the domain).
- `ẋ`: differential w.r.t. `x`, will be generated automatically if not provided
Non-differentiable arguments, such as indices, should have `ẋ` set as `NoTangent()`.
Expand All @@ -87,7 +87,7 @@ end
"""
function test_frule(
f,
inputs...;
args...;
output_tangent=Auto(),
fdm=_fdm,
check_inferred::Bool=true,
Expand All @@ -99,18 +99,18 @@ function test_frule(
# To simplify some of the calls we make later lets group the kwargs for reuse
isapprox_kwargs = (; rtol=rtol, atol=atol, kwargs...)

@testset "test_frule: $f on $(_string_typeof(inputs))" begin
@testset "test_frule: $f on $(_string_typeof(args))" begin
_ensure_not_running_on_functor(f, "test_frule")

xẋs = auto_primal_and_tangent.(inputs)
xẋs = auto_primal_and_tangent.(args)
xs = primal.(xẋs)
ẋs = tangent.(xẋs)
if check_inferred && _is_inferrable(f, deepcopy(xs)...; deepcopy(fkwargs)...)
_test_inferred(frule, (NoTangent(), deepcopy(ẋs)...), f, deepcopy(xs)...; deepcopy(fkwargs)...)
end
res = frule((NoTangent(), deepcopy(ẋs)...), f, deepcopy(xs)...; deepcopy(fkwargs)...)
res === nothing && throw(MethodError(frule, typeof((f, xs...))))
res isa Tuple || error("The frule should return (y, ∂y), not $res.")
@test_msg "The frule should return (y, ∂y), not $res." res isa Tuple{Any,Any}
Ω_ad, dΩ_ad = res
Ω = f(deepcopy(xs)...; deepcopy(fkwargs)...)
test_approx(Ω_ad, Ω; isapprox_kwargs...)
Expand All @@ -135,11 +135,11 @@ function test_frule(
end

"""
test_rrule(f, inputs...; kwargs...)
test_rrule(f, args...; kwargs...)
# Arguments
- `f`: Function to which rule should be applied.
- `inputs` either the primal inputs `x`, or primals and their tangents: `x ⊢ ẋ`
- `args` either the primal args `x`, or primals and their tangents: `x ⊢ ẋ`
- `x`: input at which to evaluate `f` (should generally be set to an arbitary point in the domain).
- `x̄`: currently accumulated cotangent, will be generated automatically if not provided
Non-differentiable arguments, such as indices, should have `x̄` set as `NoTangent()`.
Expand All @@ -155,7 +155,7 @@ end
"""
function test_rrule(
f,
inputs...;
args...;
output_tangent=Auto(),
fdm=_fdm,
check_inferred::Bool=true,
Expand All @@ -167,11 +167,11 @@ function test_rrule(
# To simplify some of the calls we make later lets group the kwargs for reuse
isapprox_kwargs = (; rtol=rtol, atol=atol, kwargs...)

@testset "test_rrule: $f on $(_string_typeof(inputs))" begin
@testset "test_rrule: $f on $(_string_typeof(args))" begin
_ensure_not_running_on_functor(f, "test_rrule")

# Check correctness of evaluation.
xx̄s = auto_primal_and_tangent.(inputs)
xx̄s = auto_primal_and_tangent.(args)
xs = primal.(xx̄s)
accumulated_x̄ = tangent.(xx̄s)
if check_inferred && _is_inferrable(f, xs...; fkwargs...)
Expand All @@ -191,6 +191,8 @@ function test_rrule(
∂self = ∂s[1]
x̄s_ad = ∂s[2:end]
@test ∂self === NoTangent() # No internal fields
msg = "The pullback should return 1 cotangent for each primal input."
@test_msg msg length(x̄s_ad) == length(args)

# Correctness testing via finite differencing.
# TODO: remove Nothing when https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/113
Expand Down
16 changes: 16 additions & 0 deletions test/testers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,22 @@ end
@test fails(() -> test_frule(my_identity2, 2.2))
@test fails(() -> test_rrule(my_identity2, 2.2))
end

@testset "wrong number of outputs #167" begin
foo(x, y) = x + 2y

function ChainRulesCore.frule((_, ẋ, ẏ), ::typeof(foo), x, y)
return foo(x, y), ẋ + 2ẏ, NoTangent() # extra derivative
end

function ChainRulesCore.rrule(::typeof(foo), x, y)
foo_pullback(dz) = NoTangent(), dz # missing derivative
return foo(x,y), foo_pullback
end

@test fails(() -> test_frule(foo, 2.1, 2.1))
@test fails(() -> test_rrule(foo, 21.0, 32.0))
end
end

@testset "Tuple primal that is not equal to differential backing" begin
Expand Down

2 comments on commit 3bde970

@mzgubic
Copy link
Member Author

@mzgubic mzgubic commented on 3bde970 Jun 4, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/38186

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.7.4 -m "<description of version>" 3bde97012e02e90fec01d866d223ae04cb976c22
git push origin v0.7.4

Please sign in to comment.