Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add early stopping utils #1545

Merged
merged 20 commits into from
May 14, 2021
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/src/utilities.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,4 +102,5 @@ Flux.nfan
Flux.throttle
Flux.stop
Flux.skip
Flux.plateau
```
56 changes: 56 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -742,3 +742,59 @@ modules(m) = [x for x in Functors.fcollect(m) if !isleaflike(x)]
isleaflike(x) = Functors.isleaf(x)
isleaflike(::Tuple{Vararg{<:Number}}) = true
isleaflike(::Tuple{Vararg{<:AbstractArray{<:Number}}}) = true

"""
plateau(f, patience; delta = -, min_delta = 0)

Return a function that internally counts by one when
`delta(best_score, f(...)) <= min_delta` where
`best_score` is the last seen best value of `f(...)`.
If the count is greater than or equal to `patience`,
the function returns `true`, otherwise it returns `false`.
The count is reset when `delta(best_score, f(...)) > min_delta`.

The keyword argument `delta` is a function of the form
`delta(best_score, current_score)`.
If you are using some increasing metric (e.g. accuracy),
you can customize the `delta` function:
`(best_score, score) -> score - best_score`.
darsnack marked this conversation as resolved.
Show resolved Hide resolved

A common use case of `plateau` is early stopping. For this,
we have added `early_stopping` as an alias to `plateau`.
Note that you can do more generic things with `plateau`,
for example reducing the learning rate when the training loss plateaus.

# Examples
```jldoctest
julia> l = 0;

julia> function loss()
global l += 1
end; # pseudo loss function that returns increasing values

julia> es = Flux.early_stopping(loss, 3);

julia> Flux.@epochs 10 begin
es() && break
end
[ Info: Epoch 1
[ Info: Epoch 2
[ Info: Epoch 3
```
"""
function plateau(f, patience; delta = -, init_score = 0, min_delta = 0)
best_score = init_score
count = 0

let best_score = best_score, count = count
darsnack marked this conversation as resolved.
Show resolved Hide resolved
function on_plateau(args...; kwargs...)
score = f(args...; kwargs...)
Δ = delta(best_score, score)
count = Δ < min_delta ? count + 1 : 0
best_score = Δ < 0 ? best_score : score
return count >= patience
end
end
end

const early_stopping = plateau
67 changes: 67 additions & 0 deletions test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -418,3 +418,70 @@ end
LayerNorm(8)))
@test length(modules) == 5
end

@testset "Plateau" begin
function metric(; step = 1)
l = 0
return () -> l += step
end
darsnack marked this conversation as resolved.
Show resolved Hide resolved

@testset "args & kwargs" begin
es = Flux.plateau((x; y = 1) -> x + y, 10; min_delta=2)

n_iter = 0
while n_iter < 99
es(-n_iter; y=-n_iter) && break
n_iter += 1
end

@test n_iter == 99
end

@testset "delta" begin
es = Flux.plateau(metric(), 10; delta=(best_score, score) -> score - best_score)

n_iter = 0
while n_iter < 99
es() && break
n_iter += 1
end

@test n_iter == 99
end

@testset "init score" begin
es = Flux.plateau(metric(), 10; init_score=10)

n_iter = 0
while n_iter < 99
es() && break
n_iter += 1
end

@test n_iter == 10
end

@testset "min delta" begin
es = Flux.plateau(metric(step=-2), 10; min_delta=1)

n_iter = 0
while n_iter < 99
es() && break
n_iter += 1
end

@test n_iter == 99
end

@testset "patience" begin
es = Flux.plateau(metric(), 10)

n_iter = 0
while n_iter < 99
es() && break
n_iter += 1
end

@test n_iter == 9
end
end