diff --git a/Project.toml b/Project.toml index 676fa291..0dea4ffa 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRulesTestUtils" uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a" -version = "0.7.3" +version = "0.7.4" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/docs/src/index.md b/docs/src/index.md index a632f404..4feb87e3 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -60,7 +60,7 @@ julia> using ChainRulesTestUtils; julia> test_frule(two2three, 3.33, -7.77); Test Summary: | Pass Total -test_frule: two2three on Float64,Float64 | 5 5 +test_frule: two2three on Float64,Float64 | 6 6 ``` ### Testing the `rrule` @@ -71,7 +71,7 @@ The call will test the `rrule` for function `f` at the point `x`, and similarly ```jldoctest ex; output = false julia> test_rrule(two2three, 3.33, -7.77); Test Summary: | Pass Total -test_rrule: two2three on Float64,Float64 | 6 6 +test_rrule: two2three on Float64,Float64 | 7 7 ``` ## Scalar example @@ -98,11 +98,11 @@ call. ```jldoctest ex; output = false julia> test_scalar(relu, 0.5); Test Summary: | Pass Total -test_scalar: relu at 0.5 | 7 7 +test_scalar: relu at 0.5 | 9 9 julia> test_scalar(relu, -0.5); Test Summary: | Pass Total -test_scalar: relu at -0.5 | 7 7 +test_scalar: relu at -0.5 | 9 9 ``` ## Specifying Tangents diff --git a/src/testers.jl b/src/testers.jl index c00eedd6..a80b565c 100644 --- a/src/testers.jl +++ b/src/testers.jl @@ -67,11 +67,11 @@ function test_scalar(f, z; rtol=1e-9, atol=1e-9, fdm=_fdm, fkwargs=NamedTuple(), end """ - test_frule(f, inputs...; kwargs...) + test_frule(f, args..; kwargs...) # Arguments - `f`: Function for which the `frule` should be tested. -- `inputs` either the primal inputs `x`, or primals and their tangents: `x ⊢ ẋ` +- `args` either the primal args `x`, or primals and their tangents: `x ⊢ ẋ` - `x`: input at which to evaluate `f` (should generally be set to an arbitary point in the domain). - `ẋ`: differential w.r.t. `x`, will be generated automatically if not provided Non-differentiable arguments, such as indices, should have `ẋ` set as `NoTangent()`. @@ -87,7 +87,7 @@ end """ function test_frule( f, - inputs...; + args...; output_tangent=Auto(), fdm=_fdm, check_inferred::Bool=true, @@ -99,10 +99,10 @@ function test_frule( # To simplify some of the calls we make later lets group the kwargs for reuse isapprox_kwargs = (; rtol=rtol, atol=atol, kwargs...) - @testset "test_frule: $f on $(_string_typeof(inputs))" begin + @testset "test_frule: $f on $(_string_typeof(args))" begin _ensure_not_running_on_functor(f, "test_frule") - xẋs = auto_primal_and_tangent.(inputs) + xẋs = auto_primal_and_tangent.(args) xs = primal.(xẋs) ẋs = tangent.(xẋs) if check_inferred && _is_inferrable(f, deepcopy(xs)...; deepcopy(fkwargs)...) @@ -110,7 +110,7 @@ function test_frule( end res = frule((NoTangent(), deepcopy(ẋs)...), f, deepcopy(xs)...; deepcopy(fkwargs)...) res === nothing && throw(MethodError(frule, typeof((f, xs...)))) - res isa Tuple || error("The frule should return (y, ∂y), not $res.") + @test_msg "The frule should return (y, ∂y), not $res." res isa Tuple{Any,Any} Ω_ad, dΩ_ad = res Ω = f(deepcopy(xs)...; deepcopy(fkwargs)...) test_approx(Ω_ad, Ω; isapprox_kwargs...) @@ -135,11 +135,11 @@ function test_frule( end """ - test_rrule(f, inputs...; kwargs...) + test_rrule(f, args...; kwargs...) # Arguments - `f`: Function to which rule should be applied. -- `inputs` either the primal inputs `x`, or primals and their tangents: `x ⊢ ẋ` +- `args` either the primal args `x`, or primals and their tangents: `x ⊢ ẋ` - `x`: input at which to evaluate `f` (should generally be set to an arbitary point in the domain). - `x̄`: currently accumulated cotangent, will be generated automatically if not provided Non-differentiable arguments, such as indices, should have `x̄` set as `NoTangent()`. @@ -155,7 +155,7 @@ end """ function test_rrule( f, - inputs...; + args...; output_tangent=Auto(), fdm=_fdm, check_inferred::Bool=true, @@ -167,11 +167,11 @@ function test_rrule( # To simplify some of the calls we make later lets group the kwargs for reuse isapprox_kwargs = (; rtol=rtol, atol=atol, kwargs...) - @testset "test_rrule: $f on $(_string_typeof(inputs))" begin + @testset "test_rrule: $f on $(_string_typeof(args))" begin _ensure_not_running_on_functor(f, "test_rrule") # Check correctness of evaluation. - xx̄s = auto_primal_and_tangent.(inputs) + xx̄s = auto_primal_and_tangent.(args) xs = primal.(xx̄s) accumulated_x̄ = tangent.(xx̄s) if check_inferred && _is_inferrable(f, xs...; fkwargs...) @@ -191,6 +191,8 @@ function test_rrule( ∂self = ∂s[1] x̄s_ad = ∂s[2:end] @test ∂self === NoTangent() # No internal fields + msg = "The pullback should return 1 cotangent for each primal input." + @test_msg msg length(x̄s_ad) == length(args) # Correctness testing via finite differencing. # TODO: remove Nothing when https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/113 diff --git a/test/testers.jl b/test/testers.jl index 7db1d8df..8cbf24a6 100644 --- a/test/testers.jl +++ b/test/testers.jl @@ -495,6 +495,22 @@ end @test fails(() -> test_frule(my_identity2, 2.2)) @test fails(() -> test_rrule(my_identity2, 2.2)) end + + @testset "wrong number of outputs #167" begin + foo(x, y) = x + 2y + + function ChainRulesCore.frule((_, ẋ, ẏ), ::typeof(foo), x, y) + return foo(x, y), ẋ + 2ẏ, NoTangent() # extra derivative + end + + function ChainRulesCore.rrule(::typeof(foo), x, y) + foo_pullback(dz) = NoTangent(), dz # missing derivative + return foo(x,y), foo_pullback + end + + @test fails(() -> test_frule(foo, 2.1, 2.1)) + @test fails(() -> test_rrule(foo, 21.0, 32.0)) + end end @testset "Tuple primal that is not equal to differential backing" begin