Skip to content

Commit

Permalink
🖌 refactor export callback (#640)
Browse files Browse the repository at this point in the history
* refactor export callback

* refactor export functions

* Rename export_convert to export

* Rename optimize to export + fix tests

* Fix imports

* Address tests

* Add nosec to surpress subprocess warnings

* Add nosec to surpress run
  • Loading branch information
ashwinvaidya17 authored Oct 20, 2022
1 parent 84a8e06 commit 406f79a
Show file tree
Hide file tree
Showing 12 changed files with 89 additions and 60 deletions.
4 changes: 2 additions & 2 deletions anomalib/deploy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from .export import ExportMode, export, get_model_metadata
from .inferencers import Inferencer, OpenVINOInferencer, TorchInferencer
from .optimize import export_convert, get_model_metadata

__all__ = ["Inferencer", "OpenVINOInferencer", "TorchInferencer", "export_convert", "get_model_metadata"]
__all__ = ["ExportMode", "Inferencer", "OpenVINOInferencer", "TorchInferencer", "export", "get_model_metadata"]
82 changes: 60 additions & 22 deletions anomalib/deploy/optimize.py → anomalib/deploy/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
# SPDX-License-Identifier: Apache-2.0

import json
import os
import subprocess # nosec
from enum import Enum
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, List, Tuple, Union

import numpy as np
import torch
Expand All @@ -16,6 +17,13 @@
from anomalib.models.components import AnomalyModule


class ExportMode(str, Enum):
"""Model export mode."""

ONNX = "onnx"
OPENVINO = "openvino"


def get_model_metadata(model: AnomalyModule) -> Dict[str, Tensor]:
"""Get meta data related to normalization from model.
Expand All @@ -41,38 +49,68 @@ def get_model_metadata(model: AnomalyModule) -> Dict[str, Tensor]:
return meta_data


def export_convert(
def export(
model: AnomalyModule,
input_size: Union[List[int], Tuple[int, int]],
export_mode: str,
export_path: Optional[Union[str, Path]] = None,
export_mode: ExportMode,
export_root: Union[str, Path],
):
"""Export the model to onnx format and convert to OpenVINO IR. Metadata.json is generated regardless of export mode.
"""Export the model to onnx format and (optionally) convert to OpenVINO IR if export mode is set to OpenVINO.
Metadata.json is generated regardless of export mode.
Args:
model (AnomalyModule): Model to convert.
input_size (Union[List[int], Tuple[int, int]]): Image size used as the input for onnx converter.
export_path (Union[str, Path]): Path to exported OpenVINO IR.
export_mode (str): Mode to export onnx or openvino
export_root (Union[str, Path]): Path to exported ONNX/OpenVINO IR.
export_mode (ExportMode): Mode to export the model. ONNX or OpenVINO.
"""
height, width = input_size
onnx_path = os.path.join(str(export_path), "model.onnx")
torch.onnx.export(
model.model,
torch.zeros((1, 3, height, width)).to(model.device),
onnx_path,
opset_version=11,
input_names=["input"],
output_names=["output"],
)
export_path = os.path.join(str(export_path), export_mode)
if export_mode == "openvino":
optimize_command = "mo --input_model " + str(onnx_path) + " --output_dir " + str(export_path)
assert os.system(optimize_command) == 0, "OpenVINO conversion failed"
# Write metadata to json file. The file is written in the same directory as the target model.
export_path: Path = Path(str(export_root)) / export_mode.value
export_path.mkdir(parents=True, exist_ok=True)
with open(Path(export_path) / "meta_data.json", "w", encoding="utf-8") as metadata_file:
meta_data = get_model_metadata(model)
# Convert metadata from torch
for key, value in meta_data.items():
if isinstance(value, Tensor):
meta_data[key] = value.numpy().tolist()
json.dump(meta_data, metadata_file, ensure_ascii=False, indent=4)

onnx_path = _export_to_onnx(model, input_size, export_path)
if export_mode == ExportMode.OPENVINO:
_export_to_openvino(export_path, onnx_path)


def _export_to_onnx(model: AnomalyModule, input_size: Union[List[int], Tuple[int, int]], export_path: Path) -> Path:
"""Export model to onnx.
Args:
model (AnomalyModule): Model to export.
input_size (Union[List[int], Tuple[int, int]]): Image size used as the input for onnx converter.
export_path (Path): Path to the root folder of the exported model.
Returns:
Path: Path to the exported onnx model.
"""
onnx_path = export_path / "model.onnx"
torch.onnx.export(
model.model,
torch.zeros((1, 3, *input_size)).to(model.device),
onnx_path,
opset_version=11,
input_names=["input"],
output_names=["output"],
)

return onnx_path


def _export_to_openvino(export_path: Union[str, Path], onnx_path: Path):
"""Convert onnx model to OpenVINO IR.
Args:
export_path (Union[str, Path]): Path to the root folder of the exported model.
onnx_path (Path): Path to the exported onnx model.
"""
optimize_command = ["mo", "--input_model", str(onnx_path), "--output_dir", str(export_path)]
subprocess.run(optimize_command, check=True) # nosec
2 changes: 1 addition & 1 deletion anomalib/deploy/inferencers/torch_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from torch import Tensor

from anomalib.config import get_configurable_parameters
from anomalib.deploy.optimize import get_model_metadata
from anomalib.deploy.export import get_model_metadata
from anomalib.models import get_model
from anomalib.models.components import AnomalyModule
from anomalib.pre_processing import PreProcessor
Expand Down
2 changes: 1 addition & 1 deletion anomalib/models/ganomaly/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ logging:
log_graph: false # Logs the model graph to respective logger.

optimization:
export_mode: ""
export_mode: null

# PL Trainer Args. Don't add extra parameter here.
trainer:
Expand Down
4 changes: 3 additions & 1 deletion anomalib/utils/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from omegaconf import DictConfig, ListConfig, OmegaConf
from pytorch_lightning.callbacks import Callback, ModelCheckpoint

from anomalib.deploy import ExportMode

from .cdf_normalization import CdfNormalizationCallback
from .graph import GraphLogger
from .metrics_configuration import MetricsConfigurationCallback
Expand Down Expand Up @@ -134,7 +136,7 @@ def get_callbacks(config: Union[ListConfig, DictConfig]) -> List[Callback]:
input_size=config.model.input_size,
dirpath=config.project.path,
filename="model",
export_mode=config.optimization.export_mode,
export_mode=ExportMode(config.optimization.export_mode),
)
)
else:
Expand Down
8 changes: 4 additions & 4 deletions anomalib/utils/callbacks/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from pytorch_lightning import Callback
from pytorch_lightning.utilities.cli import CALLBACK_REGISTRY

from anomalib.deploy import export_convert
from anomalib.deploy import ExportMode, export
from anomalib.models.components import AnomalyModule

logger = logging.getLogger(__name__)
Expand All @@ -28,7 +28,7 @@ class ExportCallback(Callback):
filename (str): Name of output model
"""

def __init__(self, input_size: Tuple[int, int], dirpath: str, filename: str, export_mode: str):
def __init__(self, input_size: Tuple[int, int], dirpath: str, filename: str, export_mode: ExportMode):
self.input_size = input_size
self.dirpath = dirpath
self.filename = filename
Expand All @@ -42,9 +42,9 @@ def on_train_end(self, trainer, pl_module: AnomalyModule) -> None: # pylint: di
"""
logger.info("Exporting the model")
os.makedirs(self.dirpath, exist_ok=True)
export_convert(
export(
model=pl_module,
input_size=self.input_size,
export_path=self.dirpath,
export_root=self.dirpath,
export_mode=self.export_mode,
)
4 changes: 3 additions & 1 deletion anomalib/utils/sweep/helpers/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,9 @@ def get_openvino_throughput(config: Union[DictConfig, ListConfig], model_path: P
Returns:
float: Inference throughput
"""
inferencer = OpenVINOInferencer(config, model_path / "model.xml", model_path / "meta_data.json")
inferencer = OpenVINOInferencer(
config, model_path / "openvino" / "model.xml", model_path / "openvino" / "meta_data.json"
)
openvino_dataloader = MockImageLoader(config.dataset.image_size, total_count=len(test_dataset))
start_time = time.time()
# Create test images on CPU. Since we don't care about performance metrics and just the throughput, use mock data.
Expand Down
9 changes: 5 additions & 4 deletions tests/pre_merge/deploy/test_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@

from anomalib.config import get_configurable_parameters
from anomalib.data import get_datamodule
from anomalib.deploy import OpenVINOInferencer, TorchInferencer, export_convert
from anomalib.deploy import OpenVINOInferencer, TorchInferencer, export
from anomalib.deploy.export import ExportMode
from anomalib.models import get_model
from anomalib.utils.callbacks import get_callbacks
from tests.helpers.dataset import TestDataset, get_dataset_path
Expand Down Expand Up @@ -102,11 +103,11 @@ def test_openvino_inference(self, model_name: str, category: str = "shapes", pat

trainer.fit(model=model, datamodule=datamodule)

export_convert(
export(
model=model,
input_size=model_config.dataset.image_size,
export_path=export_path,
export_mode="openvino",
export_root=export_path,
export_mode=ExportMode.OPENVINO,
)

# Test OpenVINO inferencer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytorch_lightning as pl
from pytorch_lightning.callbacks.early_stopping import EarlyStopping

from anomalib.deploy import ExportMode
from anomalib.utils.callbacks.export import ExportCallback
from tests.helpers.config import get_test_configurable_parameters
from tests.pre_merge.utils.callbacks.export_callback.dummy_lightning_model import (
Expand All @@ -15,7 +16,7 @@

@pytest.mark.parametrize(
"export_mode",
["openvino", "onnx"],
[ExportMode.OPENVINO, ExportMode.ONNX],
)
def test_export_model_callback(export_mode):
"""Tests if an optimized model is created."""
Expand Down Expand Up @@ -47,9 +48,9 @@ def test_export_model_callback(export_mode):
)
trainer.fit(model, datamodule=datamodule)

if "openvino" in export_mode:
if export_mode == ExportMode.OPENVINO:
assert os.path.exists(os.path.join(tmp_dir, "openvino/model.bin")), "Failed to generate OpenVINO model"
elif "onnx" in export_mode:
assert os.path.exists(os.path.join(tmp_dir, "model.onnx")), "Failed to generate ONNX model"
elif export_mode == ExportMode.ONNX:
assert os.path.exists(os.path.join(tmp_dir, "onnx/model.onnx")), "Failed to generate ONNX model"
else:
raise ValueError(f"Unknown export_mode {export_mode}. Supported modes: onnx or openvino.")
6 changes: 4 additions & 2 deletions tools/benchmarking/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,12 @@
import torch
from omegaconf import DictConfig, ListConfig, OmegaConf
from pytorch_lightning import Trainer, seed_everything
from utils import convert_to_openvino, upload_to_comet, upload_to_wandb, write_metrics
from utils import upload_to_comet, upload_to_wandb, write_metrics

from anomalib.config import get_configurable_parameters, update_input_size_config
from anomalib.data import get_datamodule
from anomalib.deploy import export
from anomalib.deploy.export import ExportMode
from anomalib.models import get_model
from anomalib.utils.loggers import configure_logger
from anomalib.utils.sweep import (
Expand Down Expand Up @@ -115,7 +117,7 @@ def get_single_model_metrics(model_config: Union[DictConfig, ListConfig], openvi
# Create dirs for openvino model export
openvino_export_path = project_path / Path("exported_models")
openvino_export_path.mkdir(parents=True, exist_ok=True)
convert_to_openvino(model, openvino_export_path, model_config.model.input_size)
export(model, model_config.model.input_size, ExportMode.OPENVINO, openvino_export_path)
openvino_throughput = get_openvino_throughput(
model_config, openvino_export_path, datamodule.test_dataloader().dataset
)
Expand Down
3 changes: 1 addition & 2 deletions tools/benchmarking/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from .convert import convert_to_openvino
from .metrics import upload_to_comet, upload_to_wandb, write_metrics

__all__ = ["convert_to_openvino", "write_metrics", "upload_to_comet", "upload_to_wandb"]
__all__ = ["write_metrics", "upload_to_comet", "upload_to_wandb"]
16 changes: 0 additions & 16 deletions tools/benchmarking/utils/convert.py

This file was deleted.

0 comments on commit 406f79a

Please sign in to comment.