Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

πŸ”¨ Update lightning inference #2018

Merged
merged 2 commits into from
Apr 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).

### Added

- Add data_path argument to predict entrypoint and add properties for retrieving model path by @djdameln in https://github.com/openvinotoolkit/anomalib/pull/2018

### Changed

- πŸ”¨Rename OptimalF1 to F1Max for consistency with the literature, by @samet-akcay in https://github.com/openvinotoolkit/anomalib/pull/1980
Expand Down
29 changes: 3 additions & 26 deletions src/anomalib/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import logging
from collections.abc import Callable, Sequence
from functools import partial
from inspect import signature
from pathlib import Path
from types import MethodType
from typing import Any
Expand All @@ -29,7 +28,6 @@
from torch.utils.data import DataLoader, Dataset

from anomalib.data import AnomalibDataModule
from anomalib.data.predict import PredictDataset
from anomalib.engine import Engine
from anomalib.metrics.threshold import BaseThreshold
from anomalib.models import AnomalyModule
Expand Down Expand Up @@ -216,7 +214,7 @@ def add_predict_arguments(self, parser: ArgumentParser) -> None:
added = parser.add_method_arguments(
Engine,
"predict",
skip={"model", "dataloaders", "datamodule", "dataset"},
skip={"model", "dataloaders", "datamodule", "dataset", "data_path"},
)
self.subcommand_method_arguments["predict"] = added
self.add_arguments_to_parser(parser)
Expand Down Expand Up @@ -267,8 +265,6 @@ def before_instantiate_classes(self) -> None:
"""Modify the configuration to properly instantiate classes and sets up tiler."""
subcommand = self.config["subcommand"]
if subcommand in (*self.subcommands(), "train", "predict"):
if self.config["subcommand"] == "predict" and isinstance(self.config["predict"]["data"], str | Path):
self.config["predict"]["data"] = self._set_predict_dataloader_namespace(self.config["predict"]["data"])
self.config[subcommand] = update_config(self.config[subcommand])

def instantiate_classes(self) -> None:
Expand Down Expand Up @@ -415,27 +411,6 @@ def _add_trainer_arguments_to_parser(
**scheduler_kwargs,
)

def _set_predict_dataloader_namespace(self, data_path: str | Path | Namespace) -> Namespace:
"""Set the predict dataloader namespace.

If the argument is of type str or Path, then it is assumed to be the path to the prediction data and is
assigned to PredictDataset.

Args:
data_path (str | Path | Namespace): Path to the data.

Returns:
Namespace: Namespace containing the predict dataloader.
"""
if isinstance(data_path, str | Path):
init_args = {key: value.default for key, value in signature(PredictDataset).parameters.items()}
init_args["path"] = data_path
data_path = Namespace(
class_path="anomalib.data.predict.PredictDataset",
init_args=Namespace(init_args),
)
return data_path

def _add_default_arguments_to_parser(self, parser: ArgumentParser) -> None:
"""Adds default arguments to the parser."""
parser.add_argument(
Expand Down Expand Up @@ -463,6 +438,8 @@ def _prepare_subcommand_kwargs(self, subcommand: str) -> dict[str, Any]:
fn_kwargs["datamodule"] = self.datamodule
elif isinstance(self.datamodule, DataLoader):
fn_kwargs["dataloaders"] = self.datamodule
elif isinstance(self.datamodule, Path | str):
fn_kwargs["data_path"] = self.datamodule
return fn_kwargs

def _parser(self, subcommand: str | None) -> ArgumentParser:
Expand Down
49 changes: 38 additions & 11 deletions src/anomalib/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,28 @@ def threshold_callback(self) -> _ThresholdCallback | None:
raise ValueError(msg)
return callbacks[0] if len(callbacks) > 0 else None

@property
def checkpoint_callback(self) -> ModelCheckpoint | None:
samet-akcay marked this conversation as resolved.
Show resolved Hide resolved
"""The ``ModelCheckpoint`` callback in the trainer.callbacks list, or ``None`` if it doesn't exist.

Returns:
ModelCheckpoint | None: ModelCheckpoint callback, if available.
"""
if self._trainer is None:
return None
return self.trainer.checkpoint_callback

@property
def best_model_path(self) -> str | None:
"""The path to the best model checkpoint.

Returns:
str: Path to the best model checkpoint.
"""
if self.checkpoint_callback is None:
return None
return self.checkpoint_callback.best_model_path

def _setup_workspace(
self,
model: AnomalyModule,
Expand Down Expand Up @@ -672,6 +694,7 @@ def predict(
dataset: Dataset | PredictDataset | None = None,
return_predictions: bool | None = None,
ckpt_path: str | Path | None = None,
data_path: str | Path | None = None,
) -> _PREDICT_OUTPUT | None:
"""Predict using the model using the trainer.

Expand Down Expand Up @@ -703,6 +726,9 @@ def predict(
Otherwise, the best model checkpoint from the previous ``trainer.fit`` call will be loaded
if a checkpoint callback is configured.
Defaults to None.
data_path (str | Path | None):
Path to the image or folder containing images to generate predictions for.
Defaults to None.

Returns:
_PREDICT_OUTPUT | None: Predictions.
Expand Down Expand Up @@ -743,18 +769,19 @@ def predict(
if not ckpt_path:
logger.warning("ckpt_path is not provided. Model weights will not be loaded.")

# Handle the instance when a dataset is passed to the predict method
# Collect dataloaders
if dataloaders is None:
dataloaders = []
elif isinstance(dataloaders, DataLoader):
dataloaders = [dataloaders]
elif not isinstance(dataloaders, list):
msg = f"Unknown type for dataloaders {type(dataloaders)}"
raise TypeError(msg)
if dataset is not None:
dataloader = DataLoader(dataset)
if dataloaders is None:
dataloaders = dataloader
elif isinstance(dataloaders, DataLoader):
dataloaders = [dataloaders, dataloader]
elif isinstance(dataloaders, list): # dataloader is a list
dataloaders.append(dataloader)
else:
msg = f"Unknown type for dataloaders {type(dataloaders)}"
raise TypeError(msg)
dataloaders.append(DataLoader(dataset))
if data_path is not None:
dataloaders.append(DataLoader(PredictDataset(data_path)))
dataloaders = dataloaders or None

self._setup_dataset_task(dataloaders, datamodule)
self._setup_transform(model or self.model, datamodule=datamodule, dataloaders=dataloaders, ckpt_path=ckpt_path)
Expand Down
Loading