Skip to content

Commit

Permalink
Merge pull request #107 from JuliaDiff/mz/docs
Browse files Browse the repository at this point in the history
Add two usage examples to documentation
  • Loading branch information
mzgubic committed Jan 19, 2021
2 parents 2e8f88b + b805973 commit dfe6704
Show file tree
Hide file tree
Showing 6 changed files with 133 additions and 3 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
dev/

# Files generated by invoking Julia with --code-coverage
*.jl.cov
*.jl.*.cov
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRulesTestUtils"
uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a"
version = "0.6.1"
version = "0.6.2"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
1 change: 1 addition & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"

Expand Down
127 changes: 126 additions & 1 deletion docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,136 @@


[ChainRulesTestUtils.jl](https://github.com/JuliaDiff/ChainRulesTestUtils.jl) helps you test [`ChainRulesCore.frule`](http://www.juliadiff.org/ChainRulesCore.jl/dev/api.html) and [`ChainRulesCore.rrule`](http://www.juliadiff.org/ChainRulesCore.jl/dev/api.html) methods, when adding rules for your functions in your own packages.

For information about ChainRules, including how to write rules, refer to the general ChainRules Documentation:
[![](https://img.shields.io/badge/docs-master-blue.svg)](https://JuliaDiff.github.io/ChainRulesCore.jl/dev)
[![](https://img.shields.io/badge/docs-stable-blue.svg)](https://JuliaDiff.github.io/ChainRulesCore.jl/stable)

## Canonical example

Let's suppose a custom transformation has been defined
```jldoctest ex; output = false
function two2three(x1::Float64, x2::Float64)
return 1.0, 2.0*x1, 3.0*x2
end
# output
two2three (generic function with 1 method)
```
along with the `frule`
```jldoctest ex; output = false
using ChainRulesCore
function ChainRulesCore.frule((Δf, Δx1, Δx2), ::typeof(two2three), x1, x2)
y = two2three(x1, x2)
∂y = Composite{Tuple{Float64, Float64, Float64}}(Zero(), 2.0*Δx1, 3.0*Δx2)
return y, ∂y
end
# output
```
and `rrule`
```jldoctest ex; output = false
function ChainRulesCore.rrule(::typeof(two2three), x1, x2)
y = two2three(x1, x2)
function two2three_pullback(Ȳ)
return (NO_FIELDS, 2.0*Ȳ[2], 3.0*Ȳ[3])
end
return y, two2three_pullback
end
# output
```

The [`frule_test`](@ref)/[`rrule_test`](@ref) helper function compares the `frule`/`rrule` outputs
to the gradients obtained by finite differencing.
They can be used for any type and number of inputs and outputs.

### Testing the `frule`

[`frule_test`](@ref) takes in the function `f` and tuples `(x, ẋ)` for each function argument `x`.
The call will test the `frule` for function `f` at the point `x` in the domain. Keep
this in mind when testing discontinuous rules for functions like
[ReLU](https://en.wikipedia.org/wiki/Rectifier_(neural_networks)), which should ideally
be tested at both `x` being above and below zero.
Additionally, choosing `` in an unfortunate way (e.g. as zeros) could hide
underlying problems with the defined `frule`.

```jldoctest ex; output = false
using ChainRulesTestUtils
x1, x2 = (3.33, -7.77)
ẋ1, ẋ2 = (rand(), rand())
frule_test(two2three, (x1, ẋ1), (x2, ẋ2))
# output
Test Summary: | Pass Total
Tuple{Float64,Float64,Float64}.1 | 1 1
Test Summary: | Pass Total
Tuple{Float64,Float64,Float64}.2 | 1 1
Test Summary: | Pass Total
Tuple{Float64,Float64,Float64}.3 | 1 1
Test Passed
```

### Testing the `rrule`

[`rrule_test`](@ref) takes in the function `f`, sensitivities of the function outputs ``,
and tuples `(x, x̄)` for each function argument `x`.
`` is the accumulated adjoint which can be set arbitrarily.
The call will test the `rrule` for function `f` at the point `x`, and similarly to
`frule` some rules should be tested at multiple points in the domain.
Choosing `` in an unfortunate way (e.g. as zeros) could hide underlying problems with
the `rrule`.
```jldoctest ex; output = false
x1, x2 = (3.33, -7.77)
x̄1, x̄2 = (rand(), rand())
ȳs = (rand(), rand(), rand())
rrule_test(two2three, ȳs, (x1, x̄1), (x2, x̄2))
# output
Test Summary: |
Don't thunk only non_zero argument | No tests
Test.DefaultTestSet("Don't thunk only non_zero argument", Any[], 0, false)
```

## Scalar example

For functions with a single argument and a single output, such as e.g. ReLU,
```jldoctest ex; output = false
function relu(x::Real)
return max(0, x)
end
# output
relu (generic function with 1 method)
```
with the `frule` and `rrule` defined with the help of `@scalar_rule` macro
```jldoctest ex; output = false
@scalar_rule relu(x::Real) x <= 0 ? zero(x) : one(x)
# output
```

`test_scalar` function is provided to test both the `frule` and the `rrule` with a single
call.
```jldoctest ex; output = false
test_scalar(relu, 0.5)
test_scalar(relu, -0.5)
# output
Test Summary: | Pass Total
relu at 0.5, with tangent 1.0 | 3 3
Test Summary: | Pass Total
relu at 0.5, with cotangent 1.0 | 4 4
Test Summary: | Pass Total
relu at -0.5, with tangent 1.0 | 3 3
Test Summary: | Pass Total
relu at -0.5, with cotangent 1.0 | 4 4
```


# API Documentation

```@autodocs
Expand Down
2 changes: 1 addition & 1 deletion src/iterator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ The iterator wraps another iterator `data`, such as an array, that must have at
many features implemented as the test iterator and have a `FiniteDifferences.to_vec`
overload. By default, the iterator it has the same features as `data`.
The optional methods `eltype`, length`, and `size` are automatically defined and forwarded
The optional methods `eltype`, `length`, and `size` are automatically defined and forwarded
to `data` if the type arguments indicate that they should be defined.
"""
struct TestIterator{T,IS,IE}
Expand Down
2 changes: 2 additions & 0 deletions src/testers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ function frule_test(f, xẋs::Tuple{Any, Any}...; rtol::Real=1e-9, atol::Real=1e
end
res = frule((NO_FIELDS, deepcopy(ẋs)...), f, deepcopy(xs)...; deepcopy(fkwargs)...)
res === nothing && throw(MethodError(frule, typeof((f, xs...))))
res isa Tuple || error("The frule should return (y, ∂y), not $res.")
Ω_ad, dΩ_ad = res
Ω = f(deepcopy(xs)...; deepcopy(fkwargs)...)
check_equal(Ω_ad, Ω; isapprox_kwargs...)
Expand Down Expand Up @@ -280,6 +281,7 @@ function rrule_test(f, ȳ, xx̄s::Tuple{Any, Any}...; rtol::Real=1e-9, atol::Re

check_inferred && _test_inferred(pullback, ȳ)
∂s = pullback(ȳ)
∂s isa Tuple || error("The pullback must return (∂self, ∂args...), not $∂s.")
∂self = ∂s[1]
x̄s_ad = ∂s[2:end]
@test ∂self === NO_FIELDS # No internal fields
Expand Down

0 comments on commit dfe6704

Please sign in to comment.