Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐞 Add device flag #601

Merged
merged 9 commits into from
Nov 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 26 additions & 4 deletions anomalib/deploy/inferencers/torch_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,19 @@ class TorchInferencer(Inferencer):
model_source (Union[str, Path, AnomalyModule]): Path to the model ckpt file or the Anomaly model.
meta_data_path (Union[str, Path], optional): Path to metadata file. If none, it tries to load the params
from the model state_dict. Defaults to None.
device (Optional[str], optional): Device to use for inference. Options are auto, cpu, cuda. Defaults to "auto".
"""

def __init__(
self,
config: Union[str, Path, DictConfig, ListConfig],
model_source: Union[str, Path, AnomalyModule],
meta_data_path: Union[str, Path] = None,
meta_data_path: Optional[Union[str, Path]] = None,
device: str = "auto",
):

self.device = self._get_device(device)

# Check and load the configuration
if isinstance(config, (str, Path)):
self.config = get_configurable_parameters(config_path=config)
Expand All @@ -55,6 +59,24 @@ def __init__(

self.meta_data = self._load_meta_data(meta_data_path)

def _get_device(self, device: str) -> torch.device:
"""Get the device to use for inference.

Args:
device (str): Device to use for inference. Options are auto, cpu, cuda.

Returns:
torch.device: Device to use for inference.
"""
if device not in ("auto", "cpu", "cuda", "gpu"):
raise ValueError(f"Unknown device {device}")

if device == "auto":
device = "cuda" if torch.cuda.is_available() else "cpu"
elif device == "gpu":
device = "cuda"
return torch.device(device)

def _load_meta_data(self, path: Optional[Union[str, Path]] = None) -> Union[Dict, DictConfig]:
"""Load metadata from file or from model state dict.

Expand Down Expand Up @@ -82,9 +104,9 @@ def load_model(self, path: Union[str, Path]) -> AnomalyModule:
(AnomalyModule): PyTorch Lightning model.
"""
model = get_model(self.config)
model.load_state_dict(torch.load(path)["state_dict"])
model.load_state_dict(torch.load(path, map_location=self.device)["state_dict"])
model.eval()
return model
return model.to(self.device)

def pre_process(self, image: np.ndarray) -> Tensor:
"""Pre process the input image by applying transformations.
Expand All @@ -105,7 +127,7 @@ def pre_process(self, image: np.ndarray) -> Tensor:
if len(processed_image) == 3:
processed_image = processed_image.unsqueeze(0)

return processed_image
return processed_image.to(self.device)

def forward(self, image: Tensor) -> Tensor:
"""Forward-Pass input tensor to the model.
Expand Down
51 changes: 11 additions & 40 deletions anomalib/utils/sweep/helpers/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@

import time
from pathlib import Path
from typing import Iterable, List, Union
from typing import Union

import numpy as np
import torch
from omegaconf import DictConfig, ListConfig
from torch.utils.data import DataLoader
Expand All @@ -16,35 +15,6 @@
from anomalib.models.components import AnomalyModule


class MockImageLoader:
"""Create mock images for inference on CPU based on the specifics of the original torch test dataset.

Uses yield so as to avoid storing everything in the memory.

Args:
image_size (List[int]): Size of input image
total_count (int): Total images in the test dataset
"""

def __init__(self, image_size: List[int], total_count: int):
self.total_count = total_count
self.image_size = image_size
self.image = np.ones((*self.image_size, 3)).astype(np.uint8)

def __len__(self):
"""Get total count of images."""
return self.total_count

def __call__(self) -> Iterable[np.ndarray]:
"""Yield batch of generated images.

Args:
idx (int): Unused
"""
for _ in range(self.total_count):
yield self.image


def get_torch_throughput(
config: Union[DictConfig, ListConfig], model: AnomalyModule, test_dataset: DataLoader
) -> float:
Expand All @@ -60,12 +30,15 @@ def get_torch_throughput(
"""
torch.set_grad_enabled(False)
model.eval()
inferencer = TorchInferencer(config, model)
torch_dataloader = MockImageLoader(config.dataset.image_size, len(test_dataset))

device = config.trainer.accelerator
if device == "gpu":
device = "cuda"

inferencer = TorchInferencer(config, model.to(device), device=device)
start_time = time.time()
# Since we don't care about performance metrics and just the throughput, use mock data.
for image in torch_dataloader():
inferencer.predict(image)
for image_path in test_dataset.samples.image_path:
inferencer.predict(image_path)

# get throughput
inference_time = time.time() - start_time
Expand All @@ -89,11 +62,9 @@ def get_openvino_throughput(config: Union[DictConfig, ListConfig], model_path: P
inferencer = OpenVINOInferencer(
config, model_path / "openvino" / "model.xml", model_path / "openvino" / "meta_data.json"
)
openvino_dataloader = MockImageLoader(config.dataset.image_size, total_count=len(test_dataset))
start_time = time.time()
# Create test images on CPU. Since we don't care about performance metrics and just the throughput, use mock data.
for image in openvino_dataloader():
inferencer.predict(image)
for image_path in test_dataset.samples.image_path:
inferencer.predict(image_path)

# get throughput
inference_time = time.time() - start_time
Expand Down
2 changes: 1 addition & 1 deletion tests/pre_merge/deploy/test_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def test_torch_inference(self, model_name: str, category: str = "shapes", path:
model.eval()

# Test torch inferencer
torch_inferencer = TorchInferencer(model_config, model)
torch_inferencer = TorchInferencer(model_config, model, device="cpu")
torch_dataloader = MockImageLoader(model_config.dataset.image_size, total_count=1)
with torch.no_grad():
for image in torch_dataloader():
Expand Down
2 changes: 1 addition & 1 deletion tools/benchmarking/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def get_single_model_metrics(model_config: Union[DictConfig, ListConfig], openvi
data = {
"Training Time (s)": training_time,
"Testing Time (s)": testing_time,
"Inference Throughput (fps)": throughput,
f"Inference Throughput {model_config.trainer.accelerator} (fps)": throughput,
"OpenVINO Inference Throughput (fps)": openvino_throughput,
}
for key, val in test_results[0].items():
Expand Down
14 changes: 13 additions & 1 deletion tools/inference/torch_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from argparse import ArgumentParser, Namespace
from pathlib import Path

import torch

from anomalib.data.utils import (
generate_output_image_filename,
get_image_filenames,
Expand All @@ -31,6 +33,14 @@ def get_args() -> Namespace:
parser.add_argument("--weights", type=Path, required=True, help="Path to model weights")
parser.add_argument("--input", type=Path, required=True, help="Path to an image to infer.")
parser.add_argument("--output", type=Path, required=False, help="Path to save the output image.")
parser.add_argument(
"--device",
type=str,
required=False,
default="auto",
help="Device to use for inference. Defaults to auto.",
choices=["auto", "cpu", "gpu", "cuda"], # cuda and gpu are the same but provided for convenience
)
parser.add_argument(
"--task",
type=str,
Expand Down Expand Up @@ -69,8 +79,10 @@ def infer() -> None:
# information regarding the data, model, train and inference details.
args = get_args()

torch.set_grad_enabled(False)

# Create the inferencer and visualizer.
inferencer = TorchInferencer(config=args.config, model_source=args.weights)
inferencer = TorchInferencer(config=args.config, model_source=args.weights, device=args.device)
visualizer = Visualizer(mode=args.visualization_mode, task=args.task)

filenames = get_image_filenames(path=args.input)
Expand Down