diff --git a/src/anomalib/models/components/base/anomaly_module.py b/src/anomalib/models/components/base/anomaly_module.py index b9a7368337..0bf32ae053 100644 --- a/src/anomalib/models/components/base/anomaly_module.py +++ b/src/anomalib/models/components/base/anomaly_module.py @@ -5,6 +5,7 @@ from __future__ import annotations +import importlib import logging from abc import ABC from typing import Any, OrderedDict @@ -234,6 +235,31 @@ def _load_normalization_class(self, state_dict: OrderedDict[str, Tensor]) -> Non else: warn("No known normalization found in model weights.") + def _load_metrics(self, state_dict: OrderedDict[str, Tensor]) -> None: + """Load metrics from saved checkpoint.""" + self._set_metrics("pixel", state_dict) + self._set_metrics("image", state_dict) + + def _set_metrics(self, name: str, state_dict: OrderedDict[str, Tensor]): + """Sets the pixel/image metrics. + + Args: + name (str): is it pixel or image. + state_dict (OrderedDict[str, Tensor]): state dict of the model. + """ + metric_keys = [key for key in state_dict.keys() if key.startswith(f"{name}_metrics")] + if not hasattr(self, f"{name}_metrics") and any(metric_keys): + metrics = AnomalibMetricCollection([], prefix=f"{name}_") + for key in metric_keys: + class_name = key.split(".")[1] + try: + metrics_module = importlib.import_module("anomalib.utils.metrics") + metrics_cls = getattr(metrics_module, class_name) + except Exception as exception: + raise ImportError(f"Class {class_name} not found in module anomalib.utils.metrics") from exception + metrics.add_metrics(metrics_cls()) + setattr(self, f"{name}_metrics", metrics) + def load_state_dict(self, state_dict: OrderedDict[str, Tensor], strict: bool = True): """Load state dict from checkpoint. @@ -241,4 +267,6 @@ def load_state_dict(self, state_dict: OrderedDict[str, Tensor], strict: bool = T """ # Used to load missing normalization and threshold parameters self._load_normalization_class(state_dict) + # Used to load metrics if there is any related data in state_dict + self._load_metrics(state_dict) return super().load_state_dict(state_dict, strict=strict) diff --git a/src/anomalib/utils/callbacks/metrics_configuration.py b/src/anomalib/utils/callbacks/metrics_configuration.py index 31113961ec..113dc5ce4c 100644 --- a/src/anomalib/utils/callbacks/metrics_configuration.py +++ b/src/anomalib/utils/callbacks/metrics_configuration.py @@ -77,7 +77,13 @@ def setup( if isinstance(pl_module, AnomalyModule): pl_module.image_metrics = create_metric_collection(image_metric_names, "image_") - pl_module.pixel_metrics = create_metric_collection(pixel_metric_names, "pixel_") + if hasattr(pl_module, "pixel_metrics"): + new_metrics = create_metric_collection(pixel_metric_names, "pixel_") + for name in new_metrics.keys(): + if name not in pl_module.pixel_metrics.keys(): + pl_module.pixel_metrics.add_metrics(new_metrics[name.split("_")[1]]) + else: + pl_module.pixel_metrics = create_metric_collection(pixel_metric_names, "pixel_") pl_module.image_metrics.set_threshold(pl_module.image_threshold.value) pl_module.pixel_metrics.set_threshold(pl_module.pixel_threshold.value) diff --git a/tests/pre_merge/utils/callbacks/metrics_configuration_callback/data/config-good-02-serialized.yaml b/tests/pre_merge/utils/callbacks/metrics_configuration_callback/data/config-good-02-serialized.yaml new file mode 100644 index 0000000000..8c5bbabb65 --- /dev/null +++ b/tests/pre_merge/utils/callbacks/metrics_configuration_callback/data/config-good-02-serialized.yaml @@ -0,0 +1,19 @@ +metrics: + pixel: + F1Score: + class_path: torchmetrics.F1Score + init_args: + compute_on_cpu: true + AUPRO: + class_path: anomalib.utils.metrics.AUPRO + init_args: + compute_on_cpu: true + image: + F1Score: + class_path: torchmetrics.F1Score + init_args: + compute_on_cpu: true + AUROC: + class_path: anomalib.utils.metrics.AUROC + init_args: + compute_on_cpu: true diff --git a/tests/pre_merge/utils/callbacks/metrics_configuration_callback/data/config-good-02.yaml b/tests/pre_merge/utils/callbacks/metrics_configuration_callback/data/config-good-02.yaml new file mode 100644 index 0000000000..4697c9d478 --- /dev/null +++ b/tests/pre_merge/utils/callbacks/metrics_configuration_callback/data/config-good-02.yaml @@ -0,0 +1,19 @@ +metrics: + pixel: + F1Score: + class_path: torchmetrics.F1Score + init_args: + compute_on_cpu: true + AUROC: + class_path: anomalib.utils.metrics.AUROC + init_args: + compute_on_cpu: true + image: + F1Score: + class_path: torchmetrics.F1Score + init_args: + compute_on_cpu: true + AUROC: + class_path: anomalib.utils.metrics.AUROC + init_args: + compute_on_cpu: true diff --git a/tests/pre_merge/utils/callbacks/metrics_configuration_callback/test_metrics_configuration_callback.py b/tests/pre_merge/utils/callbacks/metrics_configuration_callback/test_metrics_configuration_callback.py index 5e377a2b90..68ba2486fe 100644 --- a/tests/pre_merge/utils/callbacks/metrics_configuration_callback/test_metrics_configuration_callback.py +++ b/tests/pre_merge/utils/callbacks/metrics_configuration_callback/test_metrics_configuration_callback.py @@ -1,8 +1,11 @@ +from itertools import chain from pathlib import Path +from collections import OrderedDict import pytest import pytorch_lightning as pl from omegaconf import OmegaConf +import torch from anomalib.models.components import AnomalyModule from anomalib.utils.callbacks.metrics_configuration import MetricsConfigurationCallback @@ -61,3 +64,54 @@ def test_metric_collection_configuration_callback(config_from_yaml): assert isinstance( dummy_anomaly_module.pixel_metrics, AnomalibMetricCollection ), f"{dummy_anomaly_module.pixel_metrics}" + + +@pytest.mark.parametrize( + ["ori_config_from_yaml", "saved_config_from_yaml"], + [("data/config-good-02.yaml", "data/config-good-02-serialized.yaml")], +) +def test_metric_collection_configuration_deserialzation_callback(ori_config_from_yaml, saved_config_from_yaml): + """Test if metrics are properly instantiated during deserialzation.""" + + ori_config_from_yaml_res = OmegaConf.load(Path(__file__).parent / ori_config_from_yaml) + saved_config_from_yaml_res = OmegaConf.load(Path(__file__).parent / saved_config_from_yaml) + callback = MetricsConfigurationCallback( + task="segmentation", + image_metrics=ori_config_from_yaml_res.metrics.image, + pixel_metrics=ori_config_from_yaml_res.metrics.pixel, + ) + + dummy_logger = DummyLogger() + dummy_anomaly_module = _DummyAnomalyModule() + trainer = pl.Trainer( + callbacks=[callback], logger=dummy_logger, enable_checkpointing=False, default_root_dir=dummy_logger.tempdir + ) + + saved_image_state_dict = OrderedDict( + { + "image_metrics." + k: torch.tensor(1.0) + for k, v in saved_config_from_yaml_res.metrics.image.items() + if v["class_path"].startswith("anomalib.utils.metrics") + } + ) + saved_pixel_state_dict = OrderedDict( + { + "pixel_metrics." + k: torch.tensor(1.0) + for k, v in saved_config_from_yaml_res.metrics.pixel.items() + if v["class_path"].startswith("anomalib.utils.metrics") + } + ) + + final_state_dict = OrderedDict(chain(saved_image_state_dict.items(), saved_pixel_state_dict.items())) + + dummy_anomaly_module._load_metrics(final_state_dict) + callback.setup(trainer, dummy_anomaly_module, DummyDataModule()) + + assert isinstance( + dummy_anomaly_module.image_metrics, AnomalibMetricCollection + ), f"{dummy_anomaly_module.image_metrics}" + assert isinstance( + dummy_anomaly_module.pixel_metrics, AnomalibMetricCollection + ), f"{dummy_anomaly_module.pixel_metrics}" + + assert sorted((list(dummy_anomaly_module.pixel_metrics))) == ["AUPRO", "AUROC", "F1Score"]