Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add early stopping utils #1545

Merged
merged 20 commits into from
May 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions docs/src/utilities.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,73 @@ Flux.throttle
Flux.stop
Flux.skip
```

queensferryme marked this conversation as resolved.
Show resolved Hide resolved
## Patience Helpers

Flux provides utilities for controlling your training procedure according to some monitored condition and a maximum `patience`. For example, you can use `early_stopping` to stop training when the model is converging or deteriorating, or you can use `plateau` to check if the model is stagnating.

darsnack marked this conversation as resolved.
Show resolved Hide resolved
For example, below we create a pseudo-loss function that decreases, bottoms out, then increases. The early stopping trigger will break the loop before the loss increases too much.
```julia
# create a pseudo-loss that decreases for 4 calls, then starts increasing
# we call this like loss()
loss = let t = 0
() -> begin
t += 1
(t - 4) ^ 2
end
end

# create an early stopping trigger
# returns true when the loss increases for two consecutive steps
es = early_stopping(loss, 2; init_score = 9)

# this will stop at the 6th (4 decreasing + 2 increasing calls) epoch
@epochs 10 begin
es() && break
end
```

The keyword argument `distance` of `early_stopping` is a function of the form `distance(best_score, score)`. By default `distance` is `-`, which implies that the monitored metric `f` is expected to be decreasing and mimimized. If you use some increasing metric (e.g. accuracy), you can customize the `distance` function: `(best_score, score) -> score - best_score`.
```julia
# create a pseudo-accuracy that increases by 0.01 each time from 0 to 1
# we call this like acc()
acc = let v = 0
() -> v = max(1, v + 0.01)
end

# create an early stopping trigger for accuracy
es = early_stopping(acc, 3; delta = (best_score, score) -> score - best_score)

# this will iterate until the 10th epoch
@epochs 10 begin
es() && break
end
```

`early_stopping` and `plateau` are both built on top of `patience`. You can use `patience` to build your own triggers that use a patient counter. For example, if you want to trigger when the loss is below a threshold for several consecutive iterations:
```julia
threshold(f, thresh, delay) = patience(delay) do
f() < thresh
end
```

darsnack marked this conversation as resolved.
Show resolved Hide resolved
Both `predicate` in `patience` and `f` in `early_stopping` / `plateau` can accept extra arguments. You can pass such extra arguments to `predicate` or `f` through the returned function:
```julia
trigger = patience((a; b) -> a > b, 3)

# this will iterate until the 10th epoch
@epochs 10 begin
trigger(1; b = 2) && break
end

# this will stop at the 3rd epoch
@epochs 10 begin
trigger(3; b = 2) && break
end
```

```@docs
Flux.patience
Flux.early_stopping
Flux.plateau
```
113 changes: 113 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -742,3 +742,116 @@ modules(m) = [x for x in Functors.fcollect(m) if !isleaflike(x)]
isleaflike(x) = Functors.isleaf(x)
isleaflike(::Tuple{Vararg{<:Number}}) = true
isleaflike(::Tuple{Vararg{<:AbstractArray{<:Number}}}) = true

"""
patience(predicate, wait)

Return a function that internally counts by one when
`predicate(...) == true`, otherwise the count is reset to zero.
If the count is greater than or equal to `wait`,
the function returns `true`, otherwise it returns `false`.

# Examples
```jldoctest
julia> loss() = rand();

julia> trigger = Flux.patience(() -> loss() < 1, 3);

julia> Flux.@epochs 10 begin
trigger() && break
end
[ Info: Epoch 1
[ Info: Epoch 2
[ Info: Epoch 3
```
"""
function patience(predicate, wait)
let count = 0
function on_trigger(args...; kwargs...)
count = predicate(args...; kwargs...) ? count + 1 : 0

return count >= wait
end
end
end

"""
early_stopping(f, delay; distance = -, init_score = 0, min_dist = 0)

Return a function that internally counts by one when
`distance(best_score, f(...)) <= min_dist`, where
`best_score` is the last seen best value of `f(...)`.
If the count is greater than or equal to `delay`,
the function returns `true`, otherwise it returns `false`.
The count is reset when `distance(best_score, f(...)) > min_dist`.

# Examples
```jldoctest
julia> loss = let l = 0
() -> l += 1
end; # pseudo loss function that returns increasing values

julia> es = Flux.early_stopping(loss, 3);

julia> Flux.@epochs 10 begin
es() && break
end
[ Info: Epoch 1
[ Info: Epoch 2
[ Info: Epoch 3
```
"""
function early_stopping(f, delay; distance = -, init_score = 0, min_dist = 0)
trigger = let best_score = init_score
(args...; kwargs...) -> begin
score = f(args...; kwargs...)
Δ = distance(best_score, score)
best_score = Δ < 0 ? best_score : score

return Δ < min_dist
end
end

return patience(trigger, delay)
end

"""
plateau(f, width; distance = -, init_score = 0, min_dist = 1f-6)

Return a function that internally counts by one when
`abs(distance(last_score, f(...))) <= min_dist`, where
`last_score` holds the last value of `f(...)`.
If the count is greater than or equal to `width`,
the function returns `true`, otherwise it returns `false`.
The count is reset when `abs(distance(last_score, f(...))) > min_dist`.

# Examples
```jldoctest
julia> f = let v = 10
() -> v = v / abs(v) - v
end; # -9, 8, -7, 6, ...

julia> trigger = Flux.plateau(f, 3; init_score=10, min_dist=18);

julia> Flux.@epochs 10 begin
trigger() && break
end
[ Info: Epoch 1
[ Info: Epoch 2
[ Info: Epoch 3
[ Info: Epoch 4
```
"""
function plateau(f, width; distance = -, init_score = 0, min_dist = 1f-6)
is_plateau = let last_score = init_score
(args...; kwargs...) -> begin
score = f(args...; kwargs...)
Δ = abs(distance(last_score, score))
last_score = score

return Δ < min_dist
end
end

return patience(is_plateau, width)
end
77 changes: 77 additions & 0 deletions test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -418,3 +418,80 @@ end
LayerNorm(8)))
@test length(modules) == 5
end

@testset "Patience triggers" begin
@testset "patience" begin
trigger = Flux.patience(() -> true, 3)

@test trigger() == false
@test trigger() == false
@test trigger() == true

v = [false, true, false, true, true, true]
trigger = let v = v
Flux.patience(i -> v[i], 3)
end

n_iter = 0
for i in 1:length(v)
trigger(i) && break
n_iter += 1
end

@test n_iter == 5
end
darsnack marked this conversation as resolved.
Show resolved Hide resolved

@testset "early stopping" begin
@testset "args & kwargs" begin
es = Flux.early_stopping((x; y = 1) -> x + y, 10; min_dist=3)

n_iter = 0
while n_iter < 99
es(-n_iter; y=-n_iter) && break
n_iter += 1
end

@test n_iter == 9
end

@testset "distance" begin
es = Flux.early_stopping(identity, 10; distance=(best_score, score) -> score - best_score)

n_iter = 0
while n_iter < 99
es(n_iter) && break
n_iter += 1
end

@test n_iter == 99
end

@testset "init_score" begin
es = Flux.early_stopping(identity, 10; init_score=10)

n_iter = 0
while n_iter < 99
es(n_iter) && break
n_iter += 1
end

@test n_iter == 10
end
end

@testset "plateau" begin
f = let v = 10
() -> v = v / abs(v) - v
end

trigger = Flux.plateau(f, 3, init_score=10, min_dist=18)

n_iter = 0
while n_iter < 99
trigger() && break
n_iter += 1
end

@test n_iter == 3
end
end