Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Synchronize model checkpointing and servable model saving #29

Closed
monatis opened this issue Feb 10, 2022 · 13 comments
Closed

Synchronize model checkpointing and servable model saving #29

monatis opened this issue Feb 10, 2022 · 13 comments
Assignees
Labels
enhancement New feature or request help wanted Extra attention is needed wontfix This will not be worked on

Comments

@monatis
Copy link
Contributor

monatis commented Feb 10, 2022

Currently, TrainableModel.save_servable() is called by the user at the end of the training loop. This is problematic because we may end up with saving an overfitted state of the model even if we are trying to monitor an evaluation metric with pl.callbacks.ModelCheckpoint. So we need to come up with a way to synchronize both.

Possible solution

  • We may need to subclass ModelCheckpoint inside quaterion for synchronization.
  • We may accept additional keyword arguments in Quaterion.fit to automatically save a servable checkpoints to the specified directory with a specified interval.
@monatis monatis added enhancement New feature or request help wanted Extra attention is needed labels Feb 10, 2022
@generall
Copy link
Member

@joein WDYT?

@joein
Copy link
Member

joein commented Feb 11, 2022

Previously we loaded from best checkpoint after Quaterion.fit and then called save_servable(). We obviously don't want user to implement it each time at the end of training.

From the first sight I'd rather not save two versions of checkpoints simultaneously, it may be space consuming and also makes a little mess.
If we want to inherit from ModelCheckpoint and use it somewhere internally, we need to figure out if user currently use ModelCheckpoint in training or not (we can do it though, e.g. look at a list of callbacks) and specify parameters somehow (copy them from origin checkpoint or pass them in fit?

Maybe we can receive ModelCheckpoint instance as an argument in save_servable() if user wants to save from best checkpoint otherwise save last model's state? The problem here is that user almost always wants to save the best available model, and save last model's state as default can be ambiguous.

@monatis
Copy link
Contributor Author

monatis commented Feb 11, 2022

Ok, so my proposal is as follows:

  • Implement a SaveServableCallback, hook into on_fit_end, restore and save the checkpoint as servable.
  • In TrainableModel, hook into configure_callbacks, return SaveServableCallback, plus ModelCheckpoint if the current list of callbacks doesn't include an instance of it already.
  • In Quaterion.fit, accept a keyword argument save_path: str = 'default/path/to/save/servable', and set it on TrainableModel before calling Trainer.fit so that SaveServableCallback returned from configure_callbacks can pick it up.

I think such a design will

  1. minimize the effort expected from the user,
  2. be still highly configurable,
  3. respect to the user's preferences.

WDYT?

@joein
Copy link
Member

joein commented Feb 11, 2022

configure_callbacks is a part of pl.LightningModule and user may want to override it, thus he can broke SaveServableCallback logic, doesn't he? The way out could be use configure_save instead of configure_callbacks (or any more appropriate name) like we did with configure_caches

I think that addition of ModelCheckpoint if it is not presented may be counterintuitive, if user does not expect to get the best checkpoint, why should we return it to him?

I like the overall idea with callback and separate path for servable, we can try to develop this idea further

@monatis
Copy link
Contributor Author

monatis commented Feb 11, 2022

configure_callbacks is a part of pl.LightningModule and user may want to override it,

Good point. But they still have the option to pass SaveServableCallback themselves in this case.

addition of ModelCheckpoint if it is not presented may be counterintuitive

I guess we can inspect the attributes of the user-defined ModelCheckpoint (if any) and infer whether they are trying to get the best or last checkpoint (see below). Alternative, Quaterion.fit can accept anther keyword argument restore_best_checkpoint: bool = True, which will have no effect if users pass ModelCheckpoint themselves.

Or, another alternative could be:

  • Subclass ModelCheckpoint, and implement save_servable logic here in on_fit_end.
  • If monitor keyword argument is set to any metric, restore the best checkpoint, and the last one otherwise, before calling save_servable.
  • In documentation, instruct users to use this subclassing callback instead of ModelCheckpoint to get benefit from this functionality.

Pros:

  1. quite simpler, and thus less error-prone.
  2. intuitive because there is no default behavior.

Cons:

  1. no saving by default, requiring users' explicit intent to save a servable checkpoint.

@joein
Copy link
Member

joein commented Feb 11, 2022

Good point. But they still have the option to pass SaveServableCallback themselves in this case.

User can break saving unintentionally, like he overrode configure_callbacks for his user_callback_one and then saving is broken and he is shocked and it is hard to debug. Eventually it appears that if he is overriding configure_callbacks then he needs to pass SaveServableCallback separately, and it is not obvious.

I guess we can inspect the attributes of the user-defined ModelCheckpoint (if any) and infer whether they are trying to get the best or last checkpoint (see below).

Yes, we can, also we have a problem here with multiple ModelCheckpoint instances. Which one should we take by default?

Alternative, Quaterion.fit can accept anther keyword argument restore_best_checkpoint: bool = True, which will have no effect if users pass ModelCheckpoint themselves.

It does not seem right for me to save best checkpoint when user does not specify this via ModelCheckpoint himself. Maybe he don't want to store any checkpoints at all.

Or, another alternative could be:

Yes, maybe, but I think I'd like to avoid it unless we definitely can't survive without it.

@monatis
Copy link
Contributor Author

monatis commented Feb 11, 2022

unless we definitely can't survive without it

We can, but quaterion may not without a user-friendly mechanism for saving and/or restoring / loading checkpoints properly with ease.

In fact, I'm for sensible defaults whenever possible. quaterion will be strong if it can equip users with the ability to quickly materialize their ideas with minimal effort and minimum expertise in ML / ML libs. Otherwise, there exist many other great libraries / tools for ninjas out there.

@joein
Copy link
Member

joein commented Feb 11, 2022

Yes, I support the overall idea also I support sensible defaults.

Let's take it some time

@monatis
Copy link
Contributor Author

monatis commented Feb 15, 2022

After some discussion with @joein yesterday, the current proposal is as follows:

  • Instead of introducing any extra callback, change the signature for save_servable method: save_servable(path: str, checkpoint_callback: ModelCheckpoint = None, strategy: SaveStrategy = SaveStrategy.AUTO).
  • If optional checkpoint_callback argument is passed, try to restore a checkpoint before saving a servable version according to strategy argument, which is an enumeration of (AUTO, BEST, LAST) options.
  • Otherwise, save the current state.

@joein
Copy link
Member

joein commented Feb 15, 2022

There is one more problem, I guess.

To restore full model state from checkpoint we need to call TrainableModel.load_from_checkpoint(chkpt.best_model_path)

Actually, it will be TrainableModel descendant which can have custom*args and **kwargs in __init__.
We need to pass them to load_from_checkpoint to obtain original model state.
load_from_checkpoints does not allow *args, only **kwargs being passed to model constructor, it ends up in

save_servable(
    path: str,
    checkpoint_callback: ModelCheckpoint = None,
    strategy: SaveStrategy = SaveStrategy.AUTO,
    model_params: Optional[Dict] = None
)

where model_params is joined *args and **kwargs.

It's a quite straight-forward approach and also a point where it starts looking clumsy for me.

[UPDATE]:
In load_from_checkpoint:

Primary way of loading a model from a checkpoint. When Lightning saves a checkpoint
it stores the arguments passed to __init__ in the checkpoint under hyper_parameters

[UPDATE]:
self.save_hyperparameters() has to be called to save *args and **kwargs in checkpoint, we can add it in TrainableModel.__init__(). Therefore there is no need in model_params arg in save_servable.

@monatis
Copy link
Contributor Author

monatis commented Feb 16, 2022

TrainableModel.load_from_checkpoint

We don't need it. It's fine with the following inside TrainableModel.save_servable:

def save_servable(...):
    # ...
    self.load_state_dict(torch.load(checkpoint_path)["state_dict"])
    # we restored the state from the given checkpoint, now save it as servable

@generall
Copy link
Member

LGTM

@monatis monatis self-assigned this Mar 10, 2022
@generall
Copy link
Member

After giving it some thought I decided that this functionality is not something our framework should be responsible for.
Check-pointing is an optional parameter of pytorch lightning and having additional logic on top,which have nothing to do with similarity learning, might complicate the automation in an unpredictable way.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request help wanted Extra attention is needed wontfix This will not be worked on
Projects
None yet
Development

No branches or pull requests

3 participants