Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add early stopping utils #1545

Merged
merged 20 commits into from
May 14, 2021
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/src/utilities.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,4 +102,5 @@ Flux.nfan
Flux.throttle
Flux.stop
Flux.skip
Flux.early_stopping
```
44 changes: 44 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -742,3 +742,47 @@ modules(m) = [x for x in Functors.fcollect(m) if !isleaflike(x)]
isleaflike(x) = Functors.isleaf(x)
isleaflike(::Tuple{Vararg{<:Number}}) = true
isleaflike(::Tuple{Vararg{<:AbstractArray{<:Number}}}) = true

"""
early_stopping(f; delta=-, min_delta=0, patience=3)

Return a function that evaluates the metric `f` and compares its value
against its value on last invocation. When the difference has been less
than `min_delta` for at least `patience` times, `true` is returned,
otherwise `false` is returned.

By default, `early_stopping` expects the metric `f` to be minimized.
darsnack marked this conversation as resolved.
Show resolved Hide resolved
However, if you are using some increasing metric, accuracy for example,
you can change `early_stopping`'s behaviour by customizing the `delta`
function: `(best_score, score) -> score - best_score`.
darsnack marked this conversation as resolved.
Show resolved Hide resolved

# Examples
```jldoctest
julia> function loss()
l = 0
return () -> l += 1
darsnack marked this conversation as resolved.
Show resolved Hide resolved
end # pseudo loss function that returns increasing values
loss (generic function with 1 method)

julia> es = Flux.early_stopping(loss(); patience=3);

julia> Flux.@epochs 10 begin
es() || break
end
[ Info: Epoch 1
[ Info: Epoch 2
[ Info: Epoch 3
```
"""
function early_stopping(f; delta = -, min_delta = 0, patience = 3)
best_score = f()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might want to have a way to pass f some args. We also don't want to have to call f() twice here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The second call only happens when the closure is called? This call only happens once to initialize best_score. Since the correct initial value will depend on f and delta, maybe it would be better to pass best_score in as a keyword argument. The default can be f().

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, what args were you thinking of passing to f? The only additional variables introduced into the scope are patience and best_score. Just trying to understand how the additional args might be used.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well the closure would be called frequently. Optimising here makes sense.

I was referring to passing arguments to f - be that the loss function or for delayed evaluation etc. Having the general case here is nicer.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh are you saying something like

return function (args...)
    score = f(args...)
    # etc
end

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good idea. Should we just make it args + kwargs?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well the closure would be called frequently. Optimising here makes sense.

Yeah but you need the updated score from f. What would you optimize out? Could you share a snippet of what you mean? I am just confused what you want to eliminate.

count = 0

return function ()
score = f()
Δ = delta(best_score, score)
count = Δ < min_delta ? count + 1 : 0
best_score = Δ < 0 ? best_score : score
return count < patience
darsnack marked this conversation as resolved.
Show resolved Hide resolved
end
end
40 changes: 40 additions & 0 deletions test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -418,3 +418,43 @@ end
LayerNorm(8)))
@test length(modules) == 5
end

@testset "Early stopping" begin
function metric(; step = 1)
l = 0
return () -> l += step
end
darsnack marked this conversation as resolved.
Show resolved Hide resolved

@testset "delta" begin
es = Flux.early_stopping(metric(); delta=(best_score, score) -> score - best_score)

n_iter = 0
while n_iter < 99
es() ? n_iter += 1 : break
end

@test n_iter == 99
end

@testset "min delta" begin
es = Flux.early_stopping(metric(step=-2); min_delta=1)

n_iter = 0
while n_iter < 99
es() ? n_iter += 1 : break
end

@test n_iter == 99
end

@testset "patience" begin
es = Flux.early_stopping(metric(); patience=10)

n_iter = 0
while n_iter < 99
es() ? n_iter += 1 : break
end

@test n_iter == 9
end
end