Skip to content

Commit

Permalink
Improve the error message for constructors (#254)
Browse files Browse the repository at this point in the history
* show Type{Foo} rather than DataType in the rrule MethodError

* add rule to the docs
  • Loading branch information
mzgubic committed Jul 12, 2022
1 parent 22d8446 commit b105a9f
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 4 deletions.
24 changes: 22 additions & 2 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ test_scalar: relu at -0.5 | 11 11

## Testing constructors and functors (callable objects)

Testing constructor and functors works as you would expect. For struct `Foo`
Testing constructor and functors works as you would expect. For struct `Foo`,
```julia
struct Foo
a::Float64
Expand All @@ -127,7 +127,27 @@ Base.length(::Foo) = 1
Base.iterate(f::Foo) = iterate(f.a)
Base.iterate(f::Foo, state) = iterate(f.a, state)
```
the `f/rrule`s can be tested by

after defining the constructor and functor `f/rule`s,

```julia
function ChainRulesCore.rrule(::Type{Foo}, val) # constructor rrule
y = Foo(val)
Foo_pb(ΔFoo) = (NoTangent(), unthunk(ΔFoo).a)
return y, Foo_pb
end

function ChainRulesCore.rrule(foo::Foo, val) # functor rrule
y = foo(val)
function foo_pb(Δ)
Δut = unthunk(Δ)
return (Tangent{Foo}(;a=Δut), Δut)
end
return y, foo_pb
end
```

both `f/rrule`s can be tested by
```julia
test_rrule(Foo, rand()) # constructor

Expand Down
4 changes: 2 additions & 2 deletions src/testers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ function test_frule(
end

res = call_on_copy(frule_f, config, tangents, primals...)
res === nothing && throw(MethodError(frule_f, typeof(primals)))
res === nothing && throw(MethodError(frule_f, Tuple{Core.Typeof.(primals)...}))
@test_msg "The frule should return (y, ∂y), not $res." res isa Tuple{Any,Any}
Ω_ad, dΩ_ad = res
Ω = call_on_copy(primals...)
Expand Down Expand Up @@ -201,7 +201,7 @@ function test_rrule(
_test_inferred(rrule_f, config, primals...; fkwargs...)
end
res = rrule_f(config, primals...; fkwargs...)
res === nothing && throw(MethodError(rrule_f, typeof(primals)))
res === nothing && throw(MethodError(rrule_f, Tuple{Core.Typeof.(primals)...}))
y_ad, pullback = res
y = call(primals...)
test_approx(y_ad, y; isapprox_kwargs...) # make sure primal is correct
Expand Down

0 comments on commit b105a9f

Please sign in to comment.