diff --git a/quaterion/main.py b/quaterion/main.py index 09676783..2dc6b849 100644 --- a/quaterion/main.py +++ b/quaterion/main.py @@ -3,7 +3,7 @@ import torch import pytorch_lightning as pl from torch.utils.data import Dataset -from pytorch_lightning.callbacks import ModelSummary +from pytorch_lightning.callbacks import ModelSummary, EarlyStopping from quaterion_models import SimilarityModel from quaterion.dataset.similarity_data_loader import ( @@ -14,8 +14,9 @@ from quaterion.eval.evaluator import Evaluator from quaterion.loss import GroupLoss, PairwiseLoss from quaterion.train.trainable_model import TrainableModel -from quaterion.utils.progress_bar import QuaterionProgressBar from quaterion.train.callbacks import CleanupCallback, MetricsCallback +from quaterion.utils.enums import TrainStage +from quaterion.utils.progress_bar import QuaterionProgressBar class Quaterion: @@ -110,8 +111,11 @@ def evaluate( def trainer_defaults(): use_gpu = torch.cuda.is_available() defaults = { - "callbacks": [QuaterionProgressBar(), ModelSummary(max_depth=3)], - "max_epochs": 1000, + "callbacks": [ + QuaterionProgressBar(), + EarlyStopping(f"{TrainStage.VALIDATION}_loss"), + ModelSummary(max_depth=3), + ], "gpus": int(use_gpu), "auto_select_gpus": use_gpu, "log_every_n_steps": 10, diff --git a/quaterion/utils/progress_bar.py b/quaterion/utils/progress_bar.py index 5f3b944c..5fd2f415 100644 --- a/quaterion/utils/progress_bar.py +++ b/quaterion/utils/progress_bar.py @@ -23,7 +23,7 @@ def __init__( theme=theme, console_kwargs=console_kwargs, ) - self.predict_progress_bar_id: Optional[int] = None + self.predict_progress_bar_id = None self._caching = False def on_predict_batch_start(