From 268a6bfeb93c32b46ef902dcd6116a803c95e08a Mon Sep 17 00:00:00 2001 From: Dick Ameln Date: Tue, 12 Apr 2022 15:18:02 +0200 Subject: [PATCH] Refactor model implementations (#225) * refactor CFlow implementation * refactor DFM implementation * refactor PADIM implementation * refactor PatchCore implementation * refactor STFPM implementation * revert model tests * remove unintentionally committed file * model.py -> lightning_model.py --- anomalib/models/__init__.py | 4 +- anomalib/models/cflow/__init__.py | 9 +- anomalib/models/cflow/anomaly_map.py | 97 +++++ anomalib/models/cflow/lightning_model.py | 156 ++++++++ anomalib/models/cflow/model.py | 354 ------------------ anomalib/models/cflow/torch_model.py | 130 +++++++ .../models/cflow/{backbone.py => utils.py} | 19 +- anomalib/models/dfkde/__init__.py | 4 + .../dfkde/{model.py => lightning_model.py} | 0 anomalib/models/dfm/__init__.py | 4 + .../dfm/{model.py => lightning_model.py} | 2 +- .../dfm/{dfm_model.py => torch_model.py} | 2 +- anomalib/models/ganomaly/__init__.py | 2 +- .../ganomaly/{model.py => lightning_model.py} | 0 anomalib/models/padim/__init__.py | 4 + anomalib/models/padim/anomaly_map.py | 146 ++++++++ anomalib/models/padim/lightning_model.py | 101 +++++ anomalib/models/padim/model.py | 346 ----------------- anomalib/models/padim/torch_model.py | 140 +++++++ anomalib/models/patchcore/__init__.py | 4 + anomalib/models/patchcore/anomaly_map.py | 100 +++++ anomalib/models/patchcore/lightning_model.py | 108 ++++++ anomalib/models/patchcore/model.py | 334 ----------------- anomalib/models/patchcore/torch_model.py | 169 +++++++++ anomalib/models/stfpm/__init__.py | 9 +- anomalib/models/stfpm/anomaly_map.py | 98 +++++ anomalib/models/stfpm/lightning_model.py | 102 +++++ anomalib/models/stfpm/model.py | 313 ---------------- anomalib/models/stfpm/torch_model.py | 156 ++++++++ 29 files changed, 1550 insertions(+), 1363 deletions(-) create mode 100644 anomalib/models/cflow/anomaly_map.py create mode 100644 anomalib/models/cflow/lightning_model.py delete mode 100644 anomalib/models/cflow/model.py create mode 100644 anomalib/models/cflow/torch_model.py rename anomalib/models/cflow/{backbone.py => utils.py} (85%) rename anomalib/models/dfkde/{model.py => lightning_model.py} (100%) rename anomalib/models/dfm/{model.py => lightning_model.py} (98%) rename anomalib/models/dfm/{dfm_model.py => torch_model.py} (99%) rename anomalib/models/ganomaly/{model.py => lightning_model.py} (100%) create mode 100644 anomalib/models/padim/anomaly_map.py create mode 100644 anomalib/models/padim/lightning_model.py delete mode 100644 anomalib/models/padim/model.py create mode 100644 anomalib/models/padim/torch_model.py create mode 100644 anomalib/models/patchcore/anomaly_map.py create mode 100644 anomalib/models/patchcore/lightning_model.py delete mode 100644 anomalib/models/patchcore/model.py create mode 100644 anomalib/models/patchcore/torch_model.py create mode 100644 anomalib/models/stfpm/anomaly_map.py create mode 100644 anomalib/models/stfpm/lightning_model.py delete mode 100644 anomalib/models/stfpm/model.py create mode 100644 anomalib/models/stfpm/torch_model.py diff --git a/anomalib/models/__init__.py b/anomalib/models/__init__.py index a450724f24..f6c3bc4884 100644 --- a/anomalib/models/__init__.py +++ b/anomalib/models/__init__.py @@ -25,7 +25,7 @@ # TODO(AlexanderDokuchaev): Workaround of wrapping by NNCF. # Can't not wrap `spatial_softmax2d` if use import_module. -from anomalib.models.padim.model import PadimLightning # noqa: F401 +from anomalib.models.padim.lightning_model import PadimLightning # noqa: F401 def get_model(config: Union[DictConfig, ListConfig]) -> AnomalyModule: @@ -62,7 +62,7 @@ def get_model(config: Union[DictConfig, ListConfig]) -> AnomalyModule: raise ValueError(f"Unknown model {config.model.name} for OpenVINO model!") else: if config.model.name in torch_model_list: - module = import_module(f"anomalib.models.{config.model.name}.model") + module = import_module(f"anomalib.models.{config.model.name}") model = getattr(module, f"{config.model.name.capitalize()}Lightning") else: raise ValueError(f"Unknown model {config.model.name}!") diff --git a/anomalib/models/cflow/__init__.py b/anomalib/models/cflow/__init__.py index 1a4c68fe91..e187049804 100644 --- a/anomalib/models/cflow/__init__.py +++ b/anomalib/models/cflow/__init__.py @@ -1,7 +1,4 @@ -"""Real-Time Unsupervised Anomaly Detection via Conditional Normalizing Flows. - -[CW-AD](https://arxiv.org/pdf/2107.12571v1.pdf) -""" +"""Real-Time Unsupervised Anomaly Detection via Conditional Normalizing Flows.""" # Copyright (C) 2020 Intel Corporation # @@ -16,3 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions # and limitations under the License. + +from .lightning_model import CflowLightning + +__all__ = ["CflowLightning"] diff --git a/anomalib/models/cflow/anomaly_map.py b/anomalib/models/cflow/anomaly_map.py new file mode 100644 index 0000000000..8e6ec82179 --- /dev/null +++ b/anomalib/models/cflow/anomaly_map.py @@ -0,0 +1,97 @@ +"""Anomaly Map Generator for CFlow model implementation.""" + +# Copyright (C) 2020 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. + +from typing import List, Tuple, Union, cast + +import torch +import torch.nn.functional as F +from omegaconf import ListConfig +from torch import Tensor + + +class AnomalyMapGenerator: + """Generate Anomaly Heatmap.""" + + def __init__( + self, + image_size: Union[ListConfig, Tuple], + pool_layers: List[str], + ): + self.distance = torch.nn.PairwiseDistance(p=2, keepdim=True) + self.image_size = image_size if isinstance(image_size, tuple) else tuple(image_size) + self.pool_layers: List[str] = pool_layers + + def compute_anomaly_map( + self, distribution: Union[List[Tensor], List[List]], height: List[int], width: List[int] + ) -> Tensor: + """Compute the layer map based on likelihood estimation. + + Args: + distribution: Probability distribution for each decoder block + height: blocks height + width: blocks width + + Returns: + Final Anomaly Map + + """ + + test_map: List[Tensor] = [] + for layer_idx in range(len(self.pool_layers)): + test_norm = torch.tensor(distribution[layer_idx], dtype=torch.double) # pylint: disable=not-callable + test_norm -= torch.max(test_norm) # normalize likelihoods to (-Inf:0] by subtracting a constant + test_prob = torch.exp(test_norm) # convert to probs in range [0:1] + test_mask = test_prob.reshape(-1, height[layer_idx], width[layer_idx]) + # upsample + test_map.append( + F.interpolate( + test_mask.unsqueeze(1), size=self.image_size, mode="bilinear", align_corners=True + ).squeeze() + ) + # score aggregation + score_map = torch.zeros_like(test_map[0]) + for layer_idx in range(len(self.pool_layers)): + score_map += test_map[layer_idx] + score_mask = score_map + # invert probs to anomaly scores + anomaly_map = score_mask.max() - score_mask + + return anomaly_map + + def __call__(self, **kwargs: Union[List[Tensor], List[int], List[List]]) -> Tensor: + """Returns anomaly_map. + + Expects `distribution`, `height` and 'width' keywords to be passed explicitly + + Example + >>> anomaly_map_generator = AnomalyMapGenerator(image_size=tuple(hparams.model.input_size), + >>> pool_layers=pool_layers) + >>> output = self.anomaly_map_generator(distribution=dist, height=height, width=width) + + Raises: + ValueError: `distribution`, `height` and 'width' keys are not found + + Returns: + torch.Tensor: anomaly map + """ + if not ("distribution" in kwargs and "height" in kwargs and "width" in kwargs): + raise KeyError(f"Expected keys `distribution`, `height` and `width`. Found {kwargs.keys()}") + + # placate mypy + distribution: List[Tensor] = cast(List[Tensor], kwargs["distribution"]) + height: List[int] = cast(List[int], kwargs["height"]) + width: List[int] = cast(List[int], kwargs["width"]) + return self.compute_anomaly_map(distribution, height, width) diff --git a/anomalib/models/cflow/lightning_model.py b/anomalib/models/cflow/lightning_model.py new file mode 100644 index 0000000000..b7d51e2b85 --- /dev/null +++ b/anomalib/models/cflow/lightning_model.py @@ -0,0 +1,156 @@ +"""CFLOW: Real-Time Unsupervised Anomaly Detection via Conditional Normalizing Flows. + +https://arxiv.org/pdf/2107.12571v1.pdf +""" + +# Copyright (C) 2020 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. + +import einops +import torch +import torch.nn.functional as F +from pytorch_lightning.callbacks import EarlyStopping +from torch import optim + +from anomalib.models.cflow.torch_model import CflowModel +from anomalib.models.cflow.utils import get_logp, positional_encoding_2d +from anomalib.models.components import AnomalyModule + +__all__ = ["CflowLightning"] + + +class CflowLightning(AnomalyModule): + """PL Lightning Module for the CFLOW algorithm.""" + + def __init__(self, hparams): + super().__init__(hparams) + + self.model: CflowModel = CflowModel(hparams) + self.loss_val = 0 + self.automatic_optimization = False + + def configure_callbacks(self): + """Configure model-specific callbacks.""" + early_stopping = EarlyStopping( + monitor=self.hparams.model.early_stopping.metric, + patience=self.hparams.model.early_stopping.patience, + mode=self.hparams.model.early_stopping.mode, + ) + return [early_stopping] + + def configure_optimizers(self) -> torch.optim.Optimizer: + """Configures optimizers for each decoder. + + Returns: + Optimizer: Adam optimizer for each decoder + """ + decoders_parameters = [] + for decoder_idx in range(len(self.model.pool_layers)): + decoders_parameters.extend(list(self.model.decoders[decoder_idx].parameters())) + + optimizer = optim.Adam( + params=decoders_parameters, + lr=self.hparams.model.lr, + ) + return optimizer + + def training_step(self, batch, _): # pylint: disable=arguments-differ + """Training Step of CFLOW. + + For each batch, decoder layers are trained with a dynamic fiber batch size. + Training step is performed manually as multiple training steps are involved + per batch of input images + + Args: + batch: Input batch + _: Index of the batch. + + Returns: + Loss value for the batch + + """ + opt = self.optimizers() + self.model.encoder.eval() + + images = batch["image"] + activation = self.model.encoder(images) + avg_loss = torch.zeros([1], dtype=torch.float64).to(images.device) + + height = [] + width = [] + for layer_idx, layer in enumerate(self.model.pool_layers): + encoder_activations = activation[layer].detach() # BxCxHxW + + batch_size, dim_feature_vector, im_height, im_width = encoder_activations.size() + image_size = im_height * im_width + embedding_length = batch_size * image_size # number of rows in the conditional vector + + height.append(im_height) + width.append(im_width) + # repeats positional encoding for the entire batch 1 C H W to B C H W + pos_encoding = einops.repeat( + positional_encoding_2d(self.model.condition_vector, im_height, im_width).unsqueeze(0), + "b c h w-> (tile b) c h w", + tile=batch_size, + ).to(images.device) + c_r = einops.rearrange(pos_encoding, "b c h w -> (b h w) c") # BHWxP + e_r = einops.rearrange(encoder_activations, "b c h w -> (b h w) c") # BHWxC + perm = torch.randperm(embedding_length) # BHW + decoder = self.model.decoders[layer_idx].to(images.device) + + fiber_batches = embedding_length // self.model.fiber_batch_size # number of fiber batches + assert fiber_batches > 0, "Make sure we have enough fibers, otherwise decrease N or batch-size!" + + for batch_num in range(fiber_batches): # per-fiber processing + opt.zero_grad() + if batch_num < (fiber_batches - 1): + idx = torch.arange( + batch_num * self.model.fiber_batch_size, (batch_num + 1) * self.model.fiber_batch_size + ) + else: # When non-full batch is encountered batch_num * N will go out of bounds + idx = torch.arange(batch_num * self.model.fiber_batch_size, embedding_length) + # get random vectors + c_p = c_r[perm[idx]] # NxP + e_p = e_r[perm[idx]] # NxC + # decoder returns the transformed variable z and the log Jacobian determinant + p_u, log_jac_det = decoder(e_p, [c_p]) + # + decoder_log_prob = get_logp(dim_feature_vector, p_u, log_jac_det) + log_prob = decoder_log_prob / dim_feature_vector # likelihood per dim + loss = -F.logsigmoid(log_prob) + self.manual_backward(loss.mean()) + opt.step() + avg_loss += loss.sum() + + return {"loss": avg_loss} + + def validation_step(self, batch, _): # pylint: disable=arguments-differ + """Validation Step of CFLOW. + + Similar to the training step, encoder features + are extracted from the CNN for each batch, and anomaly + map is computed. + + Args: + batch: Input batch + _: Index of the batch. + + Returns: + Dictionary containing images, anomaly maps, true labels and masks. + These are required in `validation_epoch_end` for feature concatenation. + + """ + batch["anomaly_maps"] = self.model(batch["image"]) + + return batch diff --git a/anomalib/models/cflow/model.py b/anomalib/models/cflow/model.py deleted file mode 100644 index 656c903ef2..0000000000 --- a/anomalib/models/cflow/model.py +++ /dev/null @@ -1,354 +0,0 @@ -"""CFLOW: Real-Time Unsupervised Anomaly Detection via Conditional Normalizing Flows. - -https://arxiv.org/pdf/2107.12571v1.pdf -""" - -# Copyright (C) 2020 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions -# and limitations under the License. - -from typing import List, Tuple, Union, cast - -import einops -import numpy as np -import torch -import torch.nn.functional as F -import torchvision -from omegaconf import DictConfig, ListConfig -from pytorch_lightning.callbacks import EarlyStopping -from torch import Tensor, nn, optim - -from anomalib.models.cflow.backbone import cflow_head, positional_encoding_2d -from anomalib.models.components import AnomalyModule, FeatureExtractor - -__all__ = ["AnomalyMapGenerator", "CflowModel", "CflowLightning"] - - -def get_logp(dim_feature_vector: int, p_u: torch.Tensor, logdet_j: torch.Tensor) -> torch.Tensor: - """Returns the log likelihood estimation. - - Args: - dim_feature_vector (int): Dimensions of the condition vector - p_u (torch.Tensor): Random variable u - logdet_j (torch.Tensor): log of determinant of jacobian returned from the invertable decoder - - Returns: - torch.Tensor: Log probability - """ - ln_sqrt_2pi = -np.log(np.sqrt(2 * np.pi)) # ln(sqrt(2*pi)) - logp = dim_feature_vector * ln_sqrt_2pi - 0.5 * torch.sum(p_u**2, 1) + logdet_j - return logp - - -class AnomalyMapGenerator: - """Generate Anomaly Heatmap.""" - - def __init__( - self, - image_size: Union[ListConfig, Tuple], - pool_layers: List[str], - ): - self.distance = torch.nn.PairwiseDistance(p=2, keepdim=True) - self.image_size = image_size if isinstance(image_size, tuple) else tuple(image_size) - self.pool_layers: List[str] = pool_layers - - def compute_anomaly_map( - self, distribution: Union[List[Tensor], List[List]], height: List[int], width: List[int] - ) -> Tensor: - """Compute the layer map based on likelihood estimation. - - Args: - distribution: Probability distribution for each decoder block - height: blocks height - width: blocks width - - Returns: - Final Anomaly Map - - """ - - test_map: List[Tensor] = [] - for layer_idx in range(len(self.pool_layers)): - test_norm = torch.tensor(distribution[layer_idx], dtype=torch.double) # pylint: disable=not-callable - test_norm -= torch.max(test_norm) # normalize likelihoods to (-Inf:0] by subtracting a constant - test_prob = torch.exp(test_norm) # convert to probs in range [0:1] - test_mask = test_prob.reshape(-1, height[layer_idx], width[layer_idx]) - # upsample - test_map.append( - F.interpolate( - test_mask.unsqueeze(1), size=self.image_size, mode="bilinear", align_corners=True - ).squeeze() - ) - # score aggregation - score_map = torch.zeros_like(test_map[0]) - for layer_idx in range(len(self.pool_layers)): - score_map += test_map[layer_idx] - score_mask = score_map - # invert probs to anomaly scores - anomaly_map = score_mask.max() - score_mask - - return anomaly_map - - def __call__(self, **kwargs: Union[List[Tensor], List[int], List[List]]) -> Tensor: - """Returns anomaly_map. - - Expects `distribution`, `height` and 'width' keywords to be passed explicitly - - Example - >>> anomaly_map_generator = AnomalyMapGenerator(image_size=tuple(hparams.model.input_size), - >>> pool_layers=pool_layers) - >>> output = self.anomaly_map_generator(distribution=dist, height=height, width=width) - - Raises: - ValueError: `distribution`, `height` and 'width' keys are not found - - Returns: - torch.Tensor: anomaly map - """ - if not ("distribution" in kwargs and "height" in kwargs and "width" in kwargs): - raise KeyError(f"Expected keys `distribution`, `height` and `width`. Found {kwargs.keys()}") - - # placate mypy - distribution: List[Tensor] = cast(List[Tensor], kwargs["distribution"]) - height: List[int] = cast(List[int], kwargs["height"]) - width: List[int] = cast(List[int], kwargs["width"]) - return self.compute_anomaly_map(distribution, height, width) - - -class CflowModel(nn.Module): - """CFLOW: Conditional Normalizing Flows.""" - - def __init__(self, hparams: Union[DictConfig, ListConfig]): - super().__init__() - - self.backbone = getattr(torchvision.models, hparams.model.backbone) - self.fiber_batch_size = hparams.dataset.fiber_batch_size - self.condition_vector: int = hparams.model.condition_vector - self.dec_arch = hparams.model.decoder - self.pool_layers = hparams.model.layers - - self.encoder = FeatureExtractor(backbone=self.backbone(pretrained=True), layers=self.pool_layers) - self.pool_dims = self.encoder.out_dims - self.decoders = nn.ModuleList( - [ - cflow_head( - condition_vector=self.condition_vector, - coupling_blocks=hparams.model.coupling_blocks, - clamp_alpha=hparams.model.clamp_alpha, - n_features=pool_dim, - permute_soft=hparams.model.soft_permutation, - ) - for pool_dim in self.pool_dims - ] - ) - - # encoder model is fixed - for parameters in self.encoder.parameters(): - parameters.requires_grad = False - - self.anomaly_map_generator = AnomalyMapGenerator( - image_size=tuple(hparams.model.input_size), pool_layers=self.pool_layers - ) - - def forward(self, images): - """Forward-pass images into the network to extract encoder features and compute probability. - - Args: - images: Batch of images. - - Returns: - Predicted anomaly maps. - - """ - - self.encoder.eval() - self.decoders.eval() - with torch.no_grad(): - activation = self.encoder(images) - - distribution = [torch.Tensor(0).to(images.device) for _ in self.pool_layers] - - height: List[int] = [] - width: List[int] = [] - for layer_idx, layer in enumerate(self.pool_layers): - encoder_activations = activation[layer] # BxCxHxW - - batch_size, dim_feature_vector, im_height, im_width = encoder_activations.size() - image_size = im_height * im_width - embedding_length = batch_size * image_size # number of rows in the conditional vector - - height.append(im_height) - width.append(im_width) - # repeats positional encoding for the entire batch 1 C H W to B C H W - pos_encoding = einops.repeat( - positional_encoding_2d(self.condition_vector, im_height, im_width).unsqueeze(0), - "b c h w-> (tile b) c h w", - tile=batch_size, - ).to(images.device) - c_r = einops.rearrange(pos_encoding, "b c h w -> (b h w) c") # BHWxP - e_r = einops.rearrange(encoder_activations, "b c h w -> (b h w) c") # BHWxC - decoder = self.decoders[layer_idx].to(images.device) - - # Sometimes during validation, the last batch E / N is not a whole number. Hence we need to add 1. - # It is assumed that during training that E / N is a whole number as no errors were discovered during - # testing. In case it is observed in the future, we can use only this line and ensure that FIB is at - # least 1 or set `drop_last` in the dataloader to drop the last non-full batch. - fiber_batches = embedding_length // self.fiber_batch_size + int( - embedding_length % self.fiber_batch_size > 0 - ) - - for batch_num in range(fiber_batches): # per-fiber processing - if batch_num < (fiber_batches - 1): - idx = torch.arange(batch_num * self.fiber_batch_size, (batch_num + 1) * self.fiber_batch_size) - else: # When non-full batch is encountered batch_num+1 * N will go out of bounds - idx = torch.arange(batch_num * self.fiber_batch_size, embedding_length) - c_p = c_r[idx] # NxP - e_p = e_r[idx] # NxC - # decoder returns the transformed variable z and the log Jacobian determinant - with torch.no_grad(): - p_u, log_jac_det = decoder(e_p, [c_p]) - # - decoder_log_prob = get_logp(dim_feature_vector, p_u, log_jac_det) - log_prob = decoder_log_prob / dim_feature_vector # likelihood per dim - distribution[layer_idx] = torch.cat((distribution[layer_idx], log_prob)) - - output = self.anomaly_map_generator(distribution=distribution, height=height, width=width) - self.decoders.train() - - return output.to(images.device) - - -class CflowLightning(AnomalyModule): - """PL Lightning Module for the CFLOW algorithm.""" - - def __init__(self, hparams): - super().__init__(hparams) - - self.model: CflowModel = CflowModel(hparams) - self.loss_val = 0 - self.automatic_optimization = False - - def configure_callbacks(self): - """Configure model-specific callbacks.""" - early_stopping = EarlyStopping( - monitor=self.hparams.model.early_stopping.metric, - patience=self.hparams.model.early_stopping.patience, - mode=self.hparams.model.early_stopping.mode, - ) - return [early_stopping] - - def configure_optimizers(self) -> torch.optim.Optimizer: - """Configures optimizers for each decoder. - - Returns: - Optimizer: Adam optimizer for each decoder - """ - decoders_parameters = [] - for decoder_idx in range(len(self.model.pool_layers)): - decoders_parameters.extend(list(self.model.decoders[decoder_idx].parameters())) - - optimizer = optim.Adam( - params=decoders_parameters, - lr=self.hparams.model.lr, - ) - return optimizer - - def training_step(self, batch, _): # pylint: disable=arguments-differ - """Training Step of CFLOW. - - For each batch, decoder layers are trained with a dynamic fiber batch size. - Training step is performed manually as multiple training steps are involved - per batch of input images - - Args: - batch: Input batch - _: Index of the batch. - - Returns: - Loss value for the batch - - """ - opt = self.optimizers() - self.model.encoder.eval() - - images = batch["image"] - activation = self.model.encoder(images) - avg_loss = torch.zeros([1], dtype=torch.float64).to(images.device) - - height = [] - width = [] - for layer_idx, layer in enumerate(self.model.pool_layers): - encoder_activations = activation[layer].detach() # BxCxHxW - - batch_size, dim_feature_vector, im_height, im_width = encoder_activations.size() - image_size = im_height * im_width - embedding_length = batch_size * image_size # number of rows in the conditional vector - - height.append(im_height) - width.append(im_width) - # repeats positional encoding for the entire batch 1 C H W to B C H W - pos_encoding = einops.repeat( - positional_encoding_2d(self.model.condition_vector, im_height, im_width).unsqueeze(0), - "b c h w-> (tile b) c h w", - tile=batch_size, - ).to(images.device) - c_r = einops.rearrange(pos_encoding, "b c h w -> (b h w) c") # BHWxP - e_r = einops.rearrange(encoder_activations, "b c h w -> (b h w) c") # BHWxC - perm = torch.randperm(embedding_length) # BHW - decoder = self.model.decoders[layer_idx].to(images.device) - - fiber_batches = embedding_length // self.model.fiber_batch_size # number of fiber batches - assert fiber_batches > 0, "Make sure we have enough fibers, otherwise decrease N or batch-size!" - - for batch_num in range(fiber_batches): # per-fiber processing - opt.zero_grad() - if batch_num < (fiber_batches - 1): - idx = torch.arange( - batch_num * self.model.fiber_batch_size, (batch_num + 1) * self.model.fiber_batch_size - ) - else: # When non-full batch is encountered batch_num * N will go out of bounds - idx = torch.arange(batch_num * self.model.fiber_batch_size, embedding_length) - # get random vectors - c_p = c_r[perm[idx]] # NxP - e_p = e_r[perm[idx]] # NxC - # decoder returns the transformed variable z and the log Jacobian determinant - p_u, log_jac_det = decoder(e_p, [c_p]) - # - decoder_log_prob = get_logp(dim_feature_vector, p_u, log_jac_det) - log_prob = decoder_log_prob / dim_feature_vector # likelihood per dim - loss = -F.logsigmoid(log_prob) - self.manual_backward(loss.mean()) - opt.step() - avg_loss += loss.sum() - - return {"loss": avg_loss} - - def validation_step(self, batch, _): # pylint: disable=arguments-differ - """Validation Step of CFLOW. - - Similar to the training step, encoder features - are extracted from the CNN for each batch, and anomaly - map is computed. - - Args: - batch: Input batch - _: Index of the batch. - - Returns: - Dictionary containing images, anomaly maps, true labels and masks. - These are required in `validation_epoch_end` for feature concatenation. - - """ - batch["anomaly_maps"] = self.model(batch["image"]) - - return batch diff --git a/anomalib/models/cflow/torch_model.py b/anomalib/models/cflow/torch_model.py new file mode 100644 index 0000000000..7206030c28 --- /dev/null +++ b/anomalib/models/cflow/torch_model.py @@ -0,0 +1,130 @@ +"""PyTorch model for CFlow model implementation.""" + +# Copyright (C) 2020 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. + +from typing import List, Union + +import einops +import torch +import torchvision +from omegaconf import DictConfig, ListConfig +from torch import nn + +from anomalib.models.cflow.anomaly_map import AnomalyMapGenerator +from anomalib.models.cflow.utils import cflow_head, get_logp, positional_encoding_2d +from anomalib.models.components import FeatureExtractor + + +class CflowModel(nn.Module): + """CFLOW: Conditional Normalizing Flows.""" + + def __init__(self, hparams: Union[DictConfig, ListConfig]): + super().__init__() + + self.backbone = getattr(torchvision.models, hparams.model.backbone) + self.fiber_batch_size = hparams.dataset.fiber_batch_size + self.condition_vector: int = hparams.model.condition_vector + self.dec_arch = hparams.model.decoder + self.pool_layers = hparams.model.layers + + self.encoder = FeatureExtractor(backbone=self.backbone(pretrained=True), layers=self.pool_layers) + self.pool_dims = self.encoder.out_dims + self.decoders = nn.ModuleList( + [ + cflow_head( + condition_vector=self.condition_vector, + coupling_blocks=hparams.model.coupling_blocks, + clamp_alpha=hparams.model.clamp_alpha, + n_features=pool_dim, + permute_soft=hparams.model.soft_permutation, + ) + for pool_dim in self.pool_dims + ] + ) + + # encoder model is fixed + for parameters in self.encoder.parameters(): + parameters.requires_grad = False + + self.anomaly_map_generator = AnomalyMapGenerator( + image_size=tuple(hparams.model.input_size), pool_layers=self.pool_layers + ) + + def forward(self, images): + """Forward-pass images into the network to extract encoder features and compute probability. + + Args: + images: Batch of images. + + Returns: + Predicted anomaly maps. + + """ + + self.encoder.eval() + self.decoders.eval() + with torch.no_grad(): + activation = self.encoder(images) + + distribution = [torch.Tensor(0).to(images.device) for _ in self.pool_layers] + + height: List[int] = [] + width: List[int] = [] + for layer_idx, layer in enumerate(self.pool_layers): + encoder_activations = activation[layer] # BxCxHxW + + batch_size, dim_feature_vector, im_height, im_width = encoder_activations.size() + image_size = im_height * im_width + embedding_length = batch_size * image_size # number of rows in the conditional vector + + height.append(im_height) + width.append(im_width) + # repeats positional encoding for the entire batch 1 C H W to B C H W + pos_encoding = einops.repeat( + positional_encoding_2d(self.condition_vector, im_height, im_width).unsqueeze(0), + "b c h w-> (tile b) c h w", + tile=batch_size, + ).to(images.device) + c_r = einops.rearrange(pos_encoding, "b c h w -> (b h w) c") # BHWxP + e_r = einops.rearrange(encoder_activations, "b c h w -> (b h w) c") # BHWxC + decoder = self.decoders[layer_idx].to(images.device) + + # Sometimes during validation, the last batch E / N is not a whole number. Hence we need to add 1. + # It is assumed that during training that E / N is a whole number as no errors were discovered during + # testing. In case it is observed in the future, we can use only this line and ensure that FIB is at + # least 1 or set `drop_last` in the dataloader to drop the last non-full batch. + fiber_batches = embedding_length // self.fiber_batch_size + int( + embedding_length % self.fiber_batch_size > 0 + ) + + for batch_num in range(fiber_batches): # per-fiber processing + if batch_num < (fiber_batches - 1): + idx = torch.arange(batch_num * self.fiber_batch_size, (batch_num + 1) * self.fiber_batch_size) + else: # When non-full batch is encountered batch_num+1 * N will go out of bounds + idx = torch.arange(batch_num * self.fiber_batch_size, embedding_length) + c_p = c_r[idx] # NxP + e_p = e_r[idx] # NxC + # decoder returns the transformed variable z and the log Jacobian determinant + with torch.no_grad(): + p_u, log_jac_det = decoder(e_p, [c_p]) + # + decoder_log_prob = get_logp(dim_feature_vector, p_u, log_jac_det) + log_prob = decoder_log_prob / dim_feature_vector # likelihood per dim + distribution[layer_idx] = torch.cat((distribution[layer_idx], log_prob)) + + output = self.anomaly_map_generator(distribution=distribution, height=height, width=width) + self.decoders.train() + + return output.to(images.device) diff --git a/anomalib/models/cflow/backbone.py b/anomalib/models/cflow/utils.py similarity index 85% rename from anomalib/models/cflow/backbone.py rename to anomalib/models/cflow/utils.py index f4c0f13c6e..1029e6a669 100644 --- a/anomalib/models/cflow/backbone.py +++ b/anomalib/models/cflow/utils.py @@ -1,4 +1,4 @@ -"""Helper functions to create backbone model.""" +"""Helper functions for CFlow implementation.""" # Copyright (C) 2020 Intel Corporation # @@ -16,6 +16,7 @@ import math +import numpy as np import torch from torch import nn @@ -23,6 +24,22 @@ from anomalib.models.components.freia.modules import AllInOneBlock +def get_logp(dim_feature_vector: int, p_u: torch.Tensor, logdet_j: torch.Tensor) -> torch.Tensor: + """Returns the log likelihood estimation. + + Args: + dim_feature_vector (int): Dimensions of the condition vector + p_u (torch.Tensor): Random variable u + logdet_j (torch.Tensor): log of determinant of jacobian returned from the invertable decoder + + Returns: + torch.Tensor: Log probability + """ + ln_sqrt_2pi = -np.log(np.sqrt(2 * np.pi)) # ln(sqrt(2*pi)) + logp = dim_feature_vector * ln_sqrt_2pi - 0.5 * torch.sum(p_u**2, 1) + logdet_j + return logp + + def positional_encoding_2d(condition_vector: int, height: int, width: int) -> torch.Tensor: """Creates embedding to store relative position of the feature vector using sine and cosine functions. diff --git a/anomalib/models/dfkde/__init__.py b/anomalib/models/dfkde/__init__.py index 1479c15b3a..cc077649e2 100644 --- a/anomalib/models/dfkde/__init__.py +++ b/anomalib/models/dfkde/__init__.py @@ -13,3 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions # and limitations under the License. + +from .lightning_model import DfkdeLightning + +__all__ = ["DfkdeLightning"] diff --git a/anomalib/models/dfkde/model.py b/anomalib/models/dfkde/lightning_model.py similarity index 100% rename from anomalib/models/dfkde/model.py rename to anomalib/models/dfkde/lightning_model.py diff --git a/anomalib/models/dfm/__init__.py b/anomalib/models/dfm/__init__.py index 3824007791..0514d279b3 100644 --- a/anomalib/models/dfm/__init__.py +++ b/anomalib/models/dfm/__init__.py @@ -13,3 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions # and limitations under the License. + +from .lightning_model import DfmLightning + +__all__ = ["DfmLightning"] diff --git a/anomalib/models/dfm/model.py b/anomalib/models/dfm/lightning_model.py similarity index 98% rename from anomalib/models/dfm/model.py rename to anomalib/models/dfm/lightning_model.py index 444dca1030..91fb10b6c2 100644 --- a/anomalib/models/dfm/model.py +++ b/anomalib/models/dfm/lightning_model.py @@ -22,7 +22,7 @@ from anomalib.models.components import AnomalyModule -from .dfm_model import DFMModel +from .torch_model import DFMModel class DfmLightning(AnomalyModule): diff --git a/anomalib/models/dfm/dfm_model.py b/anomalib/models/dfm/torch_model.py similarity index 99% rename from anomalib/models/dfm/dfm_model.py rename to anomalib/models/dfm/torch_model.py index bab8785ee9..7014f101e3 100644 --- a/anomalib/models/dfm/dfm_model.py +++ b/anomalib/models/dfm/torch_model.py @@ -1,4 +1,4 @@ -"""Normality model of DFKDE.""" +"""PyTorch model for DFM model implementation.""" # Copyright (C) 2020 Intel Corporation # diff --git a/anomalib/models/ganomaly/__init__.py b/anomalib/models/ganomaly/__init__.py index a920c07451..08fd1ff415 100644 --- a/anomalib/models/ganomaly/__init__.py +++ b/anomalib/models/ganomaly/__init__.py @@ -14,6 +14,6 @@ # See the License for the specific language governing permissions # and limitations under the License. -from .model import GanomalyLightning +from .lightning_model import GanomalyLightning __all__ = ["GanomalyLightning"] diff --git a/anomalib/models/ganomaly/model.py b/anomalib/models/ganomaly/lightning_model.py similarity index 100% rename from anomalib/models/ganomaly/model.py rename to anomalib/models/ganomaly/lightning_model.py diff --git a/anomalib/models/padim/__init__.py b/anomalib/models/padim/__init__.py index d85459be9e..6002f79581 100644 --- a/anomalib/models/padim/__init__.py +++ b/anomalib/models/padim/__init__.py @@ -13,3 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions # and limitations under the License. + +from .lightning_model import PadimLightning + +__all__ = ["PadimLightning"] diff --git a/anomalib/models/padim/anomaly_map.py b/anomalib/models/padim/anomaly_map.py new file mode 100644 index 0000000000..db363290ea --- /dev/null +++ b/anomalib/models/padim/anomaly_map.py @@ -0,0 +1,146 @@ +"""Anomaly Map Generator for the PaDiM model implementation.""" + +# Copyright (C) 2020 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. + +from typing import List, Tuple, Union + +import torch +import torch.nn.functional as F +from kornia.filters import gaussian_blur2d +from omegaconf import ListConfig +from torch import Tensor + + +class AnomalyMapGenerator: + """Generate Anomaly Heatmap. + + Args: + image_size (Union[ListConfig, Tuple]): Size of the input image. The anomaly map is upsampled to this dimension. + sigma (int, optional): Standard deviation for Gaussian Kernel. Defaults to 4. + """ + + def __init__(self, image_size: Union[ListConfig, Tuple], sigma: int = 4): + self.image_size = image_size if isinstance(image_size, tuple) else tuple(image_size) + self.sigma = sigma + + @staticmethod + def compute_distance(embedding: Tensor, stats: List[Tensor]) -> Tensor: + """Compute anomaly score to the patch in position(i,j) of a test image. + + Ref: Equation (2), Section III-C of the paper. + + Args: + embedding (Tensor): Embedding Vector + stats (List[Tensor]): Mean and Covariance Matrix of the multivariate Gaussian distribution + + Returns: + Anomaly score of a test image via mahalanobis distance. + """ + + batch, channel, height, width = embedding.shape + embedding = embedding.reshape(batch, channel, height * width) + + # calculate mahalanobis distances + mean, inv_covariance = stats + delta = (embedding - mean).permute(2, 0, 1) + + distances = (torch.matmul(delta, inv_covariance) * delta).sum(2).permute(1, 0) + distances = distances.reshape(batch, height, width) + distances = torch.sqrt(distances) + + return distances + + def up_sample(self, distance: Tensor) -> Tensor: + """Up sample anomaly score to match the input image size. + + Args: + distance (Tensor): Anomaly score computed via the mahalanobis distance. + + Returns: + Resized distance matrix matching the input image size + """ + + score_map = F.interpolate( + distance.unsqueeze(1), + size=self.image_size, + mode="bilinear", + align_corners=False, + ) + return score_map + + def smooth_anomaly_map(self, anomaly_map: Tensor) -> Tensor: + """Apply gaussian smoothing to the anomaly map. + + Args: + anomaly_map (Tensor): Anomaly score for the test image(s). + + Returns: + Filtered anomaly scores + """ + + kernel_size = 2 * int(4.0 * self.sigma + 0.5) + 1 + sigma = torch.as_tensor(self.sigma).to(anomaly_map.device) + anomaly_map = gaussian_blur2d(anomaly_map, (kernel_size, kernel_size), sigma=(sigma, sigma)) + + return anomaly_map + + def compute_anomaly_map(self, embedding: Tensor, mean: Tensor, inv_covariance: Tensor) -> Tensor: + """Compute anomaly score. + + Scores are calculated based on embedding vector, mean and inv_covariance of the multivariate gaussian + distribution. + + Args: + embedding (Tensor): Embedding vector extracted from the test set. + mean (Tensor): Mean of the multivariate gaussian distribution + inv_covariance (Tensor): Inverse Covariance matrix of the multivariate gaussian distribution. + + Returns: + Output anomaly score. + """ + + score_map = self.compute_distance( + embedding=embedding, + stats=[mean.to(embedding.device), inv_covariance.to(embedding.device)], + ) + up_sampled_score_map = self.up_sample(score_map) + smoothed_anomaly_map = self.smooth_anomaly_map(up_sampled_score_map) + + return smoothed_anomaly_map + + def __call__(self, **kwds): + """Returns anomaly_map. + + Expects `embedding`, `mean` and `covariance` keywords to be passed explicitly. + + Example: + >>> anomaly_map_generator = AnomalyMapGenerator(image_size=input_size) + >>> output = anomaly_map_generator(embedding=embedding, mean=mean, covariance=covariance) + + Raises: + ValueError: `embedding`. `mean` or `covariance` keys are not found + + Returns: + torch.Tensor: anomaly map + """ + + if not ("embedding" in kwds and "mean" in kwds and "inv_covariance" in kwds): + raise ValueError(f"Expected keys `embedding`, `mean` and `covariance`. Found {kwds.keys()}") + + embedding: Tensor = kwds["embedding"] + mean: Tensor = kwds["mean"] + inv_covariance: Tensor = kwds["inv_covariance"] + + return self.compute_anomaly_map(embedding, mean, inv_covariance) diff --git a/anomalib/models/padim/lightning_model.py b/anomalib/models/padim/lightning_model.py new file mode 100644 index 0000000000..8e4e20ce5f --- /dev/null +++ b/anomalib/models/padim/lightning_model.py @@ -0,0 +1,101 @@ +"""PaDiM: a Patch Distribution Modeling Framework for Anomaly Detection and Localization. + +Paper https://arxiv.org/abs/2011.08785 +""" + +# Copyright (C) 2020 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. + +from typing import List, Union + +import torch +from omegaconf import DictConfig, ListConfig +from torch import Tensor + +from anomalib.models.components import AnomalyModule +from anomalib.models.padim.torch_model import PadimModel + +__all__ = ["PadimLightning"] + + +class PadimLightning(AnomalyModule): + """PaDiM: a Patch Distribution Modeling Framework for Anomaly Detection and Localization. + + Args: + hparams (Union[DictConfig, ListConfig]): Model params + """ + + def __init__(self, hparams: Union[DictConfig, ListConfig]): + super().__init__(hparams) + self.layers = hparams.model.layers + self.model: PadimModel = PadimModel( + layers=hparams.model.layers, + input_size=hparams.model.input_size, + tile_size=hparams.dataset.tiling.tile_size, + tile_stride=hparams.dataset.tiling.stride, + apply_tiling=hparams.dataset.tiling.apply, + backbone=hparams.model.backbone, + ).eval() + + self.stats: List[Tensor] = [] + self.embeddings: List[Tensor] = [] + + @staticmethod + def configure_optimizers(): # pylint: disable=arguments-differ + """PADIM doesn't require optimization, therefore returns no optimizers.""" + return None + + def training_step(self, batch, _batch_idx): # pylint: disable=arguments-differ + """Training Step of PADIM. For each batch, hierarchical features are extracted from the CNN. + + Args: + batch (Dict[str, Any]): Batch containing image filename, image, label and mask + _batch_idx: Index of the batch. + + Returns: + Hierarchical feature map + """ + self.model.feature_extractor.eval() + embedding = self.model(batch["image"]) + + # NOTE: `self.embedding` appends each batch embedding to + # store the training set embedding. We manually append these + # values mainly due to the new order of hooks introduced after PL v1.4.0 + # https://github.com/PyTorchLightning/pytorch-lightning/pull/7357 + self.embeddings.append(embedding.cpu()) + + def on_validation_start(self) -> None: + """Fit a Gaussian to the embedding collected from the training set.""" + # NOTE: Previous anomalib versions fit Gaussian at the end of the epoch. + # This is not possible anymore with PyTorch Lightning v1.4.0 since validation + # is run within train epoch. + embeddings = torch.vstack(self.embeddings) + self.stats = self.model.gaussian.fit(embeddings) + + def validation_step(self, batch, _): # pylint: disable=arguments-differ + """Validation Step of PADIM. + + Similar to the training step, hierarchical features are extracted from the CNN for each batch. + + Args: + batch: Input batch + _: Index of the batch. + + Returns: + Dictionary containing images, features, true labels and masks. + These are required in `validation_epoch_end` for feature concatenation. + """ + + batch["anomaly_maps"] = self.model(batch["image"]) + return batch diff --git a/anomalib/models/padim/model.py b/anomalib/models/padim/model.py deleted file mode 100644 index 1d5544164b..0000000000 --- a/anomalib/models/padim/model.py +++ /dev/null @@ -1,346 +0,0 @@ -"""PaDiM: a Patch Distribution Modeling Framework for Anomaly Detection and Localization. - -Paper https://arxiv.org/abs/2011.08785 -""" - -# Copyright (C) 2020 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions -# and limitations under the License. - -from random import sample -from typing import Dict, List, Optional, Tuple, Union - -import torch -import torch.nn.functional as F -import torchvision -from kornia.filters import gaussian_blur2d -from omegaconf import DictConfig, ListConfig -from torch import Tensor, nn - -from anomalib.models.components import ( - AnomalyModule, - FeatureExtractor, - MultiVariateGaussian, -) -from anomalib.pre_processing import Tiler - -__all__ = ["PadimLightning"] - - -DIMS = { - "resnet18": {"orig_dims": 448, "reduced_dims": 100, "emb_scale": 4}, - "wide_resnet50_2": {"orig_dims": 1792, "reduced_dims": 550, "emb_scale": 4}, -} - - -class PadimModel(nn.Module): - """Padim Module. - - Args: - layers (List[str]): Layers used for feature extraction - input_size (Tuple[int, int]): Input size for the model. - tile_size (Tuple[int, int]): Tile size - tile_stride (int): Stride for tiling - apply_tiling (bool, optional): Apply tiling. Defaults to False. - backbone (str, optional): Pre-trained model backbone. Defaults to "resnet18". - """ - - def __init__( - self, - layers: List[str], - input_size: Tuple[int, int], - backbone: str = "resnet18", - apply_tiling: bool = False, - tile_size: Optional[Tuple[int, int]] = None, - tile_stride: Optional[int] = None, - ): - super().__init__() - self.backbone = getattr(torchvision.models, backbone) - self.layers = layers - self.apply_tiling = apply_tiling - self.feature_extractor = FeatureExtractor(backbone=self.backbone(pretrained=True), layers=self.layers) - self.dims = DIMS[backbone] - # pylint: disable=not-callable - # Since idx is randomly selected, save it with model to get same results - self.register_buffer( - "idx", - torch.tensor(sample(range(0, DIMS[backbone]["orig_dims"]), DIMS[backbone]["reduced_dims"])), - ) - self.idx: Tensor - self.loss = None - self.anomaly_map_generator = AnomalyMapGenerator(image_size=input_size) - - n_features = DIMS[backbone]["reduced_dims"] - patches_dims = torch.tensor(input_size) / DIMS[backbone]["emb_scale"] - n_patches = patches_dims.ceil().prod().int().item() - self.gaussian = MultiVariateGaussian(n_features, n_patches) - - if apply_tiling: - assert tile_size is not None - assert tile_stride is not None - self.tiler = Tiler(tile_size, tile_stride) - - def forward(self, input_tensor: Tensor) -> Tensor: - """Forward-pass image-batch (N, C, H, W) into model to extract features. - - Args: - input_tensor: Image-batch (N, C, H, W) - input_tensor: Tensor: - - Returns: - Features from single/multiple layers. - - Example: - >>> x = torch.randn(32, 3, 224, 224) - >>> features = self.extract_features(input_tensor) - >>> features.keys() - dict_keys(['layer1', 'layer2', 'layer3']) - - >>> [v.shape for v in features.values()] - [torch.Size([32, 64, 56, 56]), - torch.Size([32, 128, 28, 28]), - torch.Size([32, 256, 14, 14])] - """ - - if self.apply_tiling: - input_tensor = self.tiler.tile(input_tensor) - with torch.no_grad(): - features = self.feature_extractor(input_tensor) - embeddings = self.generate_embedding(features) - if self.apply_tiling: - embeddings = self.tiler.untile(embeddings) - - if self.training: - output = embeddings - else: - output = self.anomaly_map_generator( - embedding=embeddings, mean=self.gaussian.mean, inv_covariance=self.gaussian.inv_covariance - ) - - return output - - def generate_embedding(self, features: Dict[str, Tensor]) -> Tensor: - """Generate embedding from hierarchical feature map. - - Args: - features (Dict[str, Tensor]): Hierarchical feature map from a CNN (ResNet18 or WideResnet) - - Returns: - Embedding vector - """ - - embeddings = features[self.layers[0]] - for layer in self.layers[1:]: - layer_embedding = features[layer] - layer_embedding = F.interpolate(layer_embedding, size=embeddings.shape[-2:], mode="nearest") - embeddings = torch.cat((embeddings, layer_embedding), 1) - - # subsample embeddings - idx = self.idx.to(embeddings.device) - embeddings = torch.index_select(embeddings, 1, idx) - return embeddings - - -class AnomalyMapGenerator: - """Generate Anomaly Heatmap. - - Args: - image_size (Union[ListConfig, Tuple]): Size of the input image. The anomaly map is upsampled to this dimension. - sigma (int, optional): Standard deviation for Gaussian Kernel. Defaults to 4. - """ - - def __init__(self, image_size: Union[ListConfig, Tuple], sigma: int = 4): - self.image_size = image_size if isinstance(image_size, tuple) else tuple(image_size) - self.sigma = sigma - - @staticmethod - def compute_distance(embedding: Tensor, stats: List[Tensor]) -> Tensor: - """Compute anomaly score to the patch in position(i,j) of a test image. - - Ref: Equation (2), Section III-C of the paper. - - Args: - embedding (Tensor): Embedding Vector - stats (List[Tensor]): Mean and Covariance Matrix of the multivariate Gaussian distribution - - Returns: - Anomaly score of a test image via mahalanobis distance. - """ - - batch, channel, height, width = embedding.shape - embedding = embedding.reshape(batch, channel, height * width) - - # calculate mahalanobis distances - mean, inv_covariance = stats - delta = (embedding - mean).permute(2, 0, 1) - - distances = (torch.matmul(delta, inv_covariance) * delta).sum(2).permute(1, 0) - distances = distances.reshape(batch, height, width) - distances = torch.sqrt(distances) - - return distances - - def up_sample(self, distance: Tensor) -> Tensor: - """Up sample anomaly score to match the input image size. - - Args: - distance (Tensor): Anomaly score computed via the mahalanobis distance. - - Returns: - Resized distance matrix matching the input image size - """ - - score_map = F.interpolate( - distance.unsqueeze(1), - size=self.image_size, - mode="bilinear", - align_corners=False, - ) - return score_map - - def smooth_anomaly_map(self, anomaly_map: Tensor) -> Tensor: - """Apply gaussian smoothing to the anomaly map. - - Args: - anomaly_map (Tensor): Anomaly score for the test image(s). - - Returns: - Filtered anomaly scores - """ - - kernel_size = 2 * int(4.0 * self.sigma + 0.5) + 1 - sigma = torch.as_tensor(self.sigma).to(anomaly_map.device) - anomaly_map = gaussian_blur2d(anomaly_map, (kernel_size, kernel_size), sigma=(sigma, sigma)) - - return anomaly_map - - def compute_anomaly_map(self, embedding: Tensor, mean: Tensor, inv_covariance: Tensor) -> Tensor: - """Compute anomaly score. - - Scores are calculated based on embedding vector, mean and inv_covariance of the multivariate gaussian - distribution. - - Args: - embedding (Tensor): Embedding vector extracted from the test set. - mean (Tensor): Mean of the multivariate gaussian distribution - inv_covariance (Tensor): Inverse Covariance matrix of the multivariate gaussian distribution. - - Returns: - Output anomaly score. - """ - - score_map = self.compute_distance( - embedding=embedding, - stats=[mean.to(embedding.device), inv_covariance.to(embedding.device)], - ) - up_sampled_score_map = self.up_sample(score_map) - smoothed_anomaly_map = self.smooth_anomaly_map(up_sampled_score_map) - - return smoothed_anomaly_map - - def __call__(self, **kwds): - """Returns anomaly_map. - - Expects `embedding`, `mean` and `covariance` keywords to be passed explicitly. - - Example: - >>> anomaly_map_generator = AnomalyMapGenerator(image_size=input_size) - >>> output = anomaly_map_generator(embedding=embedding, mean=mean, covariance=covariance) - - Raises: - ValueError: `embedding`. `mean` or `covariance` keys are not found - - Returns: - torch.Tensor: anomaly map - """ - - if not ("embedding" in kwds and "mean" in kwds and "inv_covariance" in kwds): - raise ValueError(f"Expected keys `embedding`, `mean` and `covariance`. Found {kwds.keys()}") - - embedding: Tensor = kwds["embedding"] - mean: Tensor = kwds["mean"] - inv_covariance: Tensor = kwds["inv_covariance"] - - return self.compute_anomaly_map(embedding, mean, inv_covariance) - - -class PadimLightning(AnomalyModule): - """PaDiM: a Patch Distribution Modeling Framework for Anomaly Detection and Localization. - - Args: - hparams (Union[DictConfig, ListConfig]): Model params - """ - - def __init__(self, hparams: Union[DictConfig, ListConfig]): - super().__init__(hparams) - self.layers = hparams.model.layers - self.model: PadimModel = PadimModel( - layers=hparams.model.layers, - input_size=hparams.model.input_size, - tile_size=hparams.dataset.tiling.tile_size, - tile_stride=hparams.dataset.tiling.stride, - apply_tiling=hparams.dataset.tiling.apply, - backbone=hparams.model.backbone, - ).eval() - - self.stats: List[Tensor] = [] - self.embeddings: List[Tensor] = [] - - @staticmethod - def configure_optimizers(): # pylint: disable=arguments-differ - """PADIM doesn't require optimization, therefore returns no optimizers.""" - return None - - def training_step(self, batch, _batch_idx): # pylint: disable=arguments-differ - """Training Step of PADIM. For each batch, hierarchical features are extracted from the CNN. - - Args: - batch (Dict[str, Any]): Batch containing image filename, image, label and mask - _batch_idx: Index of the batch. - - Returns: - Hierarchical feature map - """ - self.model.feature_extractor.eval() - embedding = self.model(batch["image"]) - - # NOTE: `self.embedding` appends each batch embedding to - # store the training set embedding. We manually append these - # values mainly due to the new order of hooks introduced after PL v1.4.0 - # https://github.com/PyTorchLightning/pytorch-lightning/pull/7357 - self.embeddings.append(embedding.cpu()) - - def on_validation_start(self) -> None: - """Fit a Gaussian to the embedding collected from the training set.""" - # NOTE: Previous anomalib versions fit Gaussian at the end of the epoch. - # This is not possible anymore with PyTorch Lightning v1.4.0 since validation - # is run within train epoch. - embeddings = torch.vstack(self.embeddings) - self.stats = self.model.gaussian.fit(embeddings) - - def validation_step(self, batch, _): # pylint: disable=arguments-differ - """Validation Step of PADIM. - - Similar to the training step, hierarchical features are extracted from the CNN for each batch. - - Args: - batch: Input batch - _: Index of the batch. - - Returns: - Dictionary containing images, features, true labels and masks. - These are required in `validation_epoch_end` for feature concatenation. - """ - - batch["anomaly_maps"] = self.model(batch["image"]) - return batch diff --git a/anomalib/models/padim/torch_model.py b/anomalib/models/padim/torch_model.py new file mode 100644 index 0000000000..4a393feede --- /dev/null +++ b/anomalib/models/padim/torch_model.py @@ -0,0 +1,140 @@ +"""PyTorch model for the PaDiM model implementation.""" + +# Copyright (C) 2020 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. + +from random import sample +from typing import Dict, List, Optional, Tuple + +import torch +import torch.nn.functional as F +import torchvision +from torch import Tensor, nn + +from anomalib.models.components import FeatureExtractor, MultiVariateGaussian +from anomalib.models.padim.anomaly_map import AnomalyMapGenerator +from anomalib.pre_processing import Tiler + +DIMS = { + "resnet18": {"orig_dims": 448, "reduced_dims": 100, "emb_scale": 4}, + "wide_resnet50_2": {"orig_dims": 1792, "reduced_dims": 550, "emb_scale": 4}, +} + + +class PadimModel(nn.Module): + """Padim Module. + + Args: + layers (List[str]): Layers used for feature extraction + input_size (Tuple[int, int]): Input size for the model. + tile_size (Tuple[int, int]): Tile size + tile_stride (int): Stride for tiling + apply_tiling (bool, optional): Apply tiling. Defaults to False. + backbone (str, optional): Pre-trained model backbone. Defaults to "resnet18". + """ + + def __init__( + self, + layers: List[str], + input_size: Tuple[int, int], + backbone: str = "resnet18", + apply_tiling: bool = False, + tile_size: Optional[Tuple[int, int]] = None, + tile_stride: Optional[int] = None, + ): + super().__init__() + self.backbone = getattr(torchvision.models, backbone) + self.layers = layers + self.apply_tiling = apply_tiling + self.feature_extractor = FeatureExtractor(backbone=self.backbone(pretrained=True), layers=self.layers) + self.dims = DIMS[backbone] + # pylint: disable=not-callable + # Since idx is randomly selected, save it with model to get same results + self.register_buffer( + "idx", + torch.tensor(sample(range(0, DIMS[backbone]["orig_dims"]), DIMS[backbone]["reduced_dims"])), + ) + self.idx: Tensor + self.loss = None + self.anomaly_map_generator = AnomalyMapGenerator(image_size=input_size) + + n_features = DIMS[backbone]["reduced_dims"] + patches_dims = torch.tensor(input_size) / DIMS[backbone]["emb_scale"] + n_patches = patches_dims.ceil().prod().int().item() + self.gaussian = MultiVariateGaussian(n_features, n_patches) + + if apply_tiling: + assert tile_size is not None + assert tile_stride is not None + self.tiler = Tiler(tile_size, tile_stride) + + def forward(self, input_tensor: Tensor) -> Tensor: + """Forward-pass image-batch (N, C, H, W) into model to extract features. + + Args: + input_tensor: Image-batch (N, C, H, W) + input_tensor: Tensor: + + Returns: + Features from single/multiple layers. + + Example: + >>> x = torch.randn(32, 3, 224, 224) + >>> features = self.extract_features(input_tensor) + >>> features.keys() + dict_keys(['layer1', 'layer2', 'layer3']) + + >>> [v.shape for v in features.values()] + [torch.Size([32, 64, 56, 56]), + torch.Size([32, 128, 28, 28]), + torch.Size([32, 256, 14, 14])] + """ + + if self.apply_tiling: + input_tensor = self.tiler.tile(input_tensor) + with torch.no_grad(): + features = self.feature_extractor(input_tensor) + embeddings = self.generate_embedding(features) + if self.apply_tiling: + embeddings = self.tiler.untile(embeddings) + + if self.training: + output = embeddings + else: + output = self.anomaly_map_generator( + embedding=embeddings, mean=self.gaussian.mean, inv_covariance=self.gaussian.inv_covariance + ) + + return output + + def generate_embedding(self, features: Dict[str, Tensor]) -> Tensor: + """Generate embedding from hierarchical feature map. + + Args: + features (Dict[str, Tensor]): Hierarchical feature map from a CNN (ResNet18 or WideResnet) + + Returns: + Embedding vector + """ + + embeddings = features[self.layers[0]] + for layer in self.layers[1:]: + layer_embedding = features[layer] + layer_embedding = F.interpolate(layer_embedding, size=embeddings.shape[-2:], mode="nearest") + embeddings = torch.cat((embeddings, layer_embedding), 1) + + # subsample embeddings + idx = self.idx.to(embeddings.device) + embeddings = torch.index_select(embeddings, 1, idx) + return embeddings diff --git a/anomalib/models/patchcore/__init__.py b/anomalib/models/patchcore/__init__.py index 547a90b152..5c67dde955 100644 --- a/anomalib/models/patchcore/__init__.py +++ b/anomalib/models/patchcore/__init__.py @@ -13,3 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions # and limitations under the License. + +from .lightning_model import PatchcoreLightning + +__all__ = ["PatchcoreLightning"] diff --git a/anomalib/models/patchcore/anomaly_map.py b/anomalib/models/patchcore/anomaly_map.py new file mode 100644 index 0000000000..15eda34f56 --- /dev/null +++ b/anomalib/models/patchcore/anomaly_map.py @@ -0,0 +1,100 @@ +"""Anomaly Map Generator for the PatchCore model implementation.""" + +# Copyright (C) 2020 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. + +from typing import Tuple, Union + +import torch +import torch.nn.functional as F +from kornia.filters import gaussian_blur2d +from omegaconf import ListConfig + + +class AnomalyMapGenerator: + """Generate Anomaly Heatmap.""" + + def __init__( + self, + input_size: Union[ListConfig, Tuple], + sigma: int = 4, + ) -> None: + self.input_size = input_size + self.sigma = sigma + + def compute_anomaly_map(self, patch_scores: torch.Tensor, feature_map_shape: torch.Size) -> torch.Tensor: + """Pixel Level Anomaly Heatmap. + + Args: + patch_scores (torch.Tensor): Patch-level anomaly scores + feature_map_shape (torch.Size): 2-D feature map shape (width, height) + + Returns: + torch.Tensor: Map of the pixel-level anomaly scores + """ + width, height = feature_map_shape + batch_size = len(patch_scores) // (width * height) + + anomaly_map = patch_scores[:, 0].reshape((batch_size, 1, width, height)) + anomaly_map = F.interpolate(anomaly_map, size=(self.input_size[0], self.input_size[1])) + + kernel_size = 2 * int(4.0 * self.sigma + 0.5) + 1 + anomaly_map = gaussian_blur2d(anomaly_map, (kernel_size, kernel_size), sigma=(self.sigma, self.sigma)) + + return anomaly_map + + @staticmethod + def compute_anomaly_score(patch_scores: torch.Tensor) -> torch.Tensor: + """Compute Image-Level Anomaly Score. + + Args: + patch_scores (torch.Tensor): Patch-level anomaly scores + Returns: + torch.Tensor: Image-level anomaly scores + """ + max_scores = torch.argmax(patch_scores[:, 0]) + confidence = torch.index_select(patch_scores, 0, max_scores) + weights = 1 - (torch.max(torch.exp(confidence)) / torch.sum(torch.exp(confidence))) + score = weights * torch.max(patch_scores[:, 0]) + return score + + def __call__(self, **kwargs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Returns anomaly_map and anomaly_score. + + Expects `patch_scores` keyword to be passed explicitly + Expects `feature_map_shape` keyword to be passed explicitly + + Example + >>> anomaly_map_generator = AnomalyMapGenerator(input_size=input_size) + >>> map, score = anomaly_map_generator(patch_scores=numpy_array, feature_map_shape=feature_map_shape) + + Raises: + ValueError: If `patch_scores` key is not found + + Returns: + Tuple[torch.Tensor, torch.Tensor]: anomaly_map, anomaly_score + """ + + if "patch_scores" not in kwargs: + raise ValueError(f"Expected key `patch_scores`. Found {kwargs.keys()}") + + if "feature_map_shape" not in kwargs: + raise ValueError(f"Expected key `feature_map_shape`. Found {kwargs.keys()}") + + patch_scores = kwargs["patch_scores"] + feature_map_shape = kwargs["feature_map_shape"] + + anomaly_map = self.compute_anomaly_map(patch_scores, feature_map_shape) + anomaly_score = self.compute_anomaly_score(patch_scores) + return anomaly_map, anomaly_score diff --git a/anomalib/models/patchcore/lightning_model.py b/anomalib/models/patchcore/lightning_model.py new file mode 100644 index 0000000000..dd8b4e8668 --- /dev/null +++ b/anomalib/models/patchcore/lightning_model.py @@ -0,0 +1,108 @@ +"""Towards Total Recall in Industrial Anomaly Detection. + +Paper https://arxiv.org/abs/2106.08265. +""" + +# Copyright (C) 2020 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. + +from typing import List + +import torch +from torch import Tensor + +from anomalib.models.components import AnomalyModule +from anomalib.models.patchcore.torch_model import PatchcoreModel + + +class PatchcoreLightning(AnomalyModule): + """PatchcoreLightning Module to train PatchCore algorithm. + + Args: + layers (List[str]): Layers used for feature extraction + input_size (Tuple[int, int]): Input size for the model. + tile_size (Tuple[int, int]): Tile size + tile_stride (int): Stride for tiling + backbone (str, optional): Pre-trained model backbone. Defaults to "resnet18". + apply_tiling (bool, optional): Apply tiling. Defaults to False. + """ + + def __init__(self, hparams) -> None: + super().__init__(hparams) + + self.model: PatchcoreModel = PatchcoreModel( + layers=hparams.model.layers, + input_size=hparams.model.input_size, + tile_size=hparams.dataset.tiling.tile_size, + tile_stride=hparams.dataset.tiling.stride, + backbone=hparams.model.backbone, + apply_tiling=hparams.dataset.tiling.apply, + ) + self.embeddings: List[Tensor] = [] + + def configure_optimizers(self) -> None: + """Configure optimizers. + + Returns: + None: Do not set optimizers by returning None. + """ + return None + + def training_step(self, batch, _batch_idx): # pylint: disable=arguments-differ + """Generate feature embedding of the batch. + + Args: + batch (Dict[str, Any]): Batch containing image filename, image, label and mask + _batch_idx (int): Batch Index + + Returns: + Dict[str, np.ndarray]: Embedding Vector + """ + self.model.feature_extractor.eval() + embedding = self.model(batch["image"]) + + # NOTE: `self.embedding` appends each batch embedding to + # store the training set embedding. We manually append these + # values mainly due to the new order of hooks introduced after PL v1.4.0 + # https://github.com/PyTorchLightning/pytorch-lightning/pull/7357 + self.embeddings.append(embedding) + + def on_validation_start(self) -> None: + """Apply subsampling to the embedding collected from the training set.""" + # NOTE: Previous anomalib versions fit subsampling at the end of the epoch. + # This is not possible anymore with PyTorch Lightning v1.4.0 since validation + # is run within train epoch. + print("Aggregating the embedding extracted from the training set.") + embeddings = torch.vstack(self.embeddings) + + sampling_ratio = self.hparams.model.coreset_sampling_ratio + self.model.subsample_embedding(embeddings, sampling_ratio) + + def validation_step(self, batch, _): # pylint: disable=arguments-differ + """Get batch of anomaly maps from input image batch. + + Args: + batch (Dict[str, Any]): Batch containing image filename, + image, label and mask + _ (int): Batch Index + + Returns: + Dict[str, Any]: Image filenames, test images, GT and predicted label/masks + """ + + anomaly_maps, anomaly_score = self.model(batch["image"]) + batch["anomaly_maps"] = anomaly_maps + batch["pred_scores"] = anomaly_score.unsqueeze(0) + + return batch diff --git a/anomalib/models/patchcore/model.py b/anomalib/models/patchcore/model.py deleted file mode 100644 index 5358c8c594..0000000000 --- a/anomalib/models/patchcore/model.py +++ /dev/null @@ -1,334 +0,0 @@ -"""Towards Total Recall in Industrial Anomaly Detection. - -Paper https://arxiv.org/abs/2106.08265. -""" - -# Copyright (C) 2020 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions -# and limitations under the License. - -from typing import Dict, List, Optional, Tuple, Union - -import torch -import torch.nn.functional as F -import torchvision -from kornia.filters import gaussian_blur2d -from omegaconf import ListConfig -from torch import Tensor, nn - -from anomalib.models.components import ( - AnomalyModule, - DynamicBufferModule, - FeatureExtractor, - KCenterGreedy, -) -from anomalib.pre_processing import Tiler - - -class AnomalyMapGenerator: - """Generate Anomaly Heatmap.""" - - def __init__( - self, - input_size: Union[ListConfig, Tuple], - sigma: int = 4, - ) -> None: - self.input_size = input_size - self.sigma = sigma - - def compute_anomaly_map(self, patch_scores: torch.Tensor, feature_map_shape: torch.Size) -> torch.Tensor: - """Pixel Level Anomaly Heatmap. - - Args: - patch_scores (torch.Tensor): Patch-level anomaly scores - feature_map_shape (torch.Size): 2-D feature map shape (width, height) - - Returns: - torch.Tensor: Map of the pixel-level anomaly scores - """ - width, height = feature_map_shape - batch_size = len(patch_scores) // (width * height) - - anomaly_map = patch_scores[:, 0].reshape((batch_size, 1, width, height)) - anomaly_map = F.interpolate(anomaly_map, size=(self.input_size[0], self.input_size[1])) - - kernel_size = 2 * int(4.0 * self.sigma + 0.5) + 1 - anomaly_map = gaussian_blur2d(anomaly_map, (kernel_size, kernel_size), sigma=(self.sigma, self.sigma)) - - return anomaly_map - - @staticmethod - def compute_anomaly_score(patch_scores: torch.Tensor) -> torch.Tensor: - """Compute Image-Level Anomaly Score. - - Args: - patch_scores (torch.Tensor): Patch-level anomaly scores - Returns: - torch.Tensor: Image-level anomaly scores - """ - max_scores = torch.argmax(patch_scores[:, 0]) - confidence = torch.index_select(patch_scores, 0, max_scores) - weights = 1 - (torch.max(torch.exp(confidence)) / torch.sum(torch.exp(confidence))) - score = weights * torch.max(patch_scores[:, 0]) - return score - - def __call__(self, **kwargs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - """Returns anomaly_map and anomaly_score. - - Expects `patch_scores` keyword to be passed explicitly - Expects `feature_map_shape` keyword to be passed explicitly - - Example - >>> anomaly_map_generator = AnomalyMapGenerator(input_size=input_size) - >>> map, score = anomaly_map_generator(patch_scores=numpy_array, feature_map_shape=feature_map_shape) - - Raises: - ValueError: If `patch_scores` key is not found - - Returns: - Tuple[torch.Tensor, torch.Tensor]: anomaly_map, anomaly_score - """ - - if "patch_scores" not in kwargs: - raise ValueError(f"Expected key `patch_scores`. Found {kwargs.keys()}") - - if "feature_map_shape" not in kwargs: - raise ValueError(f"Expected key `feature_map_shape`. Found {kwargs.keys()}") - - patch_scores = kwargs["patch_scores"] - feature_map_shape = kwargs["feature_map_shape"] - - anomaly_map = self.compute_anomaly_map(patch_scores, feature_map_shape) - anomaly_score = self.compute_anomaly_score(patch_scores) - return anomaly_map, anomaly_score - - -class PatchcoreModel(DynamicBufferModule, nn.Module): - """Patchcore Module.""" - - def __init__( - self, - layers: List[str], - input_size: Tuple[int, int], - backbone: str = "wide_resnet50_2", - apply_tiling: bool = False, - tile_size: Optional[Tuple[int, int]] = None, - tile_stride: Optional[int] = None, - ) -> None: - super().__init__() - - self.backbone = getattr(torchvision.models, backbone) - self.layers = layers - self.input_size = input_size - self.apply_tiling = apply_tiling - - self.feature_extractor = FeatureExtractor(backbone=self.backbone(pretrained=True), layers=self.layers) - self.feature_pooler = torch.nn.AvgPool2d(3, 1, 1) - self.anomaly_map_generator = AnomalyMapGenerator(input_size=input_size) - - if apply_tiling: - assert tile_size is not None - assert tile_stride is not None - self.tiler = Tiler(tile_size, tile_stride) - - self.register_buffer("memory_bank", torch.Tensor()) - self.memory_bank: torch.Tensor - - def forward(self, input_tensor: Tensor) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - """Return Embedding during training, or a tuple of anomaly map and anomaly score during testing. - - Steps performed: - 1. Get features from a CNN. - 2. Generate embedding based on the features. - 3. Compute anomaly map in test mode. - - Args: - input_tensor (Tensor): Input tensor - - Returns: - Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: Embedding for training, - anomaly map and anomaly score for testing. - """ - if self.apply_tiling: - input_tensor = self.tiler.tile(input_tensor) - - with torch.no_grad(): - features = self.feature_extractor(input_tensor) - - features = {layer: self.feature_pooler(feature) for layer, feature in features.items()} - embedding = self.generate_embedding(features) - - if self.apply_tiling: - embedding = self.tiler.untile(embedding) - - feature_map_shape = embedding.shape[-2:] - embedding = self.reshape_embedding(embedding) - - if self.training: - output = embedding - else: - patch_scores = self.nearest_neighbors(embedding=embedding, n_neighbors=9) - anomaly_map, anomaly_score = self.anomaly_map_generator( - patch_scores=patch_scores, feature_map_shape=feature_map_shape - ) - output = (anomaly_map, anomaly_score) - - return output - - def generate_embedding(self, features: Dict[str, Tensor]) -> torch.Tensor: - """Generate embedding from hierarchical feature map. - - Args: - features: Hierarchical feature map from a CNN (ResNet18 or WideResnet) - features: Dict[str:Tensor]: - - Returns: - Embedding vector - """ - - embeddings = features[self.layers[0]] - for layer in self.layers[1:]: - layer_embedding = features[layer] - layer_embedding = F.interpolate(layer_embedding, size=embeddings.shape[-2:], mode="nearest") - embeddings = torch.cat((embeddings, layer_embedding), 1) - - return embeddings - - @staticmethod - def reshape_embedding(embedding: Tensor) -> Tensor: - """Reshape Embedding. - - Reshapes Embedding to the following format: - [Batch, Embedding, Patch, Patch] to [Batch*Patch*Patch, Embedding] - - Args: - embedding (Tensor): Embedding tensor extracted from CNN features. - - Returns: - Tensor: Reshaped embedding tensor. - """ - embedding_size = embedding.size(1) - embedding = embedding.permute(0, 2, 3, 1).reshape(-1, embedding_size) - return embedding - - def subsample_embedding(self, embedding: torch.Tensor, sampling_ratio: float) -> None: - """Subsample embedding based on coreset sampling and store to memory. - - Args: - embedding (np.ndarray): Embedding tensor from the CNN - sampling_ratio (float): Coreset sampling ratio - """ - - # Coreset Subsampling - print("Creating CoreSet Sampler via k-Center Greedy") - sampler = KCenterGreedy(embedding=embedding, sampling_ratio=sampling_ratio) - print("Getting the coreset from the main embedding.") - coreset = sampler.sample_coreset() - print("Assigning the coreset as the memory bank.") - self.memory_bank = coreset - - def nearest_neighbors(self, embedding: Tensor, n_neighbors: int = 9) -> Tensor: - """Nearest Neighbours using brute force method and euclidean norm. - - Args: - embedding (Tensor): Features to compare the distance with the memory bank. - n_neighbors (int): Number of neighbors to look at - - Returns: - Tensor: Patch scores. - """ - distances = torch.cdist(embedding, self.memory_bank, p=2.0) # euclidean norm - patch_scores, _ = distances.topk(k=n_neighbors, largest=False, dim=1) - return patch_scores - - -class PatchcoreLightning(AnomalyModule): - """PatchcoreLightning Module to train PatchCore algorithm. - - Args: - layers (List[str]): Layers used for feature extraction - input_size (Tuple[int, int]): Input size for the model. - tile_size (Tuple[int, int]): Tile size - tile_stride (int): Stride for tiling - backbone (str, optional): Pre-trained model backbone. Defaults to "resnet18". - apply_tiling (bool, optional): Apply tiling. Defaults to False. - """ - - def __init__(self, hparams) -> None: - super().__init__(hparams) - - self.model: PatchcoreModel = PatchcoreModel( - layers=hparams.model.layers, - input_size=hparams.model.input_size, - tile_size=hparams.dataset.tiling.tile_size, - tile_stride=hparams.dataset.tiling.stride, - backbone=hparams.model.backbone, - apply_tiling=hparams.dataset.tiling.apply, - ) - self.embeddings: List[Tensor] = [] - - def configure_optimizers(self) -> None: - """Configure optimizers. - - Returns: - None: Do not set optimizers by returning None. - """ - return None - - def training_step(self, batch, _batch_idx): # pylint: disable=arguments-differ - """Generate feature embedding of the batch. - - Args: - batch (Dict[str, Any]): Batch containing image filename, image, label and mask - _batch_idx (int): Batch Index - - Returns: - Dict[str, np.ndarray]: Embedding Vector - """ - self.model.feature_extractor.eval() - embedding = self.model(batch["image"]) - - # NOTE: `self.embedding` appends each batch embedding to - # store the training set embedding. We manually append these - # values mainly due to the new order of hooks introduced after PL v1.4.0 - # https://github.com/PyTorchLightning/pytorch-lightning/pull/7357 - self.embeddings.append(embedding) - - def on_validation_start(self) -> None: - """Apply subsampling to the embedding collected from the training set.""" - # NOTE: Previous anomalib versions fit subsampling at the end of the epoch. - # This is not possible anymore with PyTorch Lightning v1.4.0 since validation - # is run within train epoch. - print("Aggregating the embedding extracted from the training set.") - embeddings = torch.vstack(self.embeddings) - - sampling_ratio = self.hparams.model.coreset_sampling_ratio - self.model.subsample_embedding(embeddings, sampling_ratio) - - def validation_step(self, batch, _): # pylint: disable=arguments-differ - """Get batch of anomaly maps from input image batch. - - Args: - batch (Dict[str, Any]): Batch containing image filename, - image, label and mask - _ (int): Batch Index - - Returns: - Dict[str, Any]: Image filenames, test images, GT and predicted label/masks - """ - - anomaly_maps, anomaly_score = self.model(batch["image"]) - batch["anomaly_maps"] = anomaly_maps - batch["pred_scores"] = anomaly_score.unsqueeze(0) - - return batch diff --git a/anomalib/models/patchcore/torch_model.py b/anomalib/models/patchcore/torch_model.py new file mode 100644 index 0000000000..ba2847a84e --- /dev/null +++ b/anomalib/models/patchcore/torch_model.py @@ -0,0 +1,169 @@ +"""PyTorch model for the PatchCore model implementation.""" + +# Copyright (C) 2020 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. + +from typing import Dict, List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torchvision +from torch import Tensor, nn + +from anomalib.models.components import ( + DynamicBufferModule, + FeatureExtractor, + KCenterGreedy, +) +from anomalib.models.patchcore.anomaly_map import AnomalyMapGenerator +from anomalib.pre_processing import Tiler + + +class PatchcoreModel(DynamicBufferModule, nn.Module): + """Patchcore Module.""" + + def __init__( + self, + layers: List[str], + input_size: Tuple[int, int], + backbone: str = "wide_resnet50_2", + apply_tiling: bool = False, + tile_size: Optional[Tuple[int, int]] = None, + tile_stride: Optional[int] = None, + ) -> None: + super().__init__() + + self.backbone = getattr(torchvision.models, backbone) + self.layers = layers + self.input_size = input_size + self.apply_tiling = apply_tiling + + self.feature_extractor = FeatureExtractor(backbone=self.backbone(pretrained=True), layers=self.layers) + self.feature_pooler = torch.nn.AvgPool2d(3, 1, 1) + self.anomaly_map_generator = AnomalyMapGenerator(input_size=input_size) + + if apply_tiling: + assert tile_size is not None + assert tile_stride is not None + self.tiler = Tiler(tile_size, tile_stride) + + self.register_buffer("memory_bank", torch.Tensor()) + self.memory_bank: torch.Tensor + + def forward(self, input_tensor: Tensor) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """Return Embedding during training, or a tuple of anomaly map and anomaly score during testing. + + Steps performed: + 1. Get features from a CNN. + 2. Generate embedding based on the features. + 3. Compute anomaly map in test mode. + + Args: + input_tensor (Tensor): Input tensor + + Returns: + Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: Embedding for training, + anomaly map and anomaly score for testing. + """ + if self.apply_tiling: + input_tensor = self.tiler.tile(input_tensor) + + with torch.no_grad(): + features = self.feature_extractor(input_tensor) + + features = {layer: self.feature_pooler(feature) for layer, feature in features.items()} + embedding = self.generate_embedding(features) + + if self.apply_tiling: + embedding = self.tiler.untile(embedding) + + feature_map_shape = embedding.shape[-2:] + embedding = self.reshape_embedding(embedding) + + if self.training: + output = embedding + else: + patch_scores = self.nearest_neighbors(embedding=embedding, n_neighbors=9) + anomaly_map, anomaly_score = self.anomaly_map_generator( + patch_scores=patch_scores, feature_map_shape=feature_map_shape + ) + output = (anomaly_map, anomaly_score) + + return output + + def generate_embedding(self, features: Dict[str, Tensor]) -> torch.Tensor: + """Generate embedding from hierarchical feature map. + + Args: + features: Hierarchical feature map from a CNN (ResNet18 or WideResnet) + features: Dict[str:Tensor]: + + Returns: + Embedding vector + """ + + embeddings = features[self.layers[0]] + for layer in self.layers[1:]: + layer_embedding = features[layer] + layer_embedding = F.interpolate(layer_embedding, size=embeddings.shape[-2:], mode="nearest") + embeddings = torch.cat((embeddings, layer_embedding), 1) + + return embeddings + + @staticmethod + def reshape_embedding(embedding: Tensor) -> Tensor: + """Reshape Embedding. + + Reshapes Embedding to the following format: + [Batch, Embedding, Patch, Patch] to [Batch*Patch*Patch, Embedding] + + Args: + embedding (Tensor): Embedding tensor extracted from CNN features. + + Returns: + Tensor: Reshaped embedding tensor. + """ + embedding_size = embedding.size(1) + embedding = embedding.permute(0, 2, 3, 1).reshape(-1, embedding_size) + return embedding + + def subsample_embedding(self, embedding: torch.Tensor, sampling_ratio: float) -> None: + """Subsample embedding based on coreset sampling and store to memory. + + Args: + embedding (np.ndarray): Embedding tensor from the CNN + sampling_ratio (float): Coreset sampling ratio + """ + + # Coreset Subsampling + print("Creating CoreSet Sampler via k-Center Greedy") + sampler = KCenterGreedy(embedding=embedding, sampling_ratio=sampling_ratio) + print("Getting the coreset from the main embedding.") + coreset = sampler.sample_coreset() + print("Assigning the coreset as the memory bank.") + self.memory_bank = coreset + + def nearest_neighbors(self, embedding: Tensor, n_neighbors: int = 9) -> Tensor: + """Nearest Neighbours using brute force method and euclidean norm. + + Args: + embedding (Tensor): Features to compare the distance with the memory bank. + n_neighbors (int): Number of neighbors to look at + + Returns: + Tensor: Patch scores. + """ + distances = torch.cdist(embedding, self.memory_bank, p=2.0) # euclidean norm + patch_scores, _ = distances.topk(k=n_neighbors, largest=False, dim=1) + return patch_scores diff --git a/anomalib/models/stfpm/__init__.py b/anomalib/models/stfpm/__init__.py index 4da5c046ca..e52c48b498 100644 --- a/anomalib/models/stfpm/__init__.py +++ b/anomalib/models/stfpm/__init__.py @@ -14,9 +14,6 @@ # See the License for the specific language governing permissions # and limitations under the License. -from .model import ( # noqa # pylint: disable=unused-import - AnomalyMapGenerator, - Loss, - StfpmLightning, - STFPMModel, -) +from .lightning_model import StfpmLightning + +__all__ = ["StfpmLightning"] diff --git a/anomalib/models/stfpm/anomaly_map.py b/anomalib/models/stfpm/anomaly_map.py new file mode 100644 index 0000000000..a087d90744 --- /dev/null +++ b/anomalib/models/stfpm/anomaly_map.py @@ -0,0 +1,98 @@ +"""Anomaly Map Generator for the STFPM model implementation.""" + +# Copyright (C) 2020 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. + +from typing import Dict, Tuple, Union + +import torch +import torch.nn.functional as F +from omegaconf import ListConfig +from torch import Tensor + + +class AnomalyMapGenerator: + """Generate Anomaly Heatmap.""" + + def __init__( + self, + image_size: Union[ListConfig, Tuple], + ): + self.distance = torch.nn.PairwiseDistance(p=2, keepdim=True) + self.image_size = image_size if isinstance(image_size, tuple) else tuple(image_size) + + def compute_layer_map(self, teacher_features: Tensor, student_features: Tensor) -> Tensor: + """Compute the layer map based on cosine similarity. + + Args: + teacher_features (Tensor): Teacher features + student_features (Tensor): Student features + + Returns: + Anomaly score based on cosine similarity. + """ + norm_teacher_features = F.normalize(teacher_features) + norm_student_features = F.normalize(student_features) + + layer_map = 0.5 * torch.norm(norm_teacher_features - norm_student_features, p=2, dim=-3, keepdim=True) ** 2 + layer_map = F.interpolate(layer_map, size=self.image_size, align_corners=False, mode="bilinear") + return layer_map + + def compute_anomaly_map( + self, teacher_features: Dict[str, Tensor], student_features: Dict[str, Tensor] + ) -> torch.Tensor: + """Compute the overall anomaly map via element-wise production the interpolated anomaly maps. + + Args: + teacher_features (Dict[str, Tensor]): Teacher features + student_features (Dict[str, Tensor]): Student features + + Returns: + Final anomaly map + """ + batch_size = list(teacher_features.values())[0].shape[0] + anomaly_map = torch.ones(batch_size, 1, self.image_size[0], self.image_size[1]) + for layer in teacher_features.keys(): + layer_map = self.compute_layer_map(teacher_features[layer], student_features[layer]) + anomaly_map = anomaly_map.to(layer_map.device) + anomaly_map *= layer_map + + return anomaly_map + + def __call__(self, **kwds: Dict[str, Tensor]) -> torch.Tensor: + """Returns anomaly map. + + Expects `teach_features` and `student_features` keywords to be passed explicitly. + + Example: + >>> anomaly_map_generator = AnomalyMapGenerator(image_size=tuple(hparams.model.input_size)) + >>> output = self.anomaly_map_generator( + teacher_features=teacher_features, + student_features=student_features + ) + + Raises: + ValueError: `teach_features` and `student_features` keys are not found + + Returns: + torch.Tensor: anomaly map + """ + + if not ("teacher_features" in kwds and "student_features" in kwds): + raise ValueError(f"Expected keys `teacher_features` and `student_features. Found {kwds.keys()}") + + teacher_features: Dict[str, Tensor] = kwds["teacher_features"] + student_features: Dict[str, Tensor] = kwds["student_features"] + + return self.compute_anomaly_map(teacher_features, student_features) diff --git a/anomalib/models/stfpm/lightning_model.py b/anomalib/models/stfpm/lightning_model.py new file mode 100644 index 0000000000..64bb9c7ecf --- /dev/null +++ b/anomalib/models/stfpm/lightning_model.py @@ -0,0 +1,102 @@ +"""STFPM: Student-Teacher Feature Pyramid Matching for Unsupervised Anomaly Detection. + +https://arxiv.org/abs/2103.04257 +""" + +# Copyright (C) 2020 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. + +import torch +from pytorch_lightning.callbacks import EarlyStopping +from torch import optim + +from anomalib.models.components import AnomalyModule +from anomalib.models.stfpm.torch_model import STFPMModel + +__all__ = ["StfpmLightning"] + + +class StfpmLightning(AnomalyModule): + """PL Lightning Module for the STFPM algorithm.""" + + def __init__(self, hparams): + super().__init__(hparams) + + self.model = STFPMModel( + layers=hparams.model.layers, + input_size=hparams.model.input_size, + tile_size=hparams.dataset.tiling.tile_size, + tile_stride=hparams.dataset.tiling.stride, + backbone=hparams.model.backbone, + apply_tiling=hparams.dataset.tiling.apply, + ) + self.loss_val = 0 + + def configure_callbacks(self): + """Configure model-specific callbacks.""" + early_stopping = EarlyStopping( + monitor=self.hparams.model.early_stopping.metric, + patience=self.hparams.model.early_stopping.patience, + mode=self.hparams.model.early_stopping.mode, + ) + return [early_stopping] + + def configure_optimizers(self) -> torch.optim.Optimizer: + """Configure optimizers by creating an SGD optimizer. + + Returns: + (Optimizer): SGD optimizer + """ + return optim.SGD( + params=self.model.student_model.parameters(), + lr=self.hparams.model.lr, + momentum=self.hparams.model.momentum, + weight_decay=self.hparams.model.weight_decay, + ) + + def training_step(self, batch, _): # pylint: disable=arguments-differ + """Training Step of STFPM. + + For each batch, teacher and student and teacher features are extracted from the CNN. + + Args: + batch (Tensor): Input batch + _: Index of the batch. + + Returns: + Hierarchical feature map + """ + self.model.teacher_model.eval() + teacher_features, student_features = self.model.forward(batch["image"]) + loss = self.loss_val + self.model.loss(teacher_features, student_features) + self.loss_val = 0 + return {"loss": loss} + + def validation_step(self, batch, _): # pylint: disable=arguments-differ + """Validation Step of STFPM. + + Similar to the training step, student/teacher features are extracted from the CNN for each batch, and + anomaly map is computed. + + Args: + batch (Tensor): Input batch + _: Index of the batch. + + Returns: + Dictionary containing images, anomaly maps, true labels and masks. + These are required in `validation_epoch_end` for feature concatenation. + """ + batch["anomaly_maps"] = self.model(batch["image"]) + + return batch diff --git a/anomalib/models/stfpm/model.py b/anomalib/models/stfpm/model.py deleted file mode 100644 index ed5e7f2b2f..0000000000 --- a/anomalib/models/stfpm/model.py +++ /dev/null @@ -1,313 +0,0 @@ -"""STFPM: Student-Teacher Feature Pyramid Matching for Unsupervised Anomaly Detection. - -https://arxiv.org/abs/2103.04257 -""" - -# Copyright (C) 2020 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions -# and limitations under the License. - -from typing import Dict, List, Optional, Tuple, Union - -import torch -import torch.nn.functional as F -import torchvision -from omegaconf import ListConfig -from pytorch_lightning.callbacks import EarlyStopping -from torch import Tensor, nn, optim - -from anomalib.models.components import AnomalyModule, FeatureExtractor -from anomalib.pre_processing import Tiler - -__all__ = ["Loss", "AnomalyMapGenerator", "STFPMModel", "StfpmLightning"] - - -class Loss(nn.Module): - """Feature Pyramid Loss This class implmenents the feature pyramid loss function proposed in STFPM paper. - - Example: - >>> from anomalib.core.model.feature_extractor import FeatureExtractor - >>> from anomalib.models.stfpm.model import Loss - >>> from torchvision.models import resnet18 - - >>> layers = ['layer1', 'layer2', 'layer3'] - >>> teacher_model = FeatureExtractor(model=resnet18(pretrained=True), layers=layers) - >>> student_model = FeatureExtractor(model=resnet18(pretrained=False), layers=layers) - >>> loss = Loss() - - >>> inp = torch.rand((4, 3, 256, 256)) - >>> teacher_features = teacher_model(inp) - >>> student_features = student_model(inp) - >>> loss(student_features, teacher_features) - tensor(51.2015, grad_fn=) - """ - - def __init__(self): - super().__init__() - self.mse_loss = nn.MSELoss(reduction="sum") - - def compute_layer_loss(self, teacher_feats: Tensor, student_feats: Tensor) -> Tensor: - """Compute layer loss based on Equation (1) in Section 3.2 of the paper. - - Args: - teacher_feats (Tensor): Teacher features - student_feats (Tensor): Student features - - Returns: - L2 distance between teacher and student features. - """ - - height, width = teacher_feats.shape[2:] - - norm_teacher_features = F.normalize(teacher_feats) - norm_student_features = F.normalize(student_feats) - layer_loss = (0.5 / (width * height)) * self.mse_loss(norm_teacher_features, norm_student_features) - - return layer_loss - - def forward(self, teacher_features: Dict[str, Tensor], student_features: Dict[str, Tensor]) -> Tensor: - """Compute the overall loss via the weighted average of the layer losses computed by the cosine similarity. - - Args: - teacher_features (Dict[str, Tensor]): Teacher features - student_features (Dict[str, Tensor]): Student features - - Returns: - Total loss, which is the weighted average of the layer losses. - """ - - layer_losses: List[Tensor] = [] - for layer in teacher_features.keys(): - loss = self.compute_layer_loss(teacher_features[layer], student_features[layer]) - layer_losses.append(loss) - - total_loss = torch.stack(layer_losses).sum() - - return total_loss - - -class AnomalyMapGenerator: - """Generate Anomaly Heatmap.""" - - def __init__( - self, - image_size: Union[ListConfig, Tuple], - ): - self.distance = torch.nn.PairwiseDistance(p=2, keepdim=True) - self.image_size = image_size if isinstance(image_size, tuple) else tuple(image_size) - - def compute_layer_map(self, teacher_features: Tensor, student_features: Tensor) -> Tensor: - """Compute the layer map based on cosine similarity. - - Args: - teacher_features (Tensor): Teacher features - student_features (Tensor): Student features - - Returns: - Anomaly score based on cosine similarity. - """ - norm_teacher_features = F.normalize(teacher_features) - norm_student_features = F.normalize(student_features) - - layer_map = 0.5 * torch.norm(norm_teacher_features - norm_student_features, p=2, dim=-3, keepdim=True) ** 2 - layer_map = F.interpolate(layer_map, size=self.image_size, align_corners=False, mode="bilinear") - return layer_map - - def compute_anomaly_map( - self, teacher_features: Dict[str, Tensor], student_features: Dict[str, Tensor] - ) -> torch.Tensor: - """Compute the overall anomaly map via element-wise production the interpolated anomaly maps. - - Args: - teacher_features (Dict[str, Tensor]): Teacher features - student_features (Dict[str, Tensor]): Student features - - Returns: - Final anomaly map - """ - batch_size = list(teacher_features.values())[0].shape[0] - anomaly_map = torch.ones(batch_size, 1, self.image_size[0], self.image_size[1]) - for layer in teacher_features.keys(): - layer_map = self.compute_layer_map(teacher_features[layer], student_features[layer]) - anomaly_map = anomaly_map.to(layer_map.device) - anomaly_map *= layer_map - - return anomaly_map - - def __call__(self, **kwds: Dict[str, Tensor]) -> torch.Tensor: - """Returns anomaly map. - - Expects `teach_features` and `student_features` keywords to be passed explicitly. - - Example: - >>> anomaly_map_generator = AnomalyMapGenerator(image_size=tuple(hparams.model.input_size)) - >>> output = self.anomaly_map_generator( - teacher_features=teacher_features, - student_features=student_features - ) - - Raises: - ValueError: `teach_features` and `student_features` keys are not found - - Returns: - torch.Tensor: anomaly map - """ - - if not ("teacher_features" in kwds and "student_features" in kwds): - raise ValueError(f"Expected keys `teacher_features` and `student_features. Found {kwds.keys()}") - - teacher_features: Dict[str, Tensor] = kwds["teacher_features"] - student_features: Dict[str, Tensor] = kwds["student_features"] - - return self.compute_anomaly_map(teacher_features, student_features) - - -class STFPMModel(nn.Module): - """STFPM: Student-Teacher Feature Pyramid Matching for Unsupervised Anomaly Detection. - - Args: - layers (List[str]): Layers used for feature extraction - input_size (Tuple[int, int]): Input size for the model. - tile_size (Tuple[int, int]): Tile size - tile_stride (int): Stride for tiling - backbone (str, optional): Pre-trained model backbone. Defaults to "resnet18". - apply_tiling (bool, optional): Apply tiling. Defaults to False. - """ - - def __init__( - self, - layers: List[str], - input_size: Tuple[int, int], - backbone: str = "resnet18", - apply_tiling: bool = False, - tile_size: Optional[Tuple[int, int]] = None, - tile_stride: Optional[int] = None, - ): - super().__init__() - self.backbone = getattr(torchvision.models, backbone) - self.apply_tiling = apply_tiling - self.teacher_model = FeatureExtractor(backbone=self.backbone(pretrained=True), layers=layers) - self.student_model = FeatureExtractor(backbone=self.backbone(pretrained=False), layers=layers) - - # teacher model is fixed - for parameters in self.teacher_model.parameters(): - parameters.requires_grad = False - - self.loss = Loss() - if self.apply_tiling: - assert tile_size is not None - assert tile_stride is not None - self.tiler = Tiler(tile_size, tile_stride) - self.anomaly_map_generator = AnomalyMapGenerator(image_size=tuple(tile_size)) - else: - self.anomaly_map_generator = AnomalyMapGenerator(image_size=tuple(input_size)) - - def forward(self, images): - """Forward-pass images into the network. - - During the training mode the model extracts the features from the teacher and student networks. - During the evaluation mode, it returns the predicted anomaly map. - - Args: - images (Tensor): Batch of images. - - Returns: - Teacher and student features when in training mode, otherwise the predicted anomaly maps. - """ - if self.apply_tiling: - images = self.tiler.tile(images) - teacher_features: Dict[str, Tensor] = self.teacher_model(images) - student_features: Dict[str, Tensor] = self.student_model(images) - if self.training: - output = teacher_features, student_features - else: - output = self.anomaly_map_generator(teacher_features=teacher_features, student_features=student_features) - if self.apply_tiling: - output = self.tiler.untile(output) - - return output - - -class StfpmLightning(AnomalyModule): - """PL Lightning Module for the STFPM algorithm.""" - - def __init__(self, hparams): - super().__init__(hparams) - - self.model = STFPMModel( - layers=hparams.model.layers, - input_size=hparams.model.input_size, - tile_size=hparams.dataset.tiling.tile_size, - tile_stride=hparams.dataset.tiling.stride, - backbone=hparams.model.backbone, - apply_tiling=hparams.dataset.tiling.apply, - ) - self.loss_val = 0 - - def configure_callbacks(self): - """Configure model-specific callbacks.""" - early_stopping = EarlyStopping( - monitor=self.hparams.model.early_stopping.metric, - patience=self.hparams.model.early_stopping.patience, - mode=self.hparams.model.early_stopping.mode, - ) - return [early_stopping] - - def configure_optimizers(self) -> torch.optim.Optimizer: - """Configure optimizers by creating an SGD optimizer. - - Returns: - (Optimizer): SGD optimizer - """ - return optim.SGD( - params=self.model.student_model.parameters(), - lr=self.hparams.model.lr, - momentum=self.hparams.model.momentum, - weight_decay=self.hparams.model.weight_decay, - ) - - def training_step(self, batch, _): # pylint: disable=arguments-differ - """Training Step of STFPM. - - For each batch, teacher and student and teacher features are extracted from the CNN. - - Args: - batch (Tensor): Input batch - _: Index of the batch. - - Returns: - Hierarchical feature map - """ - self.model.teacher_model.eval() - teacher_features, student_features = self.model.forward(batch["image"]) - loss = self.loss_val + self.model.loss(teacher_features, student_features) - self.loss_val = 0 - return {"loss": loss} - - def validation_step(self, batch, _): # pylint: disable=arguments-differ - """Validation Step of STFPM. - - Similar to the training step, student/teacher features are extracted from the CNN for each batch, and - anomaly map is computed. - - Args: - batch (Tensor): Input batch - _: Index of the batch. - - Returns: - Dictionary containing images, anomaly maps, true labels and masks. - These are required in `validation_epoch_end` for feature concatenation. - """ - batch["anomaly_maps"] = self.model(batch["image"]) - - return batch diff --git a/anomalib/models/stfpm/torch_model.py b/anomalib/models/stfpm/torch_model.py new file mode 100644 index 0000000000..1e1b650e78 --- /dev/null +++ b/anomalib/models/stfpm/torch_model.py @@ -0,0 +1,156 @@ +"""PyTorch model for the STFPM model implementation.""" + +# Copyright (C) 2020 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. + +from typing import Dict, List, Optional, Tuple + +import torch +import torch.nn.functional as F +import torchvision +from torch import Tensor, nn + +from anomalib.models.components import FeatureExtractor +from anomalib.models.stfpm.anomaly_map import AnomalyMapGenerator +from anomalib.pre_processing import Tiler + + +class Loss(nn.Module): + """Feature Pyramid Loss This class implmenents the feature pyramid loss function proposed in STFPM paper. + + Example: + >>> from anomalib.models.components.feature_extractors.feature_extractor import FeatureExtractor + >>> from anomalib.models.stfpm.torch_model import Loss + >>> from torchvision.models import resnet18 + + >>> layers = ['layer1', 'layer2', 'layer3'] + >>> teacher_model = FeatureExtractor(model=resnet18(pretrained=True), layers=layers) + >>> student_model = FeatureExtractor(model=resnet18(pretrained=False), layers=layers) + >>> loss = Loss() + + >>> inp = torch.rand((4, 3, 256, 256)) + >>> teacher_features = teacher_model(inp) + >>> student_features = student_model(inp) + >>> loss(student_features, teacher_features) + tensor(51.2015, grad_fn=) + """ + + def __init__(self): + super().__init__() + self.mse_loss = nn.MSELoss(reduction="sum") + + def compute_layer_loss(self, teacher_feats: Tensor, student_feats: Tensor) -> Tensor: + """Compute layer loss based on Equation (1) in Section 3.2 of the paper. + + Args: + teacher_feats (Tensor): Teacher features + student_feats (Tensor): Student features + + Returns: + L2 distance between teacher and student features. + """ + + height, width = teacher_feats.shape[2:] + + norm_teacher_features = F.normalize(teacher_feats) + norm_student_features = F.normalize(student_feats) + layer_loss = (0.5 / (width * height)) * self.mse_loss(norm_teacher_features, norm_student_features) + + return layer_loss + + def forward(self, teacher_features: Dict[str, Tensor], student_features: Dict[str, Tensor]) -> Tensor: + """Compute the overall loss via the weighted average of the layer losses computed by the cosine similarity. + + Args: + teacher_features (Dict[str, Tensor]): Teacher features + student_features (Dict[str, Tensor]): Student features + + Returns: + Total loss, which is the weighted average of the layer losses. + """ + + layer_losses: List[Tensor] = [] + for layer in teacher_features.keys(): + loss = self.compute_layer_loss(teacher_features[layer], student_features[layer]) + layer_losses.append(loss) + + total_loss = torch.stack(layer_losses).sum() + + return total_loss + + +class STFPMModel(nn.Module): + """STFPM: Student-Teacher Feature Pyramid Matching for Unsupervised Anomaly Detection. + + Args: + layers (List[str]): Layers used for feature extraction + input_size (Tuple[int, int]): Input size for the model. + tile_size (Tuple[int, int]): Tile size + tile_stride (int): Stride for tiling + backbone (str, optional): Pre-trained model backbone. Defaults to "resnet18". + apply_tiling (bool, optional): Apply tiling. Defaults to False. + """ + + def __init__( + self, + layers: List[str], + input_size: Tuple[int, int], + backbone: str = "resnet18", + apply_tiling: bool = False, + tile_size: Optional[Tuple[int, int]] = None, + tile_stride: Optional[int] = None, + ): + super().__init__() + self.backbone = getattr(torchvision.models, backbone) + self.apply_tiling = apply_tiling + self.teacher_model = FeatureExtractor(backbone=self.backbone(pretrained=True), layers=layers) + self.student_model = FeatureExtractor(backbone=self.backbone(pretrained=False), layers=layers) + + # teacher model is fixed + for parameters in self.teacher_model.parameters(): + parameters.requires_grad = False + + self.loss = Loss() + if self.apply_tiling: + assert tile_size is not None + assert tile_stride is not None + self.tiler = Tiler(tile_size, tile_stride) + self.anomaly_map_generator = AnomalyMapGenerator(image_size=tuple(tile_size)) + else: + self.anomaly_map_generator = AnomalyMapGenerator(image_size=tuple(input_size)) + + def forward(self, images): + """Forward-pass images into the network. + + During the training mode the model extracts the features from the teacher and student networks. + During the evaluation mode, it returns the predicted anomaly map. + + Args: + images (Tensor): Batch of images. + + Returns: + Teacher and student features when in training mode, otherwise the predicted anomaly maps. + """ + if self.apply_tiling: + images = self.tiler.tile(images) + teacher_features: Dict[str, Tensor] = self.teacher_model(images) + student_features: Dict[str, Tensor] = self.student_model(images) + if self.training: + output = teacher_features, student_features + else: + output = self.anomaly_map_generator(teacher_features=teacher_features, student_features=student_features) + if self.apply_tiling: + output = self.tiler.untile(output) + + return output