From 6a38764ce8b0eb79865f4a31ac8dbeadb8268e80 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 12 Jan 2021 13:09:36 -0800 Subject: [PATCH] Only check type-stability of frule and rrule if primal is type-stable (#103) * Add utility function to check whether a function is type-stable * Only check if frule and rrule are type-stable if primal is * Add tests * Increment version number * Clarify that this checks inferrability --- Project.toml | 2 +- src/testers.jl | 29 ++++++++++++++++++++++++----- test/testers.jl | 14 ++++++++++++++ 3 files changed, 39 insertions(+), 6 deletions(-) diff --git a/Project.toml b/Project.toml index 387a7c32..09128742 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRulesTestUtils" uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a" -version = "0.6.0" +version = "0.6.1" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/testers.jl b/src/testers.jl index 16259993..ec79fc8c 100644 --- a/src/testers.jl +++ b/src/testers.jl @@ -188,6 +188,20 @@ function _test_inferred(f, args...; kwargs...) end end +""" + _is_inferrable(f, args...; kwargs...) -> Bool + +Return whether the return type of `f(args...; kwargs...)` is inferrable. +""" +function _is_inferrable(f, args...; kwargs...) + try + _test_inferred(f, args...; kwargs...) + return true + catch ErrorException + return false + end +end + """ frule_test(f, (x, ẋ)...; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1), fkwargs=NamedTuple(), check_inferred=true, kwargs...) @@ -197,7 +211,8 @@ end - `ẋ`: differential w.r.t. `x` (should generally be set randomly). `fkwargs` are passed to `f` as keyword arguments. -If `check_inferred=true`, then the type-stability of the `frule` is checked. +If `check_inferred=true`, then the inferrability of the `frule` is checked, as long as `f` +is itself inferrable. All remaining keyword arguments are passed to `isapprox`. """ function frule_test(f, xẋs::Tuple{Any, Any}...; rtol::Real=1e-9, atol::Real=1e-9, fdm=_fdm, fkwargs::NamedTuple=NamedTuple(), check_inferred::Bool=true, kwargs...) @@ -208,7 +223,9 @@ function frule_test(f, xẋs::Tuple{Any, Any}...; rtol::Real=1e-9, atol::Real=1e xs = first.(xẋs) ẋs = last.(xẋs) - check_inferred && _test_inferred(frule, (NO_FIELDS, deepcopy(ẋs)...), f, deepcopy(xs)...; deepcopy(fkwargs)...) + if check_inferred && _is_inferrable(f, deepcopy(xs)...; deepcopy(fkwargs)...) + _test_inferred(frule, (NO_FIELDS, deepcopy(ẋs)...), f, deepcopy(xs)...; deepcopy(fkwargs)...) + end res = frule((NO_FIELDS, deepcopy(ẋs)...), f, deepcopy(xs)...; deepcopy(fkwargs)...) res === nothing && throw(MethodError(frule, typeof((f, xs...)))) Ω_ad, dΩ_ad = res @@ -239,8 +256,8 @@ end - `x̄`: currently accumulated adjoint (should generally be set randomly). `fkwargs` are passed to `f` as keyword arguments. -If `check_inferred=true`, then the type-stability of the `rrule` and the pullback it -returns are checked. +If `check_inferred=true`, then the inferrability of the `rrule` is checked — if `f` is +itself inferrable — along with the inferrability of the pullback it returns. All remaining keyword arguments are passed to `isapprox`. """ function rrule_test(f, ȳ, xx̄s::Tuple{Any, Any}...; rtol::Real=1e-9, atol::Real=1e-9, fdm=_fdm, check_inferred::Bool=true, fkwargs::NamedTuple=NamedTuple(), kwargs...) @@ -252,7 +269,9 @@ function rrule_test(f, ȳ, xx̄s::Tuple{Any, Any}...; rtol::Real=1e-9, atol::Re # Check correctness of evaluation. xs = first.(xx̄s) accumulated_x̄ = last.(xx̄s) - check_inferred && _test_inferred(rrule, f, xs...; fkwargs...) + if check_inferred && _is_inferrable(f, xs...; fkwargs...) + _test_inferred(rrule, f, xs...; fkwargs...) + end res = rrule(f, xs...; fkwargs...) res === nothing && throw(MethodError(rrule, typeof((f, xs...)))) y_ad, pullback = res diff --git a/test/testers.jl b/test/testers.jl index 4b5bb33d..bf2240da 100644 --- a/test/testers.jl +++ b/test/testers.jl @@ -12,6 +12,8 @@ f_noninferrable_frule(x) = x f_noninferrable_rrule(x) = x f_noninferrable_pullback(x) = x f_noninferrable_thunk(x, y) = x + y +f_inferrable_pullback_only(x) = x > 0 ? Float64(x) : Float32(x) + function finplace!(x; y = [1]) y[1] = 2 @@ -173,6 +175,18 @@ end rrule_test(f_noninferrable_thunk, z̄, (x, x̄), (y, ȳ); check_inferred = false) @test_throws ErrorException rrule_test(f_noninferrable_thunk, z̄, (x, x̄), (y, ȳ)) end + + @testset "check non-inferrable primal still passes if pullback inferrable" begin + function ChainRulesCore.frule((_, Δx), ::typeof(f_inferrable_pullback_only), x) + return (x > 0 ? Float64(x) : Float32(x), x > 0 ? Float64(Δx) : Float32(Δx)) + end + function ChainRulesCore.rrule(::typeof(f_inferrable_pullback_only), x) + f_inferrable_pullback_only_pullback(Δy) = (NO_FIELDS, oftype(x, Δy)) + return x > 0 ? Float64(x) : Float32(x), f_inferrable_pullback_only_pullback + end + frule_test(f_inferrable_pullback_only, (x, ẋ); check_inferred = true) + rrule_test(f_inferrable_pullback_only, z̄, (x, x̄); check_inferred = true) + end end @testset "test derivative conjugated in pullback" begin