Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minor fixes #1182

Merged
merged 12 commits into from
Jul 20, 2023
8 changes: 7 additions & 1 deletion src/anomalib/data/utils/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,16 @@ def get_transforms(
if isinstance(config, DictConfig):
logger.info("Loading transforms from config File")
transforms_list = []

if "Resize" not in config.keys() and image_size is not None:
resize_height, resize_width = get_image_height_and_width(image_size)
transforms_list.append(A.Resize(height=resize_height, width=resize_width, always_apply=True))
logger.info("Resize %s added!", (resize_height, resize_width))

for key, value in config.items():
if hasattr(A, key):
transform = getattr(A, key)(**value)
logger.info(f"Transform {transform} added!")
logger.info("Transform %s added!", transform)
transforms_list.append(transform)
else:
raise ValueError(f"Transformation {key} is not part of albumentations")
Expand Down
5 changes: 4 additions & 1 deletion src/anomalib/deploy/inferencers/openvino_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from __future__ import annotations

import logging
from importlib.util import find_spec
from pathlib import Path
from typing import Any
Expand All @@ -18,10 +19,12 @@

from .base_inferencer import Inferencer

logger = logging.getLogger("anomalib")

if find_spec("openvino") is not None:
from openvino.runtime import Core
else:
raise ImportError("OpenVINO is not installed. Please install OpenVINO to use OpenVINOInferencer.")
logger.warning("OpenVINO is not installed. Please install OpenVINO to use OpenVINOInferencer.")


class OpenVINOInferencer(Inferencer):
Expand Down
2 changes: 1 addition & 1 deletion src/anomalib/models/components/base/anomaly_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def _collect_outputs(
image_metric.update(output["pred_scores"], output["label"].int())
if "mask" in output.keys() and "anomaly_maps" in output.keys():
pixel_metric.cpu()
pixel_metric.update(output["anomaly_maps"], output["mask"].int())
pixel_metric.update(torch.squeeze(output["anomaly_maps"]), torch.squeeze(output["mask"].int()))

@staticmethod
def _post_process(outputs: STEP_OUTPUT) -> None:
Expand Down
17 changes: 3 additions & 14 deletions src/anomalib/models/efficient_ad/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from anomalib.data.utils import DownloadInfo, download_and_extract
from anomalib.models.components import AnomalyModule

from .torch_model import EfficientAdModel, EfficientAdModelSize
from .torch_model import EfficientAdModel, EfficientAdModelSize, reduce_tensor_elems

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -192,19 +192,8 @@ def _get_quantiles_of_maps(self, maps: list[Tensor]) -> tuple[Tensor, Tensor]:
Returns:
tuple[Tensor, Tensor]: Two scalars - the 90% and the 99.5% quantile.
"""
maps_flat = torch.flatten(torch.cat(maps))
# torch.quantile only works with input size up to 2**24 elements, see
# https://github.com/pytorch/pytorch/blob/b9f81a483a7879cd3709fd26bcec5f1ee33577e6/aten/src/ATen/native/Sorting.cpp#L291
# if we have more elements we need to decrease the size
# we do this by sampling random elements of maps_flat because then
# the locations of the quantiles (90% and 99.5%) will still be
# valid even though they might not be the exact quantiles.
max_input_size = 2**24
if len(maps_flat) > max_input_size:
# select a random subset with max_input_size elements.
perm = torch.randperm(len(maps_flat), device=self.device)
idx = perm[:max_input_size]
maps_flat = maps_flat[idx]

maps_flat = reduce_tensor_elems(torch.cat(maps))
qa = torch.quantile(maps_flat, q=0.9).to(self.device)
qb = torch.quantile(maps_flat, q=0.995).to(self.device)
return qa, qb
Expand Down
25 changes: 24 additions & 1 deletion src/anomalib/models/efficient_ad/torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,29 @@ def imagenet_norm_batch(x):
return x_norm


def reduce_tensor_elems(tensor: torch.Tensor, m=2**24) -> torch.Tensor:
"""Flattens n-dimensional tensors, selects m elements from it
and returns the selected elements as tensor. It is used to select
at most 2**24 for torch.quantile operation, as it is the maximum
supported number of elements.
https://github.com/pytorch/pytorch/blob/b9f81a483a7879cd3709fd26bcec5f1ee33577e6/aten/src/ATen/native/Sorting.cpp#L291

Args:
tensor (torch.Tensor): input tensor from which elements are selected
m (int): number of maximum tensor elements. Default: 2**24

Returns:
Tensor: reduced tensor
"""
tensor = torch.flatten(tensor)
if len(tensor) > m:
# select a random subset with m elements.
perm = torch.randperm(len(tensor), device=tensor.device)
idx = perm[:m]
tensor = tensor[idx]
return tensor


class EfficientAdModelSize(str, Enum):
"""Supported EfficientAd model sizes"""

Expand Down Expand Up @@ -123,7 +146,6 @@ class Decoder(nn.Module):
def __init__(self, out_channels, padding, img_size, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.img_size = img_size
self.last_upsample = 64 if padding else 56
self.last_upsample = int(img_size / 4) if padding else int(img_size / 4) - 8
self.deconv1 = nn.Conv2d(64, 64, kernel_size=4, stride=1, padding=2)
self.deconv2 = nn.Conv2d(64, 64, kernel_size=4, stride=1, padding=2)
Expand Down Expand Up @@ -279,6 +301,7 @@ def forward(self, batch: Tensor, batch_imagenet: Tensor = None) -> Tensor | dict

if self.training:
# Student loss
distance_st = reduce_tensor_elems(distance_st)
d_hard = torch.quantile(distance_st, 0.999)
loss_hard = torch.mean(distance_st[distance_st >= d_hard])
student_output_penalty = self.student(batch_imagenet)[:, : self.teacher_out_channels, :, :]
Expand Down
Loading