Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

πŸš€ Anomalib Pipelines #2005

Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion src/anomalib/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from rich import traceback

from anomalib import TaskType, __version__
from anomalib.cli.pipelines import add_pipeline_subparsers, run_pipeline
from anomalib.cli.utils.help_formatter import CustomHelpFormatter, get_short_docstring
from anomalib.cli.utils.openvino import add_openvino_export_arguments
from anomalib.loggers import configure_logger
Expand Down Expand Up @@ -92,6 +93,7 @@ def anomalib_subcommands() -> dict[str, dict[str, str]]:
"train": {"description": "Fit the model and then call test on the trained model."},
"predict": {"description": "Run inference on a model."},
"export": {"description": "Export the model to ONNX or OpenVINO format."},
"pipeline": {"description": "Run a pipeline of jobs."},
}

def add_subcommands(self, **kwargs) -> None:
Expand Down Expand Up @@ -240,6 +242,10 @@ def add_export_arguments(self, parser: ArgumentParser) -> None:
add_openvino_export_arguments(parser)
self.add_arguments_to_parser(parser)

def add_pipeline_arguments(self, parser: ArgumentParser) -> None:
"""Add pipeline arguments to the parser."""
add_pipeline_subparsers(parser)

def _set_install_subcommand(self, action_subcommand: _ActionSubCommands) -> None:
sub_parser = ArgumentParser(formatter_class=CustomHelpFormatter)
sub_parser.add_argument(
Expand Down Expand Up @@ -288,7 +294,7 @@ def instantiate_classes(self) -> None:
self.model = self._get(self.config_init, "model")
self._configure_optimizers_method_to_model()
self.instantiate_engine()
else:
elif self.config["subcommand"] != "pipeline":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which approach would be better?
anomalib pipeline benchmark --arg1 val1 --arg2 val2

or
anomalib benchmark --arg1 val --arg2 val2`

Just setting the stage here for discussion

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer anomalib benchmark ... but if someone implements a custom pipeline then I feel they should be able to run it without making changes to the cli. In this case they might have to use anomalib pipeline cutom_pipeline?

self.config_init = self.parser.instantiate_classes(self.config)
subcommand = self.config["subcommand"]
if subcommand in ("train", "export"):
Expand Down Expand Up @@ -353,6 +359,8 @@ def _run_subcommand(self) -> None:
fn = getattr(self.engine, self.subcommand)
fn_kwargs = self._prepare_subcommand_kwargs(self.subcommand)
fn(**fn_kwargs)
elif self.subcommand == "pipeline":
run_pipeline(self.config)
else:
self.config_init = self.parser.instantiate_classes(self.config)
getattr(self, f"{self.subcommand}")()
Expand Down
36 changes: 36 additions & 0 deletions src/anomalib/cli/pipelines.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""Subcommand for pipelines."""

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from jsonargparse import ArgumentParser, Namespace

from anomalib.utils.exceptions import try_import

if try_import("anomalib.pipelines"):
from anomalib.pipelines import Benchmark
from anomalib.pipelines.components.base import Pipeline

PIPELINE_REGISTRY: dict[str, Pipeline] | None = {
"benchmark": Benchmark(),
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this multi-line intentional? Looks like it could fit into a single line?

else:
PIPELINE_REGISTRY = None


def add_pipeline_subparsers(parser: ArgumentParser) -> None:
"""Add subparsers for pipelines."""
if PIPELINE_REGISTRY is not None:
subcommands = parser.add_subcommands(dest="subcommand", help="Run Pipelines", required=True)
for name, pipeline in PIPELINE_REGISTRY.items():
subcommands.add_subcommand(name, pipeline.get_parser(), help=f"Run {name} pipeline")


def run_pipeline(args: Namespace) -> None:
samet-akcay marked this conversation as resolved.
Show resolved Hide resolved
"""Run pipeline."""
if PIPELINE_REGISTRY is not None:
config = args.pipeline[args.pipeline.subcommand]
PIPELINE_REGISTRY[args.pipeline.subcommand].run(config)
else:
msg = "Pipeline is not available"
raise ValueError(msg)
26 changes: 20 additions & 6 deletions src/anomalib/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
# Copyright (C) 2022-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0


import importlib
import logging
from enum import Enum
Expand All @@ -29,20 +28,35 @@
)


def get_datamodule(config: DictConfig | ListConfig) -> AnomalibDataModule:
class UnknownDatamoduleError(ModuleNotFoundError):
...


def get_datamodule(config: DictConfig | ListConfig | dict) -> AnomalibDataModule:
samet-akcay marked this conversation as resolved.
Show resolved Hide resolved
"""Get Anomaly Datamodule.

Args:
config (DictConfig | ListConfig): Configuration of the anomaly model.
config (DictConfig | ListConfig | dict): Configuration of the anomaly model.

Returns:
PyTorch Lightning DataModule
"""
logger.info("Loading the datamodule")

module = importlib.import_module(".".join(config.data.class_path.split(".")[:-1]))
dataclass = getattr(module, config.data.class_path.split(".")[-1])
init_args = {**config.data.get("init_args", {})} # get dict
if isinstance(config, dict):
config = DictConfig(config)

try:
_config = config.data if "data" in config else config
if len(_config.class_path.split(".")) > 1:
module = importlib.import_module(".".join(_config.class_path.split(".")[:-1]))
else:
module = importlib.import_module("anomalib.data")
except ModuleNotFoundError as exception:
logger.exception(f"ModuleNotFoundError: {_config.class_path}")
raise UnknownDatamoduleError from exception
dataclass = getattr(module, _config.class_path.split(".")[-1])
init_args = {**_config.get("init_args", {})} # get dict
if "image_size" in init_args:
init_args["image_size"] = to_tuple(init_args["image_size"])

Expand Down
8 changes: 8 additions & 0 deletions src/anomalib/pipelines/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""Pipelines for end-to-end usecases."""

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from .benchmark import Benchmark

__all__ = ["Benchmark"]
8 changes: 8 additions & 0 deletions src/anomalib/pipelines/benchmark/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""Benchmarking."""

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from .pipeline import Benchmark

__all__ = ["Benchmark"]
83 changes: 83 additions & 0 deletions src/anomalib/pipelines/benchmark/generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
"""Benchmark job generator."""

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
ashwinvaidya17 marked this conversation as resolved.
Show resolved Hide resolved
from argparse import SUPPRESS
from collections.abc import Generator

from jsonargparse import ArgumentParser, Namespace
from jsonargparse._optionals import get_doc_short_description

from anomalib.data import AnomalibDataModule, get_datamodule
from anomalib.models import AnomalyModule, get_model
from anomalib.pipelines.components import JobGenerator, dict_from_namespace, hide_output
from anomalib.pipelines.components.actions import GridSearchAction, get_iterator_from_grid_dict

from .job import BenchmarkJob


class BenchmarkJobGenerator(JobGenerator):
"""Generate BenchmarkJob."""

def __init__(self, accelerator: str) -> None:
self.accelerator = accelerator
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure about the terminology here. We use this variable mainly to distinguish between cpu and gpu, but I'm not sure if cpu is technically considered to be an accelerator. Maybe device would be a more suitable name?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Originally this was called device. I think we discussed on changing this to accelerator to be inline with lightning's terminology. I have no preference here. So, I can rename this once we finalise the name.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


@property
def job_class(self) -> type:
"""Return the job class."""
return BenchmarkJob

@staticmethod
def add_arguments(parser: ArgumentParser) -> None:
"""Add job specific arguments to the parser."""
group = parser.add_argument_group("Benchmark job specific arguments.")
group.add_argument(
f"--{BenchmarkJob.name}.seed",
type=int | dict[str, list[int]],
default=42,
help="Seed for reproducibility.",
)
BenchmarkJobGenerator._add_subclass_arguments(group, AnomalyModule, f"{BenchmarkJob.name}.model")
BenchmarkJobGenerator._add_subclass_arguments(group, AnomalibDataModule, f"{BenchmarkJob.name}.data")
ashwinvaidya17 marked this conversation as resolved.
Show resolved Hide resolved

@hide_output
def generate_jobs(self, args: Namespace) -> Generator[BenchmarkJob, None, None]:
"""Return iterator based on the arguments."""
container = {
"seed": args.seed,
"data": dict_from_namespace(args.data),
"model": dict_from_namespace(args.model),
}
for _container in get_iterator_from_grid_dict(container):
yield BenchmarkJob(
accelerator=self.accelerator,
seed=_container["seed"],
model=get_model(_container["model"]),
datamodule=get_datamodule(_container["data"]),
)

@staticmethod
def _add_subclass_arguments(parser: ArgumentParser, baseclass: type, key: str) -> None:
ashwinvaidya17 marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain why we need this method?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't call the default add_subclass_arguments of jsonargparse as we need to override the action. This method passes GridSearchAction to action parameter when registering class arguments.

"""Adds the subclass of the provided class to the parser under nested_key."""
doc_group = get_doc_short_description(baseclass, logger=parser.logger)
group = parser._create_group_if_requested( # noqa: SLF001
baseclass,
nested_key=key,
as_group=True,
doc_group=doc_group,
config_load=False,
instantiate=False,
)

with GridSearchAction.allow_default_instance_context():
action = group.add_argument(
f"--{key}",
metavar="CONFIG | CLASS_PATH_OR_NAME | .INIT_ARG_NAME VALUE",
help=(
'One or more arguments specifying "class_path"'
f' and "init_args" for any subclass of {baseclass.__name__}.'
),
default=SUPPRESS,
action=GridSearchAction(typehint=baseclass, enable_path=True, logger=parser.logger),
)
action.sub_add_kwargs = {"fail_untyped": True, "sub_configs": True, "instantiate": True}
101 changes: 101 additions & 0 deletions src/anomalib/pipelines/benchmark/job.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
"""Benchmarking job."""

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import logging
from datetime import datetime
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Any

import pandas as pd
from lightning import seed_everything
from rich.console import Console
from rich.table import Table

from anomalib.data import AnomalibDataModule
from anomalib.engine import Engine
from anomalib.models import AnomalyModule
from anomalib.pipelines.components import (
Job,
hide_output,
)
ashwinvaidya17 marked this conversation as resolved.
Show resolved Hide resolved

logger = logging.getLogger(__name__)


class BenchmarkJob(Job):
"""Benchmarking job."""

name = "benchmark"

def __init__(self, accelerator: str, model: AnomalyModule, datamodule: AnomalibDataModule, seed: int) -> None:
super().__init__()
self.accelerator = accelerator
self.model = model
self.datamodule = datamodule
self.seed = seed

@hide_output
def run(
self,
task_id: int | None = None,
) -> dict[str, Any]:
"""Run the benchmark."""
devices: str | list[int] = "auto"
if task_id is not None:
samet-akcay marked this conversation as resolved.
Show resolved Hide resolved
devices = [task_id]
logger.info(f"Running job {self.model.__class__.__name__} with device {task_id}")
with TemporaryDirectory() as temp_dir:
seed_everything(self.seed)
samet-akcay marked this conversation as resolved.
Show resolved Hide resolved
engine = Engine(
accelerator=self.accelerator,
devices=devices,
default_root_dir=temp_dir,
)
engine.fit(self.model, self.datamodule)
test_results = engine.test(self.model, self.datamodule)
output = {
"seed": self.seed,
"accelerator": self.accelerator,
"model": self.model.__class__.__name__,
"data": self.datamodule.__class__.__name__,
"category": self.datamodule.category,
**test_results[0],
}
logger.info(f"Completed with result {output}")
return output

@staticmethod
def collect(results: list[dict[str, Any]]) -> pd.DataFrame:
"""Gather the results returned from run."""
output: dict[str, Any] = {}
for key in results[0]:
output[key] = []
for result in results:
for key, value in result.items():
output[key].append(value)
return pd.DataFrame(output)

@staticmethod
def save(result: pd.DataFrame) -> None:
"""Save the result to a csv file."""
BenchmarkJob._print_tabular_results(result)
file_path = Path("runs") / BenchmarkJob.name / datetime.now().strftime("%Y-%m-%d-%H:%M:%S") / "results.csv"
file_path.parent.mkdir(parents=True, exist_ok=True)
result.to_csv(file_path, index=False)
logger.info(f"Saved results to {file_path}")

@staticmethod
def _print_tabular_results(gathered_result: pd.DataFrame) -> None:
samet-akcay marked this conversation as resolved.
Show resolved Hide resolved
"""Print the tabular results."""
if gathered_result is not None:
console = Console()
table = Table(title=f"{BenchmarkJob.name} Results", show_header=True, header_style="bold magenta")
_results = gathered_result.to_dict("list")
for column in _results:
table.add_column(column)
for row in zip(*_results.values(), strict=False):
table.add_row(*[str(value) for value in row])
console.print(table)
44 changes: 44 additions & 0 deletions src/anomalib/pipelines/benchmark/pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""Benchmarking."""

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import torch
from jsonargparse import ArgumentParser, Namespace

from anomalib.pipelines.components.base import Pipeline, Runner
from anomalib.pipelines.components.runners import ParallelRunner, SerialRunner

from .generator import BenchmarkJobGenerator


class Benchmark(Pipeline):
"""Benchmarking orchestrator."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we still using orchestrator? Or is it a left-over?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I missed this


def _setup_runners(self, args: Namespace) -> list[Runner]:
"""Setup the runners for the pipeline."""
accelerators = args.accelerator if isinstance(args.accelerator, list) else [args.accelerator]
runners: list[Runner] = []
for accelerator in accelerators:
if accelerator == "cpu":
runners.append(SerialRunner(BenchmarkJobGenerator("cpu")))
elif accelerator == "cuda":
runners.append(
ParallelRunner(BenchmarkJobGenerator("cuda"), n_jobs=torch.cuda.device_count()),
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to be fitting to a single line. Do you intentionally break it to multiple lines?

else:
msg = f"Unsupported accelerator: {accelerator}"
raise ValueError(msg)
return runners

def get_parser(self, parser: ArgumentParser | None = None) -> ArgumentParser:
"""Add arguments to the parser."""
parser = super().get_parser(parser)
parser.add_argument(
"--accelerator",
type=str | list[str],
default="cuda",
help="Hardware to run the benchmark on.",
)
BenchmarkJobGenerator.add_arguments(parser)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would an implementer always need to pass the parser to JobGenerator? or is this implementation-specific ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They have to pass it manually each time as job specific arguments are added to the same parser.

return parser
Loading
Loading