Skip to content

Commit

Permalink
Only check type-stability of frule and rrule if primal is type-stable (
Browse files Browse the repository at this point in the history
…#103)

* Add utility function to check whether a function is type-stable

* Only check if frule and rrule are type-stable if primal is

* Add tests

* Increment version number

* Clarify that this checks inferrability
  • Loading branch information
sethaxen committed Jan 12, 2021
1 parent cb004ef commit 6a38764
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 6 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRulesTestUtils"
uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a"
version = "0.6.0"
version = "0.6.1"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
29 changes: 24 additions & 5 deletions src/testers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,20 @@ function _test_inferred(f, args...; kwargs...)
end
end

"""
_is_inferrable(f, args...; kwargs...) -> Bool
Return whether the return type of `f(args...; kwargs...)` is inferrable.
"""
function _is_inferrable(f, args...; kwargs...)
try
_test_inferred(f, args...; kwargs...)
return true
catch ErrorException
return false
end
end

"""
frule_test(f, (x, ẋ)...; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1), fkwargs=NamedTuple(), check_inferred=true, kwargs...)
Expand All @@ -197,7 +211,8 @@ end
- `ẋ`: differential w.r.t. `x` (should generally be set randomly).
`fkwargs` are passed to `f` as keyword arguments.
If `check_inferred=true`, then the type-stability of the `frule` is checked.
If `check_inferred=true`, then the inferrability of the `frule` is checked, as long as `f`
is itself inferrable.
All remaining keyword arguments are passed to `isapprox`.
"""
function frule_test(f, xẋs::Tuple{Any, Any}...; rtol::Real=1e-9, atol::Real=1e-9, fdm=_fdm, fkwargs::NamedTuple=NamedTuple(), check_inferred::Bool=true, kwargs...)
Expand All @@ -208,7 +223,9 @@ function frule_test(f, xẋs::Tuple{Any, Any}...; rtol::Real=1e-9, atol::Real=1e

xs = first.(xẋs)
ẋs = last.(xẋs)
check_inferred && _test_inferred(frule, (NO_FIELDS, deepcopy(ẋs)...), f, deepcopy(xs)...; deepcopy(fkwargs)...)
if check_inferred && _is_inferrable(f, deepcopy(xs)...; deepcopy(fkwargs)...)
_test_inferred(frule, (NO_FIELDS, deepcopy(ẋs)...), f, deepcopy(xs)...; deepcopy(fkwargs)...)
end
res = frule((NO_FIELDS, deepcopy(ẋs)...), f, deepcopy(xs)...; deepcopy(fkwargs)...)
res === nothing && throw(MethodError(frule, typeof((f, xs...))))
Ω_ad, dΩ_ad = res
Expand Down Expand Up @@ -239,8 +256,8 @@ end
- `x̄`: currently accumulated adjoint (should generally be set randomly).
`fkwargs` are passed to `f` as keyword arguments.
If `check_inferred=true`, then the type-stability of the `rrule` and the pullback it
returns are checked.
If `check_inferred=true`, then the inferrability of the `rrule` is checked — if `f` is
itself inferrable — along with the inferrability of the pullback it returns.
All remaining keyword arguments are passed to `isapprox`.
"""
function rrule_test(f, ȳ, xx̄s::Tuple{Any, Any}...; rtol::Real=1e-9, atol::Real=1e-9, fdm=_fdm, check_inferred::Bool=true, fkwargs::NamedTuple=NamedTuple(), kwargs...)
Expand All @@ -252,7 +269,9 @@ function rrule_test(f, ȳ, xx̄s::Tuple{Any, Any}...; rtol::Real=1e-9, atol::Re
# Check correctness of evaluation.
xs = first.(xx̄s)
accumulated_x̄ = last.(xx̄s)
check_inferred && _test_inferred(rrule, f, xs...; fkwargs...)
if check_inferred && _is_inferrable(f, xs...; fkwargs...)
_test_inferred(rrule, f, xs...; fkwargs...)
end
res = rrule(f, xs...; fkwargs...)
res === nothing && throw(MethodError(rrule, typeof((f, xs...))))
y_ad, pullback = res
Expand Down
14 changes: 14 additions & 0 deletions test/testers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ f_noninferrable_frule(x) = x
f_noninferrable_rrule(x) = x
f_noninferrable_pullback(x) = x
f_noninferrable_thunk(x, y) = x + y
f_inferrable_pullback_only(x) = x > 0 ? Float64(x) : Float32(x)


function finplace!(x; y = [1])
y[1] = 2
Expand Down Expand Up @@ -173,6 +175,18 @@ end
rrule_test(f_noninferrable_thunk, z̄, (x, x̄), (y, ȳ); check_inferred = false)
@test_throws ErrorException rrule_test(f_noninferrable_thunk, z̄, (x, x̄), (y, ȳ))
end

@testset "check non-inferrable primal still passes if pullback inferrable" begin
function ChainRulesCore.frule((_, Δx), ::typeof(f_inferrable_pullback_only), x)
return (x > 0 ? Float64(x) : Float32(x), x > 0 ? Float64(Δx) : Float32(Δx))
end
function ChainRulesCore.rrule(::typeof(f_inferrable_pullback_only), x)
f_inferrable_pullback_only_pullback(Δy) = (NO_FIELDS, oftype(x, Δy))
return x > 0 ? Float64(x) : Float32(x), f_inferrable_pullback_only_pullback
end
frule_test(f_inferrable_pullback_only, (x, ẋ); check_inferred = true)
rrule_test(f_inferrable_pullback_only, z̄, (x, x̄); check_inferred = true)
end
end

@testset "test derivative conjugated in pullback" begin
Expand Down

2 comments on commit 6a38764

@sethaxen
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/27868

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.6.1 -m "<description of version>" 6a38764ce8b0eb79865f4a31ac8dbeadb8268e80
git push origin v0.6.1

Please sign in to comment.