diff --git a/CHANGELOG.md b/CHANGELOG.md index a31996bf07..7af5aec234 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). ### Changed +- Changed default inference device to AUTO in https://github.com/openvinotoolkit/anomalib/pull/1534 + ### Deprecated - Support only Python 3.10 and greater in https://github.com/openvinotoolkit/anomalib/pull/1299 diff --git a/requirements/openvino.txt b/requirements/openvino.txt index 48946cc0e9..cb68a34a59 100644 --- a/requirements/openvino.txt +++ b/requirements/openvino.txt @@ -1,6 +1,3 @@ -defusedxml==0.7.1 -requests>=2.26.0 -networkx~=2.5 -nncf>=2.1.0 -onnx>=1.10.1 -openvino-dev>=2022.3.0 +openvino-dev>=2023.0 +nncf>=2.5.0 +onnx>=1.13.1 diff --git a/src/anomalib/cli/cli.py b/src/anomalib/cli/cli.py index bcac9a8ea3..e1bbe93747 100644 --- a/src/anomalib/cli/cli.py +++ b/src/anomalib/cli/cli.py @@ -303,7 +303,7 @@ def train(self) -> Callable[..., _EVALUATE_OUTPUT]: return self.engine.train @property - def export(self) -> Callable[..., None]: + def export(self) -> Callable[..., Path | None]: """Export the model using engine's export method.""" return self.engine.export diff --git a/src/anomalib/deploy/export.py b/src/anomalib/deploy/export.py index 2706ffc8fb..d9dbf26d9d 100644 --- a/src/anomalib/deploy/export.py +++ b/src/anomalib/deploy/export.py @@ -52,7 +52,7 @@ def export_to_torch( export_path: Path | str, transform: dict[str, Any] | AnomalibDataset | AnomalibDataModule | A.Compose, task: TaskType | None = None, -) -> None: +) -> Path: """Export AnomalibModel to torch. Args: @@ -64,6 +64,9 @@ def export_to_torch( task (TaskType | None): Task type should be provided if transforms is of type dict or A.Compose object. Defaults to ``None``. + Returns: + Path: Path to the exported pytorch model. + Examples: Assume that we have a model to train and we want to export it to torch format. @@ -92,10 +95,12 @@ def export_to_torch( """ export_path = _create_export_path(export_path, ExportMode.TORCH) metadata = get_metadata(task=task, transform=transform, model=model) + pt_model_path = export_path / "model.pt" torch.save( obj={"model": model.model, "metadata": metadata}, - f=export_path / "model.pt", + f=pt_model_path, ) + return pt_model_path def export_to_onnx( @@ -177,9 +182,9 @@ def export_to_openvino( model: AnomalyModule, input_size: tuple[int, int], transform: dict[str, Any] | AnomalibDataset | AnomalibDataModule | A.Compose, - mo_args: dict[str, Any] | None = None, + ov_args: dict[str, Any] | None = None, task: TaskType | None = None, -) -> None: +) -> Path: """Convert onnx model to OpenVINO IR. Args: @@ -189,11 +194,14 @@ def export_to_openvino( transform (dict[str, Any] | AnomalibDataset | AnomalibDataModule | A.Compose): Data transforms (augmentations) used for the model. When using dict, ensure that the transform dict is in the format required by Albumentations. - mo_args: Model optimizer arguments for OpenVINO model conversion. + ov_args: Model optimizer arguments for OpenVINO model conversion. Defaults to ``None``. task (TaskType | None): Task type should be provided if transforms is of type dict or A.Compose object. Defaults to ``None``. + Returns: + Path: Path to the exported onnx model. + Raises: ModuleNotFoundError: If OpenVINO is not installed. @@ -232,13 +240,15 @@ def export_to_openvino( """ model_path = export_to_onnx(model, input_size, export_path, transform, task, ExportMode.OPENVINO) - mo_args = {} if mo_args is None else mo_args + ov_model_path = model_path.with_suffix(".xml") + ov_args = {} if ov_args is None else ov_args if convert_model is not None and serialize is not None: - model = convert_model(input_model=str(model_path), output_dir=str(model_path.parent), **mo_args) - serialize(model, model_path.with_suffix(".xml")) - else: - logger.exception("Could not find OpenVINO methods. Please check OpenVINO installation.") - raise ModuleNotFoundError + model = convert_model(model_path, **ov_args) + serialize(model, ov_model_path) + return ov_model_path + + logger.exception("Could not find OpenVINO methods. Please check OpenVINO installation.") + raise ModuleNotFoundError def get_metadata( diff --git a/src/anomalib/deploy/inferencers/openvino_inferencer.py b/src/anomalib/deploy/inferencers/openvino_inferencer.py index 2a139d20ac..51512d2c36 100644 --- a/src/anomalib/deploy/inferencers/openvino_inferencer.py +++ b/src/anomalib/deploy/inferencers/openvino_inferencer.py @@ -21,7 +21,7 @@ logger = logging.getLogger("anomalib") if find_spec("openvino") is not None: - from openvino.runtime import Core + import openvino.runtime as ov if TYPE_CHECKING: from openvino.runtime import CompiledModel @@ -37,11 +37,12 @@ class OpenVINOInferencer(Inferencer): metadata (str | Path | dict, optional): Path to metadata file or a dict object defining the metadata. Defaults to ``None``. - device (str | None, optional): Device to run the inference on. - Defaults to ``CPU``. + device (str | None, optional): Device to run the inference on (AUTO, CPU, GPU, NPU). + Defaults to ``AUTO``. task (TaskType | None, optional): Task type. Defaults to ``None``. - + config (dict | None, optional): Configuration parameters for the inference + Defaults to ``None``. Examples: Assume that we have an OpenVINO IR model and metadata files in the following structure: @@ -89,7 +90,7 @@ def __init__( self, path: str | Path | tuple[bytes, bytes], metadata: str | Path | dict | None = None, - device: str | None = "CPU", + device: str | None = "AUTO", task: str | None = None, config: dict | None = None, ) -> None: @@ -112,11 +113,10 @@ def load_model(self, path: str | Path | tuple[bytes, bytes]) -> tuple[Any, Any, [tuple[str, str, ExecutableNetwork]]: Input and Output blob names together with the Executable network. """ - ie_core = Core() + core = ov.Core() # If tuple of bytes is passed - if isinstance(path, tuple): - model = ie_core.read_model(model=path[0], weights=path[1], init_from_buffer=True) + model = core.read_model(model=path[0], weights=path[1]) else: path = path if isinstance(path, Path) else Path(path) if path.suffix in (".bin", ".xml"): @@ -124,18 +124,18 @@ def load_model(self, path: str | Path | tuple[bytes, bytes]) -> tuple[Any, Any, bin_path, xml_path = path, path.with_suffix(".xml") elif path.suffix == ".xml": xml_path, bin_path = path, path.with_suffix(".bin") - model = ie_core.read_model(xml_path, bin_path) + model = core.read_model(xml_path, bin_path) elif path.suffix == ".onnx": - model = ie_core.read_model(path) + model = core.read_model(path) else: msg = f"Path must be .onnx, .bin or .xml file. Got {path.suffix}" raise ValueError(msg) # Create cache folder cache_folder = Path("cache") cache_folder.mkdir(exist_ok=True) - ie_core.set_property({"CACHE_DIR": cache_folder}) + core.set_property({"CACHE_DIR": cache_folder}) - compile_model = ie_core.compile_model(model=model, device_name=self.device, config=self.config) + compile_model = core.compile_model(model=model, device_name=self.device, config=self.config) input_blob = compile_model.input(0) output_blob = compile_model.output(0) @@ -143,7 +143,7 @@ def load_model(self, path: str | Path | tuple[bytes, bytes]) -> tuple[Any, Any, return input_blob, output_blob, compile_model def pre_process(self, image: np.ndarray) -> np.ndarray: - """Pre process the input image by applying transformations. + """Pre-process the input image by applying transformations. Args: image (np.ndarray): Input image. @@ -178,8 +178,8 @@ def post_process(self, predictions: np.ndarray, metadata: dict | DictConfig | No Args: predictions (np.ndarray): Raw output predicted by the model. - metadata (Dict, optional): Meta data. Post-processing step sometimes requires - additional meta data such as image shape. This variable comprises such info. + metadata (Dict, optional): Metadata. Post-processing step sometimes requires + additional metadata such as image shape. This variable comprises such info. Defaults to None. Returns: diff --git a/src/anomalib/engine/engine.py b/src/anomalib/engine/engine.py index a84a16fda0..96caef228f 100644 --- a/src/anomalib/engine/engine.py +++ b/src/anomalib/engine/engine.py @@ -459,9 +459,9 @@ def export( datamodule: AnomalibDataModule | None = None, dataset: AnomalibDataset | None = None, input_size: tuple[int, int] | None = None, - mo_args: dict[str, Any] | None = None, + ov_args: dict[str, Any] | None = None, ckpt_path: str | None = None, - ) -> None: + ) -> Path | None: """Export the model in the specified format. Args: @@ -479,10 +479,13 @@ def export( is optional. Defaults to None. input_size (tuple[int, int] | None, optional): This is required only if the model is exported to ONNX and OpenVINO format. Defaults to None. - mo_args (dict[str, Any] | None, optional): This is optional and used only for OpenVINO's model optimizer. + ov_args (dict[str, Any] | None, optional): This is optional and used only for OpenVINO's model optimizer. Defaults to None. ckpt_path (str | None): Checkpoint path. If provided, the model will be loaded from this path. + Returns: + Path: Path to the exported model. + Raises: ValueError: If Dataset, Datamodule, and transform are not provided. TypeError: If path to the transform file is not a string or Path. @@ -509,28 +512,30 @@ def export( logger.exception(f"Unknown type {type(transform)} for transform.") raise TypeError - if export_mode in (ExportMode.OPENVINO, ExportMode.ONNX): - assert input_size is not None, "input_size must be provided for OpenVINO and ONNX export modes." if export_path is None: export_path = Path(self.trainer.default_root_dir) + if export_mode == ExportMode.TORCH: - export_to_torch(model=model, export_path=export_path, transform=transform, task=self.task) - elif export_mode == ExportMode.ONNX: + return export_to_torch(model=model, export_path=export_path, transform=transform, task=self.task) + if export_mode == ExportMode.ONNX: assert input_size is not None, "input_size must be provided for ONNX export mode." - export_to_onnx( + return export_to_onnx( model=model, input_size=input_size, export_path=export_path, transform=transform, task=self.task, ) - else: + if export_mode == ExportMode.OPENVINO: assert input_size is not None, "input_size must be provided for OpenVINO export mode." - export_to_openvino( + return export_to_openvino( model=model, input_size=input_size, export_path=export_path, transform=transform, task=self.task, - mo_args=mo_args, + ov_args=ov_args, ) + + logging.error(f"Export mode {export_mode} is not supported yet.") + return None diff --git a/src/anomalib/pipelines/benchmarking/benchmark.py b/src/anomalib/pipelines/benchmarking/benchmark.py index b5c5b038a2..b28da259ba 100644 --- a/src/anomalib/pipelines/benchmarking/benchmark.py +++ b/src/anomalib/pipelines/benchmarking/benchmark.py @@ -143,7 +143,7 @@ def get_single_model_metrics( model=model, input_size=input_size, transform=engine.trainer.datamodule.test_data.transform, - mo_args={}, + ov_args={}, task=engine.trainer.datamodule.test_data.task, ) openvino_throughput = get_openvino_throughput(model_path=project_path, test_dataset=datamodule.test_data) diff --git a/tests/integration/tools/test_gradio_entrypoint.py b/tests/integration/tools/test_gradio_entrypoint.py index 6e5b0f7de9..b92e716734 100644 --- a/tests/integration/tools/test_gradio_entrypoint.py +++ b/tests/integration/tools/test_gradio_entrypoint.py @@ -78,7 +78,7 @@ def test_openvino_inference( model=model, input_size=(256, 256), transform=transforms_config, - mo_args={}, + ov_args={}, task=TaskType.SEGMENTATION, ) diff --git a/tests/integration/tools/test_openvino_entrypoint.py b/tests/integration/tools/test_openvino_entrypoint.py index 99c30a1088..32db0c01e6 100644 --- a/tests/integration/tools/test_openvino_entrypoint.py +++ b/tests/integration/tools/test_openvino_entrypoint.py @@ -49,7 +49,7 @@ def test_openvino_inference( model=model, input_size=(256, 256), transform=transforms_config, - mo_args={}, + ov_args={}, task=TaskType.SEGMENTATION, ) diff --git a/tests/unit/deploy/test_inferencer.py b/tests/unit/deploy/test_inferencer.py index 26c8bad672..c31a329ae9 100644 --- a/tests/unit/deploy/test_inferencer.py +++ b/tests/unit/deploy/test_inferencer.py @@ -108,21 +108,21 @@ def test_openvino_inference(task: TaskType, ckpt_path: Callable[[str], Path], da """ model = Padim() engine = Engine(task=task) - export_path = ckpt_path("Padim").parent.parent + export_dir = ckpt_path("Padim").parent.parent datamodule = MVTec(root=dataset_path / "mvtec", category="dummy") - engine.export( + exported_xml_file_path = engine.export( model=model, export_mode=ExportMode.OPENVINO, input_size=(256, 256), - export_path=export_path, + export_path=export_dir, datamodule=datamodule, ckpt_path=str(ckpt_path("Padim")), ) # Test OpenVINO inferencer openvino_inferencer = OpenVINOInferencer( - export_path / "weights/openvino/model.xml", - export_path / "weights/openvino/metadata.json", + exported_xml_file_path, + exported_xml_file_path.parent / "metadata.json", ) openvino_dataloader = _MockImageLoader([256, 256], total_count=1) for image in openvino_dataloader():