From 722a780eeda05891f8b01483292f99039cb95e82 Mon Sep 17 00:00:00 2001 From: George Panchuk Date: Tue, 31 May 2022 17:56:56 +0300 Subject: [PATCH] new: replace max epochs with early stopping callback #113 --- quaterion/main.py | 12 ++++++++---- quaterion/utils/progress_bar.py | 2 +- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/quaterion/main.py b/quaterion/main.py index 20b1dcb4..305118bb 100644 --- a/quaterion/main.py +++ b/quaterion/main.py @@ -6,7 +6,7 @@ import pytorch_lightning as pl from pytorch_lightning.utilities.warnings import PossibleUserWarning from torch.utils.data import Dataset -from pytorch_lightning.callbacks import ModelSummary +from pytorch_lightning.callbacks import ModelSummary, EarlyStopping from quaterion_models import SimilarityModel from quaterion.dataset.similarity_data_loader import ( @@ -17,8 +17,9 @@ from quaterion.eval.evaluator import Evaluator from quaterion.loss import GroupLoss, PairwiseLoss from quaterion.train.trainable_model import TrainableModel -from quaterion.utils.progress_bar import QuaterionProgressBar from quaterion.train.callbacks import CleanupCallback, MetricsCallback +from quaterion.utils.enums import TrainStage +from quaterion.utils.progress_bar import QuaterionProgressBar class Quaterion: @@ -119,8 +120,11 @@ def evaluate( def trainer_defaults(): use_gpu = torch.cuda.is_available() defaults = { - "callbacks": [QuaterionProgressBar(), ModelSummary(max_depth=3)], - "max_epochs": 1000, + "callbacks": [ + QuaterionProgressBar(), + EarlyStopping(f"{TrainStage.VALIDATION}_loss"), + ModelSummary(max_depth=3), + ], "gpus": int(use_gpu), "auto_select_gpus": use_gpu, "log_every_n_steps": 10, diff --git a/quaterion/utils/progress_bar.py b/quaterion/utils/progress_bar.py index 5f3b944c..5fd2f415 100644 --- a/quaterion/utils/progress_bar.py +++ b/quaterion/utils/progress_bar.py @@ -23,7 +23,7 @@ def __init__( theme=theme, console_kwargs=console_kwargs, ) - self.predict_progress_bar_id: Optional[int] = None + self.predict_progress_bar_id = None self._caching = False def on_predict_batch_start(