Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix unexpected key pixel_metrics.AUPRO.fpr_limit #1055

Merged
merged 11 commits into from
Oct 24, 2023
28 changes: 28 additions & 0 deletions src/anomalib/models/components/base/anomaly_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from __future__ import annotations

import importlib
import logging
from abc import ABC
from typing import Any, OrderedDict
Expand Down Expand Up @@ -234,11 +235,38 @@ def _load_normalization_class(self, state_dict: OrderedDict[str, Tensor]) -> Non
else:
warn("No known normalization found in model weights.")

def _load_metrics(self, state_dict: OrderedDict[str, Tensor]) -> None:
"""Load metrics from saved checkpoint."""
self._set_metrics("pixel", state_dict)
self._set_metrics("image", state_dict)

def _set_metrics(self, name: str, state_dict: OrderedDict[str, Tensor]):
"""Sets the pixel/image metrics.

Args:
name (str): is it pixel or image.
state_dict (OrderedDict[str, Tensor]): state dict of the model.
"""
metric_keys = [key for key in state_dict.keys() if key.startswith(f"{name}_metrics")]
if not hasattr(self, f"{name}_metrics") and any(metric_keys):
metrics = AnomalibMetricCollection([], prefix=f"{name}_")
for key in metric_keys:
class_name = key.split(".")[1]
try:
metrics_module = importlib.import_module("anomalib.utils.metrics")
metrics_cls = getattr(metrics_module, class_name)
except Exception as exception:
raise ImportError(f"Class {class_name} not found in module anomalib.utils.metrics") from exception
metrics.add_metrics(metrics_cls())
setattr(self, f"{name}_metrics", metrics)

def load_state_dict(self, state_dict: OrderedDict[str, Tensor], strict: bool = True):
"""Load state dict from checkpoint.

Ensures that normalization and thresholding attributes is properly setup before model is loaded.
"""
# Used to load missing normalization and threshold parameters
self._load_normalization_class(state_dict)
# Used to load metrics if there is any related data in state_dict
self._load_metrics(state_dict)
return super().load_state_dict(state_dict, strict=strict)
8 changes: 7 additions & 1 deletion src/anomalib/utils/callbacks/metrics_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,13 @@ def setup(

if isinstance(pl_module, AnomalyModule):
pl_module.image_metrics = create_metric_collection(image_metric_names, "image_")
pl_module.pixel_metrics = create_metric_collection(pixel_metric_names, "pixel_")
if hasattr(pl_module, "pixel_metrics"):
new_metrics = create_metric_collection(pixel_metric_names, "pixel_")
for name in new_metrics.keys():
if name not in pl_module.pixel_metrics.keys():
pl_module.pixel_metrics.add_metrics(new_metrics[name.split("_")[1]])
else:
pl_module.pixel_metrics = create_metric_collection(pixel_metric_names, "pixel_")

pl_module.image_metrics.set_threshold(pl_module.image_threshold.value)
pl_module.pixel_metrics.set_threshold(pl_module.pixel_threshold.value)