Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐞 Fix Rich Progress with Patchcore Training #2062

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/anomalib/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from typing import Any

import torch
from lightning.pytorch.callbacks import Callback
from lightning.pytorch.callbacks import Callback, RichModelSummary, RichProgressBar
from lightning.pytorch.loggers import Logger
from lightning.pytorch.trainer import Trainer
from lightning.pytorch.utilities.types import _EVALUATE_OUTPUT, _PREDICT_OUTPUT, EVAL_DATALOADERS, TRAIN_DATALOADERS
Expand Down Expand Up @@ -406,7 +406,7 @@ def _setup_transform(

def _setup_anomalib_callbacks(self) -> None:
"""Set up callbacks for the trainer."""
_callbacks: list[Callback] = []
_callbacks: list[Callback] = [RichProgressBar(), RichModelSummary()]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is not directly related to the bug but I am now using RichProgressBar in the trainer. I don't mind if we drop this as it is not necessary


# Add ModelCheckpoint if it is not in the callbacks list.
has_checkpoint_callback = any(isinstance(c, ModelCheckpoint) for c in self._cache.args["callbacks"])
Expand Down
4 changes: 0 additions & 4 deletions src/anomalib/loggers/mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,8 @@
from lightning.pytorch.utilities import rank_zero_only
from matplotlib.figure import Figure

from anomalib.utils.exceptions.imports import try_import

from .base import ImageLoggerBase

try_import("mlflow")


class AnomalibMLFlowLogger(ImageLoggerBase, MLFlowLogger):
"""Logger for MLFlow.
Expand Down
4 changes: 2 additions & 2 deletions src/anomalib/models/components/sampling/k_center_greedy.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
# SPDX-License-Identifier: Apache-2.0

import torch
from rich.progress import track
from torch.nn import functional as F # noqa: N812

from anomalib.models.components.dimensionality_reduction import SparseRandomProjection
from anomalib.utils.rich import safe_track


class KCenterGreedy:
Expand Down Expand Up @@ -98,7 +98,7 @@ def select_coreset_idxs(self, selected_idxs: list[int] | None = None) -> list[in

selected_coreset_idxs: list[int] = []
idx = int(torch.randint(high=self.n_observations, size=(1,)).item())
for _ in track(range(self.coreset_size), description="Selecting Coreset Indices."):
for _ in safe_track(sequence=range(self.coreset_size), description="Selecting Coreset Indices."):
self.update_distances(cluster_centers=[idx])
idx = self.get_new_idx()
if idx in selected_idxs:
Expand Down
47 changes: 47 additions & 0 deletions src/anomalib/utils/rich.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""Custom rich methods."""

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from collections.abc import Generator, Iterable
from typing import TYPE_CHECKING, Any

from rich import get_console
from rich.progress import track

if TYPE_CHECKING:
from rich.live import Live


class CacheRichLiveState:
"""Cache the live state of the console.

Note: This is a bit dangerous as it accesses private attributes of the console.
Use this with caution.
"""

def __init__(self) -> None:
self.console = get_console()
self.live: "Live" | None = None

def __enter__(self) -> None:
"""Save the live state of the console."""
# Need to access private attribute to get the live state
with self.console._lock: # noqa: SLF001
self.live = self.console._live # noqa: SLF001
self.console.clear_live()

def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: # noqa: ANN401
"""Restore the live state of the console."""
if self.live:
self.console.clear_live()
self.console.set_live(self.live)


def safe_track(*args, **kwargs) -> Generator[Iterable, Any, Any]:
"""Wraps ``rich.progress.track`` with a context manager to cache the live state.

For parameters look at ``rich.progress.track``.
"""
with CacheRichLiveState():
yield from track(*args, **kwargs)
Loading