Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add early stopping utils #1545

Merged
merged 20 commits into from
May 14, 2021
Merged

Add early stopping utils #1545

merged 20 commits into from
May 14, 2021

Conversation

queensferryme
Copy link
Contributor

@queensferryme queensferryme commented Mar 24, 2021

This pull request introduces a utility function early_stopping(f; min_delta=0, patience=3) for conveniently implementing early stopping. Its motive was discussed in #227 back in 2018. Also AFAIU, early stopping is widely adopted in training large-scale neural networks nowadays as a mechanism to prevent overfitting. Therefore, I believe a built-in utility function for early stopping would be beneficial for Flux.jl users.

The implementation is heavily based on tf.keras.callbacks.EarlyStopping and ignite.handlers.EarlyStopping.

This is a draft implementation. Documentation and tests are not added yet. As I am new to Flux.jl, I think advice and opinions from maintainers would be helpful before I proceed.

Example Usage:

# test_ea.jl
using Flux

function loss()
    v = 1
    return () -> v += 1
end

ea = Flux.early_stopping(loss(), patience=5)
Flux.@epochs 10 begin
    ea() || break
end

Output:

➜ julia test_ea.jl
 Activating environment at `~/Developer/Flux.jl/Project.toml`
[ Info: Epoch 1
[ Info: Epoch 2
[ Info: Epoch 3
[ Info: Epoch 4
[ Info: Epoch 5
[ Info: Epoch 6
➜

PR Checklist

  • Tests are added
  • Entry in NEWS.md
  • Documentation, if applicable
  • API changes require approval from a committer (different from the author, if applicable)

@DhairyaLGandhi
Copy link
Member

DhairyaLGandhi commented Mar 24, 2021

Thanks for the contribution! Seems generally handy and useful. Do you think it would make more sense as a callback?

@DhairyaLGandhi
Copy link
Member

We can always do it that way later though. For now some docs and tests would be helpful.

@DhairyaLGandhi
Copy link
Member

DhairyaLGandhi commented Mar 24, 2021

Something like this might make sense too once we merge #1471

struct EarlyStopping
  l::AbstractVector
  patience
  tolerance
end

EarlyStopping(memory, patience, tolerance) = EarlyStopping(CircularBuffer(memory), patience, tolerance)

function (es::EarlyStopping)(l, gs, d)
  roi = es.l[end-patience:end]
  if l within tolerance
    push!(es.l, l)
  else
    Flux.stop()
  end
end

Copy link
Member

@darsnack darsnack left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your contribution. Can you add a doc string, entry in the documentation, and some tests?

src/utils.jl Outdated Show resolved Hide resolved
@darsnack
Copy link
Member

No need to limit this utility by pigeon-holing it into a callback-biased implementation. I think the current approach in the PR is best. It can easily be used in a CB like:

ea = Flux.early_stopping(loss(), patience=5)
cb = () -> ea() && Flux.stop()

Perhaps we can have a version of stop() that accepts a function. i.e. Flux.stop(f) returns a function that calls Flux.stop() when f() returns true.

@DhairyaLGandhi
Copy link
Member

Implementing it as a callback doesn't take away anything from it being used outside of a callback context.

We should also generally want this to be decoupled from the loss function. It would hold unnecessary references alive.

Anyway this would be expected to be used within a train loop 99% of the times, so it's worth it to make that api first class.

@darsnack
Copy link
Member

darsnack commented Mar 24, 2021

The callback system uses Flux.stop() whereas this just returns true when a condition is met. This implementation is more generic because it can be used in more than Flux.train!. There are other ways to make the Flux.train! API "first-class" without limiting these utilities to only be used in the stopping implementation used by Flux.train!.

@darsnack
Copy link
Member

Anyways, let's not open an argument. Like you said, we can do the callback (or not callback) later.

@darsnack
Copy link
Member

As I am new to Flux.jl, I think advice and opinions from maintainers would be helpful before I proceed.

The tests should be added under test/utils.jl. You can probably do a test with a pseudo-loss function that you know will hit the early stop criterion in n iterations, run the loop with the early stopping criterion in place. Then test if n matches the number you expect.

Docs will include a docstring and an entry in docs/src/utilities.md.

src/utils.jl Outdated Show resolved Hide resolved
@queensferryme
Copy link
Contributor Author

Implementing it as a callback doesn't take away anything from it being used outside of a callback context.

We should also generally want this to be decoupled from the loss function. It would hold unnecessary references alive.

Anyway this would be expected to be used within a train loop 99% of the times, so it's worth it to make that api first class.

I have considered implementing a callback function for early stopping, which seems more natural and integrated into the current Flux workflow. However, here are some considerations that drive me to this closure-based implementation:

  1. In my experience, most early stopping happens between loops instead of between batches (within one train loop) because loss generally tends to oscillate greatly between batches. That means early stopping bwtween batches is likely to break the entire training process even when the model is still underfitting.

  2. As far as I understand, Flux.stop only terminates the current train loop, whereas in early stopping we'd like that the current train loop along with any follow-up loops are all stopped. Since there aren't any inter-loop hooks for Flux right now, I think manually breaking from loops like ea() || break becomes the only choice.

Copy link
Member

@darsnack darsnack left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking good. I think once you figure out the increasing/decreasing metric part, then this will likely be ready to go.

Your feedback on early stopping is appreciated. I think it is just more proof that we should do the most generic thing in this context, which is to write a utility that returns true or false. It will make the utility more flexible to use a variety of contexts. There are several options for how we integrate a boolean utility into the rest of the callback system.

src/utils.jl Outdated Show resolved Hide resolved
src/utils.jl Outdated Show resolved Hide resolved
src/utils.jl Outdated Show resolved Hide resolved
@DhairyaLGandhi
Copy link
Member

re 2. since Flux.stop throws, it will actually terminate all loops outside of train and will break the loop when used inside train, so it can be used either way.

@darsnack
Copy link
Member

It can be used either way if the user is willing to a) write boiler plate try/catch code or b) accept an ugly exception message. A boolean utility truly can be used however the user wants, because it makes no impositions on the user.

src/utils.jl Outdated Show resolved Hide resolved
@queensferryme
Copy link
Contributor Author

re 2. since Flux.stop throws, it will actually terminate all loops outside of train and will break the loop when used inside train, so it can be used either way.

My personal experience tells otherwise:
Screenshot

I believe this is because StopException and SkipException is captured within Flux.train!

catch ex
if ex isa StopException
break
elseif ex isa SkipException
continue
else
rethrow(ex)
end
end

Maybe this is an unexpected behavior?

@darsnack
Copy link
Member

No that's expected. I think what Dhairya was saying is that if you call the callback in your outer loop (over epochs) then it will throw an exception and stop the program. That's not great IMO because it means you either accept an ugly exception, or you write some boilerplate try/catch code to make it pretty. I don't see why we would want the user to write more boilerplate when ea() || break works just fine.

@queensferryme
Copy link
Contributor Author

Correct, it could get a condition function maybe which can operate on a fixed size of vector of some quantity of interest (typically loss).

I am also considering @DhairyaLGandhi's suggestion. Maybe we could provide a variant of early_stopping that holds a fixed size buffer of the last n collected metrics, as well as a condition function that operates on the buffer to decide whether we should early stop now. This could add more flexibility to the behavior of early stopping.

I will try implement if this sounds like a good idea to you both @DhairyaLGandhi @darsnack.

src/utils.jl Outdated Show resolved Hide resolved
@darsnack
Copy link
Member

Maybe we could provide a variant of early_stopping that holds a fixed size buffer of the last n collected metrics, as well as a condition function that operates on the buffer to decide whether we should early stop now.

My only concern here is calling this "early stopping," since it doesn't match the accepted understanding of the term. To @DhairyaLGandhi's point, the current approach doesn't stop anything, it just returns a boolean that can be used to stop something. If we were to move to this more generic and powerful approach, I would suggest we consider renaming the function.

The other alternative is to build this utility for applying a threshold on a window of values, then make early_stopping be syntactic sugar for passing the correct value for condition so that the it matches the conventional understanding.

src/utils.jl Outdated Show resolved Hide resolved
`true` to stop, `false` to continue
@queensferryme
Copy link
Contributor Author

The other alternative is to build this utility for applying a threshold on a window of values, then make early_stopping be syntactic sugar for passing the correct value for condition so that the it matches the conventional understanding.

Sounds somewhat troublesome and complicated to use 😂 I think I will stick with this implemention for now, which seems more intuitive.

@darsnack
Copy link
Member

darsnack commented Apr 8, 2021

Since the current design allows for more than just stopping, do we still want to consider a name change? PyTorch's ReduceLROnPlateau is an example that uses the same "early stopping" logic to decide when to advance the schedule. We have the opportunity to do something more generic here. ReduceLROnPlateau is hardcoded to a specific criterion for advancing the schedule, and a specific type of schedule. We can write a more generic version (see here) that works with any underlying schedule and predicate for advancing.

But if we use the utility here as our predicate, then the name early_stopping is weird. And in general, this utility doesn't do any stopping itself.

src/utils.jl Outdated Show resolved Hide resolved
@queensferryme
Copy link
Contributor Author

I also thought we were going with patience instead of plateau? The former can be better associated with being more general purpose than plateau would indicate.

But this utility specifically will always be a "plateau." You could turn the patience off completely by setting it to 1. But it's still a plateau because the utility measures when the distance between two points is small enough (by definition a "plateau"). BTW I am not tied to plateau. I just think patience is too uninformative.

For the current implementation, yes it will always be a plateau, and patience may seem too general in this case. In my opinion, a general patience utility should have only three components: a maximum patience, an internal count as well as a predicate that operates over a set of variables and returns eithr true or false. When the predicate says true, the internal count is increased by one; otherwise it is reset to zero.

So, in our case, plateau is more like a specialization of patience, where the predicate is reduced to delta < min_delta. The current plateau will not be able to do something like "stop the training when the loss has been below some threshold for n times". We can do this more general patience thing in another pull request and provide some pre-baked routines maybe, if you are interested.

src/utils.jl Outdated Show resolved Hide resolved
src/utils.jl Outdated Show resolved Hide resolved
@darsnack
Copy link
Member

darsnack commented May 2, 2021

What about something like:

function patience(predicate, patience)
  on_trigger = let count = 0
   (args...; kwargs...) -> begin
      count = predicate(args...; kwargs...) ? count + 1 : 0

      return count >= patience
    end
  end

  return on_trigger
end

function plateau(f, p; distance = -, min_dist = 0, init_score = 0)
  on_plateau = let best_score = init_score
    (args...; kwargs...) -> begin
        score = f(args...; kwargs...)
        Δ = distance(best_score, score)
        best_score = Δ < 0 ? best_score : score

        return Δ < min_dist
     end
  end

  return patience(on_plateau, p)
end

const early_stopping = plateau

Personally, I think the patience logic is fairly simple. The hard part is the score closure. If it wasn't for the best score tracking, then I would suggest completely foregoing plateau in the example above. But I think the logic in plateau is cumbersome enough that it warrants its own function.

@queensferryme
Copy link
Contributor Author

What about something like:

function patience(predicate, patience)
  on_trigger = let count = 0
   (args...; kwargs...) -> begin
      count = predicate(args...; kwargs...) ? count + 1 : 0

      return count >= patience
    end
  end

  return on_trigger
end

function plateau(f, p; distance = -, min_dist = 0, init_score = 0)
  on_plateau = let best_score = init_score
    (args...; kwargs...) -> begin
        score = f(args...; kwargs...)
        Δ = distance(best_score, score)
        best_score = Δ < 0 ? best_score : score

        return Δ < min_dist
     end
  end

  return patience(on_plateau, p)
end

const early_stopping = plateau

Personally, I think the patience logic is fairly simple. The hard part is the score closure. If it wasn't for the best score tracking, then I would suggest completely foregoing plateau in the example above. But I think the logic in plateau is cumbersome enough that it warrants its own function.

This looks great!

Copy link
Member

@darsnack darsnack left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking good, thanks. I noticed a couple questions that I missed before.

docs/src/utilities.md Show resolved Hide resolved
src/utils.jl Outdated Show resolved Hide resolved
src/utils.jl Outdated Show resolved Hide resolved
src/utils.jl Outdated Show resolved Hide resolved
src/utils.jl Outdated Show resolved Hide resolved
Copy link
Member

@darsnack darsnack left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking really good. I like the doc additions. I think we should be GTM after this last round of revisions.

cc @DhairyaLGandhi so he can look over things

docs/src/utilities.md Outdated Show resolved Hide resolved
docs/src/utilities.md Outdated Show resolved Hide resolved
docs/src/utilities.md Outdated Show resolved Hide resolved
docs/src/utilities.md Show resolved Hide resolved
docs/src/utilities.md Show resolved Hide resolved
src/utils.jl Outdated Show resolved Hide resolved
src/utils.jl Outdated Show resolved Hide resolved
test/utils.jl Outdated Show resolved Hide resolved
test/utils.jl Show resolved Hide resolved
darsnack
darsnack previously approved these changes May 6, 2021
@DhairyaLGandhi
Copy link
Member

Cool, I'll give this a once over this weekend

docs/src/utilities.md Outdated Show resolved Hide resolved
docs/src/utilities.md Outdated Show resolved Hide resolved
src/utils.jl Outdated Show resolved Hide resolved
src/utils.jl Outdated
"""
function patience(predicate, wait)
on_trigger = let count = 0
(args...; kwargs...) -> begin
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

on_ trigger should be the name of the function imo. Good to have it named for stacktraces

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How is that done with a let?

Like

let ...
  function name(...)
    ...
  end
end

?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I believe so

docs/src/utilities.md Outdated Show resolved Hide resolved
@darsnack
Copy link
Member

@queensferryme If you can address the latest comments when you get a chance, we can finish this PR up! Thanks for all your work and patience.

@darsnack
Copy link
Member

Great, thanks @queensferryme!

bors r+

bors bot added a commit that referenced this pull request May 14, 2021
1545: Add early stopping utils r=darsnack a=queensferryme

This pull request introduces a utility function `early_stopping(f; min_delta=0, patience=3)` for conveniently implementing early stopping. Its motive was discussed in #227 back in 2018. Also AFAIU, early stopping is widely adopted in training large-scale neural networks nowadays as a mechanism to prevent overfitting. Therefore, I believe a built-in utility function for early stopping would be beneficial for Flux.jl users.

The implementation is heavily based on [tf.keras.callbacks.EarlyStopping](https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/EarlyStopping) and [ignite.handlers.EarlyStopping](https://github.com/pytorch/ignite/blob/master/ignite/handlers/early_stopping.py).

> This is a _draft_ implementation. Documentation and tests are not added yet. As I am new to Flux.jl, I think advice and opinions from maintainers would be helpful before I proceed.

**Example Usage**:
```julia
# test_ea.jl
using Flux

function loss()
    v = 1
    return () -> v += 1
end

ea = Flux.early_stopping(loss(), patience=5)
Flux.@epochs 10 begin
    ea() || break
end
```

**Output**:

```
➜ julia test_ea.jl
 Activating environment at `~/Developer/Flux.jl/Project.toml`
[ Info: Epoch 1
[ Info: Epoch 2
[ Info: Epoch 3
[ Info: Epoch 4
[ Info: Epoch 5
[ Info: Epoch 6
➜
```

### PR Checklist

- [x] Tests are added
- [ ] Entry in NEWS.md
- [x] Documentation, if applicable
- [ ] API changes require approval from a committer (different from the author, if applicable)


Co-authored-by: Queensferry <queensferry.me@gmail.com>
Co-authored-by: Dhairya Gandhi <dhairya@juliacomputing.com>
@DhairyaLGandhi
Copy link
Member

Thanks all

@DhairyaLGandhi DhairyaLGandhi merged commit e6629dd into FluxML:master May 14, 2021
@queensferryme
Copy link
Contributor Author

@DhairyaLGandhi @darsnack Thanks for your patience and time!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants