diff --git a/anomalib/models/__init__.py b/anomalib/models/__init__.py index 766bd253e2..3ee3d44932 100644 --- a/anomalib/models/__init__.py +++ b/anomalib/models/__init__.py @@ -42,7 +42,7 @@ def get_model(config: Union[DictConfig, ListConfig]) -> AnomalyModule: Returns: AnomalyModule: Anomaly Model """ - model_list: List[str] = ["cflow", "dfkde", "dfm", "fastflow", "ganomaly", "padim", "patchcore", "stfpm"] + model_list: List[str] = ["cflow", "dfkde", "dfm", "draem", "fastflow", "ganomaly", "padim", "patchcore", "stfpm"] model: AnomalyModule if config.model.name in model_list: diff --git a/anomalib/models/draem/LICENSE b/anomalib/models/draem/LICENSE new file mode 100644 index 0000000000..7025d18fb4 --- /dev/null +++ b/anomalib/models/draem/LICENSE @@ -0,0 +1,29 @@ +Copyright (c) 2022 Intel Corporation +SPDX-License-Identifier: Apache-2.0 + +Some files in this folder are based on the original DRAEM implementation by VitjanZ + +Original license: +---------------- + + MIT License + + Copyright (c) 2021 VitjanZ + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. diff --git a/anomalib/models/draem/README.md b/anomalib/models/draem/README.md new file mode 100644 index 0000000000..8780574c98 --- /dev/null +++ b/anomalib/models/draem/README.md @@ -0,0 +1,22 @@ +# DRÆM – A discriminatively trained reconstruction embedding for surface anomaly detection + +This is the implementation of the [DRAEM](https://arxiv.org/pdf/2108.07610v2.pdf) paper. + +Model Type: Segmentation + +## Description + +DRAEM is a reconstruction based algorithm that consists of a reconstructive subnetwork and a discriminative subnetwork. DRAEM is trained on simulated anomaly images, generated by augmenting normal input images from the training set with a random Perlin noise mask extracted from an unrelated source of image data. The reconstructive subnetwork is an autoencoder architecture that is trained to reconstruct the original input images from the augmented images. The reconstructive submodel is trained using a combination of L2 loss and Structural Similarity loss. The input of the discriminative subnetwork consists of the channel-wise concatenation of the (augmented) input image and the output of the reconstructive subnetwork. The output of the discriminative subnetwork is an anomaly map that contains the predicted anomaly scores for each pixel location. The discriminative subnetwork is trained using Focal Loss. + +For optimal results, DRAEM requires specifying the path to a folder of image data that will be used as the source of the anomalous pixel regions in the simulated anomaly images. The path can be specified by editing the value of the `model.anomaly_source_path` parameter in the `config.yaml` file. The authors of the original paper recommend using the [DTD](https://www.robots.ox.ac.uk/~vgg/data/dtd/) dataset as anomaly source. + +## Architecture +![DRAEM Architecture](../../../docs/source/images/draem/architecture.png "DRAEM Architecture") + +## Usage + +`python tools/train.py --model draem` + +## Benchmark + +Benchmarking results are not yet available for this algorithm. Please check again later. diff --git a/anomalib/models/draem/__init__.py b/anomalib/models/draem/__init__.py new file mode 100644 index 0000000000..68091b1f91 --- /dev/null +++ b/anomalib/models/draem/__init__.py @@ -0,0 +1,8 @@ +"""DRAEM model.""" + +# Copyright (C) 2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from .lightning_model import DraemLightning + +__all__ = ["DraemLightning"] diff --git a/anomalib/models/draem/config.yaml b/anomalib/models/draem/config.yaml new file mode 100644 index 0000000000..d619be26e0 --- /dev/null +++ b/anomalib/models/draem/config.yaml @@ -0,0 +1,102 @@ +dataset: + name: mvtec #options: [mvtec, btech, folder] + format: mvtec + path: ./datasets/MVTec + category: bottle + task: segmentation + image_size: 256 + train_batch_size: 8 + test_batch_size: 32 + num_workers: 8 + transform_config: + train: ./anomalib/models/draem/transform_config.yaml + val: ./anomalib/models/draem/transform_config.yaml + create_validation_set: false + tiling: + apply: false + tile_size: null + stride: null + remove_border_count: 0 + use_random_tiling: False + random_tile_count: 16 + +model: + name: draem + anomaly_source_path: null # optional, e.g. ./datasets/dtd + lr: 0.0001 + early_stopping: + patience: 50 + metric: pixel_AUROC + mode: max + normalization_method: min_max # options: [none, min_max, cdf] + +metrics: + image: + - F1Score + - AUROC + pixel: + - F1Score + - AUROC + threshold: + image_default: 3 + pixel_default: 3 + adaptive: true + +project: + seed: 42 + path: ./results + log_images_to: ["local"] + logger: false # options: [tensorboard, wandb, csv] or combinations. + +optimization: + openvino: + apply: false + +# PL Trainer Args. Don't add extra parameter here. +trainer: + accelerator: auto # <"cpu", "gpu", "tpu", "ipu", "hpu", "auto"> + accumulate_grad_batches: 1 + amp_backend: native + auto_lr_find: false + auto_scale_batch_size: false + auto_select_gpus: false + benchmark: false + check_val_every_n_epoch: 1 + default_root_dir: null + detect_anomaly: false + deterministic: false + devices: 1 + enable_checkpointing: true + enable_model_summary: true + enable_progress_bar: true + fast_dev_run: false + gpus: null # Set automatically + gradient_clip_val: 0 + ipus: null + limit_predict_batches: 1.0 + limit_test_batches: 1.0 + limit_train_batches: 1.0 + limit_val_batches: 1.0 + log_every_n_steps: 50 + log_gpu_memory: null + max_epochs: 100 + max_steps: -1 + max_time: null + min_epochs: null + min_steps: null + move_metrics_to_cpu: false + multiple_trainloader_mode: max_size_cycle + num_nodes: 1 + num_processes: null + num_sanity_val_steps: 0 + overfit_batches: 0.0 + plugins: null + precision: 32 + profiler: null + reload_dataloaders_every_n_epochs: 0 + replace_sampler_ddp: true + strategy: null + sync_batchnorm: false + tpu_cores: null + track_grad_norm: -1 + val_check_interval: 1.0 diff --git a/anomalib/models/draem/lightning_model.py b/anomalib/models/draem/lightning_model.py new file mode 100644 index 0000000000..72656869dc --- /dev/null +++ b/anomalib/models/draem/lightning_model.py @@ -0,0 +1,108 @@ +"""DRÆM – A discriminatively trained reconstruction embedding for surface anomaly detection. + +Paper https://arxiv.org/abs/2108.07610 +""" + +# Copyright (C) 2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import logging +from typing import Optional, Union + +import torch +from omegaconf import DictConfig, ListConfig +from pytorch_lightning.callbacks import EarlyStopping +from pytorch_lightning.utilities.cli import MODEL_REGISTRY + +from anomalib.models.components import AnomalyModule +from anomalib.models.draem.loss import DraemLoss +from anomalib.models.draem.torch_model import DraemModel +from anomalib.models.draem.utils import Augmenter + +logger = logging.getLogger(__name__) + +__all__ = ["Draem", "DraemLightning"] + + +@MODEL_REGISTRY +class Draem(AnomalyModule): + """DRÆM: A discriminatively trained reconstruction embedding for surface anomaly detection. + + Args: + anomaly_source_path (Optional[str]): Path to folder that contains the anomaly source images. Random noise will + be used if left empty. + """ + + def __init__(self, anomaly_source_path: Optional[str] = None): + super().__init__() + + self.augmenter = Augmenter(anomaly_source_path) + self.model = DraemModel() + self.loss = DraemLoss() + + def training_step(self, batch, _): # pylint: disable=arguments-differ + """Training Step of DRAEM. + + Feeds the original image and the simulated anomaly + image through the network and computes the training loss. + + Args: + batch (Dict[str, Any]): Batch containing image filename, image, label and mask + + Returns: + Loss dictionary + """ + input_image = batch["image"] + # Apply corruption to input image + augmented_image, anomaly_mask = self.augmenter.augment_batch(input_image) + # Generate model prediction + reconstruction, prediction = self.model(augmented_image) + # Compute loss + loss = self.loss(input_image, reconstruction, anomaly_mask, prediction) + return {"loss": loss} + + def validation_step(self, batch, _): + """Validation step of DRAEM. The Softmax predictions of the anomalous class are used as anomaly map. + + Args: + batch: Batch of input images + + Returns: + Dictionary to which predicted anomaly maps have been added. + """ + prediction = self.model(batch["image"]) + batch["anomaly_maps"] = prediction[:, 1, :, :] + return batch + + +class DraemLightning(Draem): + """DRÆM: A discriminatively trained reconstruction embedding for surface anomaly detection. + + Args: + hparams (Union[DictConfig, ListConfig]): Model parameters + """ + + def __init__(self, hparams: Union[DictConfig, ListConfig]): + super().__init__(anomaly_source_path=hparams.model.anomaly_source_path) + self.hparams: Union[DictConfig, ListConfig] # type: ignore + self.save_hyperparameters(hparams) + + def configure_callbacks(self): + """Configure model-specific callbacks. + + Note: + This method is used for the existing CLI. + When PL CLI is introduced, configure callback method will be + deprecated, and callbacks will be configured from either + config.yaml file or from CLI. + """ + early_stopping = EarlyStopping( + monitor=self.hparams.model.early_stopping.metric, + patience=self.hparams.model.early_stopping.patience, + mode=self.hparams.model.early_stopping.mode, + ) + return [early_stopping] + + def configure_optimizers(self): # pylint: disable=arguments-differ + """Configure the Adam optimizer.""" + return torch.optim.Adam(params=self.model.parameters(), lr=self.hparams.model.lr) diff --git a/anomalib/models/draem/loss.py b/anomalib/models/draem/loss.py new file mode 100644 index 0000000000..1e05bfa073 --- /dev/null +++ b/anomalib/models/draem/loss.py @@ -0,0 +1,29 @@ +"""Loss function for the DRAEM model implementation.""" + +# Copyright (C) 2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from kornia.losses import FocalLoss, SSIMLoss +from torch import nn + + +class DraemLoss(nn.Module): + """Overall loss function of the DRAEM model. + + The total loss consists of the sum of the L2 loss and Focal loss between the reconstructed image and the input + image, and the Structural Similarity loss between the predicted and GT anomaly masks. + """ + + def __init__(self): + super().__init__() + + self.l2_loss = nn.modules.loss.MSELoss() + self.focal_loss = FocalLoss(alpha=1, reduction="mean") + self.ssim_loss = SSIMLoss(window_size=11) + + def forward(self, input_image, reconstruction, anomaly_mask, prediction): + """Compute the loss over a batch for the DRAEM model.""" + l2_loss_val = self.l2_loss(reconstruction, input_image) + focal_loss_val = self.focal_loss(prediction, anomaly_mask.squeeze(1).long()) + ssim_loss_val = self.ssim_loss(reconstruction, input_image) * 2 + return l2_loss_val + ssim_loss_val + focal_loss_val diff --git a/anomalib/models/draem/torch_model.py b/anomalib/models/draem/torch_model.py new file mode 100644 index 0000000000..9da0dab87c --- /dev/null +++ b/anomalib/models/draem/torch_model.py @@ -0,0 +1,488 @@ +"""PyTorch model for the DRAEM model implementation.""" + +# Original Code +# Copyright (c) 2021 VitjanZ +# https://github.com/VitjanZ/DRAEM. +# SPDX-License-Identifier: MIT +# +# Modified +# Copyright (C) 2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from typing import Tuple, Union + +import torch +from torch import Tensor, nn + + +class DraemModel(nn.Module): + """DRAEM PyTorch model consisting of the reconstructive and discriminative sub networks.""" + + def __init__(self): + super().__init__() + self.reconstructive_subnetwork = ReconstructiveSubNetwork() + self.discriminative_subnetwork = DiscriminativeSubNetwork(in_channels=6, out_channels=2) + + def forward(self, batch: Tensor) -> Union[Tensor, Tuple[Tensor, Tensor]]: + """Compute the reconstruction and anomaly mask from an input image. + + Args: + x (Tensor): batch of input images + + Returns: + Predicted confidence values of the anomaly mask. During training the reconstructed input images are + returned as well. + """ + reconstruction = self.reconstructive_subnetwork(batch) + concatenated_inputs = torch.cat([batch, reconstruction], axis=1) + prediction = self.discriminative_subnetwork(concatenated_inputs) + if self.training: + return reconstruction, prediction + return torch.softmax(prediction, dim=1) + + +class ReconstructiveSubNetwork(nn.Module): + """Autoencoder model that encodes and reconstructs the input image. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + base_width (int): Base dimensionality of the layers of the autoencoder. + """ + + def __init__(self, in_channels: int = 3, out_channels: int = 3, base_width=128): + super().__init__() + self.encoder = EncoderReconstructive(in_channels, base_width) + self.decoder = DecoderReconstructive(base_width, out_channels=out_channels) + + def forward(self, batch: Tensor) -> Tensor: + """Encode and reconstruct the input images. + + Args: + batch (Tensor): Batch of input images + + Returns: + Batch of reconstructed images. + """ + encoded = self.encoder(batch) + decoded = self.decoder(encoded) + return decoded + + +class DiscriminativeSubNetwork(nn.Module): + """Discriminative model that predicts the anomaly mask from the original image and its reconstruction. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + base_width (int): Base dimensionality of the layers of the autoencoder. + """ + + def __init__(self, in_channels: int = 3, out_channels: int = 3, base_width: int = 64): + super().__init__() + self.encoder_segment = EncoderDiscriminative(in_channels, base_width) + self.decoder_segment = DecoderDiscriminative(base_width, out_channels=out_channels) + + def forward(self, batch: Tensor) -> Tensor: + """Generate the predicted anomaly masks for a batch of input images. + + Args: + batch (Tensor): Batch of inputs consisting of the concatenation of the original images + and their reconstructions. + + Returns: + Activations of the output layer corresponding to the normal and anomalous class scores on the pixel level. + """ + act1, act2, act3, act4, act5, act6 = self.encoder_segment(batch) + segmentation = self.decoder_segment(act1, act2, act3, act4, act5, act6) + return segmentation + + +class EncoderDiscriminative(nn.Module): + """Encoder part of the discriminator network. + + Args: + in_channels (int): Number of input channels. + base_width (int): Base dimensionality of the layers of the autoencoder. + """ + + def __init__(self, in_channels: int, base_width: int): + super().__init__() + self.block1 = nn.Sequential( + nn.Conv2d(in_channels, base_width, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width), + nn.ReLU(inplace=True), + nn.Conv2d(base_width, base_width, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width), + nn.ReLU(inplace=True), + ) + self.mp1 = nn.Sequential(nn.MaxPool2d(2)) + self.block2 = nn.Sequential( + nn.Conv2d(base_width, base_width * 2, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width * 2), + nn.ReLU(inplace=True), + nn.Conv2d(base_width * 2, base_width * 2, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width * 2), + nn.ReLU(inplace=True), + ) + self.mp2 = nn.Sequential(nn.MaxPool2d(2)) + self.block3 = nn.Sequential( + nn.Conv2d(base_width * 2, base_width * 4, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width * 4), + nn.ReLU(inplace=True), + nn.Conv2d(base_width * 4, base_width * 4, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width * 4), + nn.ReLU(inplace=True), + ) + self.mp3 = nn.Sequential(nn.MaxPool2d(2)) + self.block4 = nn.Sequential( + nn.Conv2d(base_width * 4, base_width * 8, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width * 8), + nn.ReLU(inplace=True), + nn.Conv2d(base_width * 8, base_width * 8, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width * 8), + nn.ReLU(inplace=True), + ) + self.mp4 = nn.Sequential(nn.MaxPool2d(2)) + self.block5 = nn.Sequential( + nn.Conv2d(base_width * 8, base_width * 8, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width * 8), + nn.ReLU(inplace=True), + nn.Conv2d(base_width * 8, base_width * 8, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width * 8), + nn.ReLU(inplace=True), + ) + + self.mp5 = nn.Sequential(nn.MaxPool2d(2)) + self.block6 = nn.Sequential( + nn.Conv2d(base_width * 8, base_width * 8, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width * 8), + nn.ReLU(inplace=True), + nn.Conv2d(base_width * 8, base_width * 8, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width * 8), + nn.ReLU(inplace=True), + ) + + def forward(self, batch: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: + """Convert the inputs to the salient space by running them through the encoder network. + + Args: + batch (Tensor): Batch of inputs consisting of the concatenation of the original images + and their reconstructions. + + Returns: + Computed feature maps for each of the layers in the encoder sub network. + """ + act1 = self.block1(batch) + mp1 = self.mp1(act1) + act2 = self.block2(mp1) + mp2 = self.mp3(act2) + act3 = self.block3(mp2) + mp3 = self.mp3(act3) + act4 = self.block4(mp3) + mp4 = self.mp4(act4) + act5 = self.block5(mp4) + mp5 = self.mp5(act5) + act6 = self.block6(mp5) + return act1, act2, act3, act4, act5, act6 + + +class DecoderDiscriminative(nn.Module): + """Decoder part of the discriminator network. + + Args: + base_width (int): Base dimensionality of the layers of the autoencoder. + out_channels (int): Number of output channels. + """ + + def __init__(self, base_width: int, out_channels: int = 1): + super().__init__() + + self.up_b = nn.Sequential( + nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True), + nn.Conv2d(base_width * 8, base_width * 8, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width * 8), + nn.ReLU(inplace=True), + ) + self.db_b = nn.Sequential( + nn.Conv2d(base_width * (8 + 8), base_width * 8, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width * 8), + nn.ReLU(inplace=True), + nn.Conv2d(base_width * 8, base_width * 8, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width * 8), + nn.ReLU(inplace=True), + ) + + self.up1 = nn.Sequential( + nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True), + nn.Conv2d(base_width * 8, base_width * 4, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width * 4), + nn.ReLU(inplace=True), + ) + self.db1 = nn.Sequential( + nn.Conv2d(base_width * (4 + 8), base_width * 4, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width * 4), + nn.ReLU(inplace=True), + nn.Conv2d(base_width * 4, base_width * 4, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width * 4), + nn.ReLU(inplace=True), + ) + + self.up2 = nn.Sequential( + nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True), + nn.Conv2d(base_width * 4, base_width * 2, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width * 2), + nn.ReLU(inplace=True), + ) + self.db2 = nn.Sequential( + nn.Conv2d(base_width * (2 + 4), base_width * 2, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width * 2), + nn.ReLU(inplace=True), + nn.Conv2d(base_width * 2, base_width * 2, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width * 2), + nn.ReLU(inplace=True), + ) + + self.up3 = nn.Sequential( + nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True), + nn.Conv2d(base_width * 2, base_width, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width), + nn.ReLU(inplace=True), + ) + self.db3 = nn.Sequential( + nn.Conv2d(base_width * (2 + 1), base_width, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width), + nn.ReLU(inplace=True), + nn.Conv2d(base_width, base_width, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width), + nn.ReLU(inplace=True), + ) + + self.up4 = nn.Sequential( + nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True), + nn.Conv2d(base_width, base_width, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width), + nn.ReLU(inplace=True), + ) + self.db4 = nn.Sequential( + nn.Conv2d(base_width * 2, base_width, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width), + nn.ReLU(inplace=True), + nn.Conv2d(base_width, base_width, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width), + nn.ReLU(inplace=True), + ) + + self.fin_out = nn.Sequential(nn.Conv2d(base_width, out_channels, kernel_size=3, padding=1)) + + def forward(self, act1: Tensor, act2: Tensor, act3: Tensor, act4: Tensor, act5: Tensor, act6: Tensor) -> Tensor: + """Computes predicted anomaly class scores from the intermediate outputs of the encoder sub network. + + Args: + act1 (Tensor): Encoder activations of the first block of convolutional layers. + act2 (Tensor): Encoder activations of the second block of convolutional layers. + act3 (Tensor): Encoder activations of the third block of convolutional layers. + act4 (Tensor): Encoder activations of the fourth block of convolutional layers. + act5 (Tensor): Encoder activations of the fifth block of convolutional layers. + act6 (Tensor): Encoder activations of the sixth block of convolutional layers. + + Returns: + Predicted anomaly class scores per pixel. + """ + up_b = self.up_b(act6) + cat_b = torch.cat((up_b, act5), dim=1) + db_b = self.db_b(cat_b) + + up1 = self.up1(db_b) + cat1 = torch.cat((up1, act4), dim=1) + db1 = self.db1(cat1) + + up2 = self.up2(db1) + cat2 = torch.cat((up2, act3), dim=1) + db2 = self.db2(cat2) + + up3 = self.up3(db2) + cat3 = torch.cat((up3, act2), dim=1) + db3 = self.db3(cat3) + + up4 = self.up4(db3) + cat4 = torch.cat((up4, act1), dim=1) + db4 = self.db4(cat4) + + out = self.fin_out(db4) + return out + + +class EncoderReconstructive(nn.Module): + """Encoder part of the reconstructive network. + + Args: + in_channels (int): Number of input channels. + base_width (int): Base dimensionality of the layers of the autoencoder. + """ + + def __init__(self, in_channels: int, base_width: int): + super().__init__() + self.block1 = nn.Sequential( + nn.Conv2d(in_channels, base_width, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width), + nn.ReLU(inplace=True), + nn.Conv2d(base_width, base_width, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width), + nn.ReLU(inplace=True), + ) + self.mp1 = nn.Sequential(nn.MaxPool2d(2)) + self.block2 = nn.Sequential( + nn.Conv2d(base_width, base_width * 2, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width * 2), + nn.ReLU(inplace=True), + nn.Conv2d(base_width * 2, base_width * 2, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width * 2), + nn.ReLU(inplace=True), + ) + self.mp2 = nn.Sequential(nn.MaxPool2d(2)) + self.block3 = nn.Sequential( + nn.Conv2d(base_width * 2, base_width * 4, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width * 4), + nn.ReLU(inplace=True), + nn.Conv2d(base_width * 4, base_width * 4, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width * 4), + nn.ReLU(inplace=True), + ) + self.mp3 = nn.Sequential(nn.MaxPool2d(2)) + self.block4 = nn.Sequential( + nn.Conv2d(base_width * 4, base_width * 8, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width * 8), + nn.ReLU(inplace=True), + nn.Conv2d(base_width * 8, base_width * 8, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width * 8), + nn.ReLU(inplace=True), + ) + self.mp4 = nn.Sequential(nn.MaxPool2d(2)) + self.block5 = nn.Sequential( + nn.Conv2d(base_width * 8, base_width * 8, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width * 8), + nn.ReLU(inplace=True), + nn.Conv2d(base_width * 8, base_width * 8, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width * 8), + nn.ReLU(inplace=True), + ) + + def forward(self, batch: Tensor) -> Tensor: + """Encode a batch of input images to the salient space. + + Args: + batch (Tensor): Batch of input images. + + Returns: + Feature maps extracted from the bottleneck layer. + """ + act1 = self.block1(batch) + mp1 = self.mp1(act1) + act2 = self.block2(mp1) + mp2 = self.mp3(act2) + act3 = self.block3(mp2) + mp3 = self.mp3(act3) + act4 = self.block4(mp3) + mp4 = self.mp4(act4) + act5 = self.block5(mp4) + return act5 + + +class DecoderReconstructive(nn.Module): + """Decoder part of the reconstructive network. + + Args: + base_width (int): Base dimensionality of the layers of the autoencoder. + out_channels (int): Number of output channels. + """ + + def __init__(self, base_width: int, out_channels: int = 1): + super().__init__() + + self.up1 = nn.Sequential( + nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True), + nn.Conv2d(base_width * 8, base_width * 8, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width * 8), + nn.ReLU(inplace=True), + ) + self.db1 = nn.Sequential( + nn.Conv2d(base_width * 8, base_width * 8, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width * 8), + nn.ReLU(inplace=True), + nn.Conv2d(base_width * 8, base_width * 4, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width * 4), + nn.ReLU(inplace=True), + ) + + self.up2 = nn.Sequential( + nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True), + nn.Conv2d(base_width * 4, base_width * 4, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width * 4), + nn.ReLU(inplace=True), + ) + self.db2 = nn.Sequential( + nn.Conv2d(base_width * 4, base_width * 4, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width * 4), + nn.ReLU(inplace=True), + nn.Conv2d(base_width * 4, base_width * 2, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width * 2), + nn.ReLU(inplace=True), + ) + + self.up3 = nn.Sequential( + nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True), + nn.Conv2d(base_width * 2, base_width * 2, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width * 2), + nn.ReLU(inplace=True), + ) + # cat with base*1 + self.db3 = nn.Sequential( + nn.Conv2d(base_width * 2, base_width * 2, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width * 2), + nn.ReLU(inplace=True), + nn.Conv2d(base_width * 2, base_width * 1, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width * 1), + nn.ReLU(inplace=True), + ) + + self.up4 = nn.Sequential( + nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True), + nn.Conv2d(base_width, base_width, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width), + nn.ReLU(inplace=True), + ) + self.db4 = nn.Sequential( + nn.Conv2d(base_width * 1, base_width, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width), + nn.ReLU(inplace=True), + nn.Conv2d(base_width, base_width, kernel_size=3, padding=1), + nn.BatchNorm2d(base_width), + nn.ReLU(inplace=True), + ) + + self.fin_out = nn.Sequential(nn.Conv2d(base_width, out_channels, kernel_size=3, padding=1)) + + def forward(self, act5: Tensor) -> Tensor: + """Reconstruct the image from the activations of the bottleneck layer. + + Args: + act5 (Tensor): Activations of the bottleneck layer. + + Returns: + Batch of reconstructed images. + """ + up1 = self.up1(act5) + db1 = self.db1(up1) + + up2 = self.up2(db1) + db2 = self.db2(up2) + + up3 = self.up3(db2) + db3 = self.db3(up3) + + up4 = self.up4(db3) + db4 = self.db4(up4) + + out = self.fin_out(db4) + return out diff --git a/anomalib/models/draem/transform_config.yaml b/anomalib/models/draem/transform_config.yaml new file mode 100644 index 0000000000..5a379ef762 --- /dev/null +++ b/anomalib/models/draem/transform_config.yaml @@ -0,0 +1,26 @@ +{ + "__version__": "1.1.0", + "transform": + { + "__class_fullname__": "Compose", + "p": 1.0, + "transforms": + [ + { + "__class_fullname__": "ToFloat", + "always_apply": false, + "p": 1.0, + "max_value": null, + }, + { + "__class_fullname__": "ToTensorV2", + "always_apply": true, + "p": 1.0, + "transpose_mask": false, + }, + ], + "bbox_params": null, + "keypoint_params": null, + "additional_targets": {}, + }, +} diff --git a/anomalib/models/draem/utils/__init__.py b/anomalib/models/draem/utils/__init__.py new file mode 100644 index 0000000000..dde7003813 --- /dev/null +++ b/anomalib/models/draem/utils/__init__.py @@ -0,0 +1,8 @@ +"""Helpers for the DRAEM model implementation.""" + +# Copyright (C) 2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from .augmenter import Augmenter + +__all__ = ["Augmenter"] diff --git a/anomalib/models/draem/utils/augmenter.py b/anomalib/models/draem/utils/augmenter.py new file mode 100644 index 0000000000..7ae49c1315 --- /dev/null +++ b/anomalib/models/draem/utils/augmenter.py @@ -0,0 +1,147 @@ +"""Augmenter module to generates out-of-distribution samples for the DRAEM implementation.""" + +# Original Code +# Copyright (c) 2021 VitjanZ +# https://github.com/VitjanZ/DRAEM. +# SPDX-License-Identifier: MIT +# +# Modified +# Copyright (C) 2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +import glob +import random +from typing import Optional, Tuple + +import cv2 +import imgaug.augmenters as iaa +import numpy as np +import torch +from torch import Tensor +from torchvision.datasets.folder import IMG_EXTENSIONS + +from anomalib.models.draem.utils.perlin import rand_perlin_2d_np + + +class Augmenter: + """Class that generates noisy augmentations of input images. + + Args: + anomaly_source_path (Optional[str]): Path to a folder of images that will be used as source of the anomalous + noise. If not specified, random noise will be used instead. + """ + + def __init__(self, anomaly_source_path: Optional[str] = None): + + self.anomaly_source_paths = [] + if anomaly_source_path is not None: + for img_ext in IMG_EXTENSIONS: + self.anomaly_source_paths.extend(glob.glob(anomaly_source_path + "/**/*" + img_ext, recursive=True)) + + self.augmenters = [ + iaa.GammaContrast((0.5, 2.0), per_channel=True), + iaa.MultiplyAndAddToBrightness(mul=(0.8, 1.2), add=(-30, 30)), + iaa.pillike.EnhanceSharpness(), + iaa.AddToHueAndSaturation((-50, 50), per_channel=True), + iaa.Solarize(0.5, threshold=(32, 128)), + iaa.Posterize(), + iaa.Invert(), + iaa.pillike.Autocontrast(), + iaa.pillike.Equalize(), + iaa.Affine(rotate=(-45, 45)), + ] + self.rot = iaa.Sequential([iaa.Affine(rotate=(-90, 90))]) + + def rand_augmenter(self) -> iaa.Sequential: + """Selects 3 random transforms that will be applied to the anomaly source images. + + Returns: + A selection of 3 transforms. + """ + aug_ind = np.random.choice(np.arange(len(self.augmenters)), 3, replace=False) + aug = iaa.Sequential([self.augmenters[aug_ind[0]], self.augmenters[aug_ind[1]], self.augmenters[aug_ind[2]]]) + return aug + + def generate_perturbation( + self, height: int, width: int, anomaly_source_path: Optional[str] + ) -> Tuple[np.ndarray, np.ndarray]: + """Generate an image containing a random anomalous perturbation using a source image. + + Args: + height (int): height of the generated image. + width: (int): width of the generated image. + anomaly_source_path (Optional[str]): Path to an image file. If not provided, random noise will be used + instead. + + Returns: + Image containing a random anomalous perturbation, and the corresponding ground truth anomaly mask. + """ + # Generate random perlin noise + perlin_scale = 6 + min_perlin_scale = 0 + + perlin_scalex = 2 ** random.randint(min_perlin_scale, perlin_scale) + perlin_scaley = 2 ** random.randint(min_perlin_scale, perlin_scale) + + perlin_noise = rand_perlin_2d_np((height, width), (perlin_scalex, perlin_scaley)) + perlin_noise = self.rot(image=perlin_noise) + + # Create mask from perlin noise + mask = np.where(perlin_noise > 0.5, np.ones_like(perlin_noise), np.zeros_like(perlin_noise)) + mask = np.expand_dims(mask, axis=2).astype(np.float32) + + # Load anomaly source image + if anomaly_source_path: + anomaly_source_img = cv2.imread(anomaly_source_path) + anomaly_source_img = cv2.resize(anomaly_source_img, dsize=(width, height)) + else: # if no anomaly source is specified, we use the perlin noise as anomalous source + anomaly_source_img = np.expand_dims(perlin_noise, 2).repeat(3, 2) + anomaly_source_img = (anomaly_source_img * 255).astype(np.uint8) + + # Augment anomaly source image + aug = self.rand_augmenter() + anomaly_img_augmented = aug(image=anomaly_source_img) + + # Create anomalous perturbation that we will apply to the image + perturbation = anomaly_img_augmented.astype(np.float32) * mask / 255.0 + + return perturbation, mask + + def augment_batch(self, batch: Tensor) -> Tuple[Tensor, Tensor]: + """Generate anomalous augmentations for a batch of input images. + + Args: + batch (Tensor): Batch of input images + + Returns: + - Augmented image to which anomalous perturbations have been added. + - Ground truth masks corresponding to the anomalous perturbations. + """ + batch_size, channels, height, width = batch.shape + + # Collect perturbations + perturbations_list = [] + masks_list = [] + for _ in range(batch_size): + if random.random() > 0.5: # include 50% normal samples + perturbations_list.append(torch.zeros((channels, height, width))) + masks_list.append(torch.zeros((1, height, width))) + else: + anomaly_source_path = ( + random.sample(self.anomaly_source_paths, 1)[0] if len(self.anomaly_source_paths) > 0 else None + ) + perturbation, mask = self.generate_perturbation(height, width, anomaly_source_path) + perturbations_list.append(Tensor(perturbation).permute((2, 0, 1))) + masks_list.append(Tensor(mask).permute((2, 0, 1))) + + perturbations = torch.stack(perturbations_list).to(batch.device) + masks = torch.stack(masks_list).to(batch.device) + + # Apply perturbations batch wise + beta = torch.rand(batch_size) * 0.8 + beta = beta.view(batch_size, 1, 1, 1).expand_as(batch).to(batch.device) + + augmented_batch = batch * (1 - masks) + (1 - beta) * perturbations + beta * batch * (masks) + + return augmented_batch, masks diff --git a/anomalib/models/draem/utils/perlin.py b/anomalib/models/draem/utils/perlin.py new file mode 100644 index 0000000000..0c7a72f394 --- /dev/null +++ b/anomalib/models/draem/utils/perlin.py @@ -0,0 +1,134 @@ +"""Helper functions for generating Perlin noise.""" + +# Original Code +# Copyright (c) 2021 VitjanZ +# https://github.com/VitjanZ/DRAEM. +# SPDX-License-Identifier: MIT +# +# Modified +# Copyright (C) 2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +# pylint: disable=invalid-name + +import math + +import numpy as np +import torch + + +def lerp_np(x, y, w): + """Helper function.""" + fin_out = (y - x) * w + x + return fin_out + + +def rand_perlin_2d_octaves_np(shape, res, octaves=1, persistence=0.5): + """Generate Perlin noise parameterized by the octaves method. Numpy version.""" + noise = np.zeros(shape) + frequency = 1 + amplitude = 1 + for _ in range(octaves): + noise += amplitude * generate_perlin_noise_2d(shape, (frequency * res[0], frequency * res[1])) + frequency *= 2 + amplitude *= persistence + return noise + + +def generate_perlin_noise_2d(shape, res): + """Fractal perlin noise.""" + + def f(t): + return 6 * t**5 - 15 * t**4 + 10 * t**3 + + delta = (res[0] / shape[0], res[1] / shape[1]) + d = (shape[0] // res[0], shape[1] // res[1]) + grid = np.mgrid[0 : res[0] : delta[0], 0 : res[1] : delta[1]].transpose(1, 2, 0) % 1 + # Gradients + angles = 2 * np.pi * np.random.rand(res[0] + 1, res[1] + 1) + gradients = np.dstack((np.cos(angles), np.sin(angles))) + g00 = gradients[0:-1, 0:-1].repeat(d[0], 0).repeat(d[1], 1) + g10 = gradients[1:, 0:-1].repeat(d[0], 0).repeat(d[1], 1) + g01 = gradients[0:-1, 1:].repeat(d[0], 0).repeat(d[1], 1) + g11 = gradients[1:, 1:].repeat(d[0], 0).repeat(d[1], 1) + # Ramps + n00 = np.sum(grid * g00, 2) + n10 = np.sum(np.dstack((grid[:, :, 0] - 1, grid[:, :, 1])) * g10, 2) + n01 = np.sum(np.dstack((grid[:, :, 0], grid[:, :, 1] - 1)) * g01, 2) + n11 = np.sum(np.dstack((grid[:, :, 0] - 1, grid[:, :, 1] - 1)) * g11, 2) + # Interpolation + t = f(grid) + n0 = n00 * (1 - t[:, :, 0]) + t[:, :, 0] * n10 + n1 = n01 * (1 - t[:, :, 0]) + t[:, :, 0] * n11 + return np.sqrt(2) * ((1 - t[:, :, 1]) * n0 + t[:, :, 1] * n1) + + +def rand_perlin_2d_np(shape, res, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3): + """Generate a random image containing Perlin noise. Numpy version.""" + delta = (res[0] / shape[0], res[1] / shape[1]) + d = (shape[0] // res[0], shape[1] // res[1]) + grid = np.mgrid[0 : res[0] : delta[0], 0 : res[1] : delta[1]].transpose(1, 2, 0) % 1 + + angles = 2 * math.pi * np.random.rand(res[0] + 1, res[1] + 1) + gradients = np.stack((np.cos(angles), np.sin(angles)), axis=-1) + + def tile_grads(slice1, slice2): + return np.repeat(np.repeat(gradients[slice1[0] : slice1[1], slice2[0] : slice2[1]], d[0], axis=0), d[1], axis=1) + + def dot(grad, shift): + return ( + np.stack((grid[: shape[0], : shape[1], 0] + shift[0], grid[: shape[0], : shape[1], 1] + shift[1]), axis=-1) + * grad[: shape[0], : shape[1]] + ).sum(axis=-1) + + n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0]) + n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0]) + n01 = dot(tile_grads([0, -1], [1, None]), [0, -1]) + n11 = dot(tile_grads([1, None], [1, None]), [-1, -1]) + t = fade(grid[: shape[0], : shape[1]]) + return math.sqrt(2) * lerp_np(lerp_np(n00, n10, t[..., 0]), lerp_np(n01, n11, t[..., 0]), t[..., 1]) + + +def rand_perlin_2d(shape, res, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3): + """Generate a random image containing Perlin noise. PyTorch version.""" + delta = (res[0] / shape[0], res[1] / shape[1]) + d = (shape[0] // res[0], shape[1] // res[1]) + + grid = torch.stack(torch.meshgrid(torch.arange(0, res[0], delta[0]), torch.arange(0, res[1], delta[1])), dim=-1) % 1 + angles = 2 * math.pi * torch.rand(res[0] + 1, res[1] + 1) + gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1) + + def tile_grads(slice1, slice2): + return ( + gradients[slice1[0] : slice1[1], slice2[0] : slice2[1]] + .repeat_interleave(d[0], 0) + .repeat_interleave(d[1], 1) + ) + + def dot(grad, shift): + return ( + torch.stack( + (grid[: shape[0], : shape[1], 0] + shift[0], grid[: shape[0], : shape[1], 1] + shift[1]), dim=-1 + ) + * grad[: shape[0], : shape[1]] + ).sum(dim=-1) + + n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0]) + + n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0]) + n01 = dot(tile_grads([0, -1], [1, None]), [0, -1]) + n11 = dot(tile_grads([1, None], [1, None]), [-1, -1]) + t = fade(grid[: shape[0], : shape[1]]) + return math.sqrt(2) * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1]) + + +def rand_perlin_2d_octaves(shape, res, octaves=1, persistence=0.5): + """Generate Perlin noise parameterized by the octaves method. PyTorch version.""" + noise = torch.zeros(shape) + frequency = 1 + amplitude = 1 + for _ in range(octaves): + noise += amplitude * rand_perlin_2d(shape, (frequency * res[0], frequency * res[1])) + frequency *= 2 + amplitude *= persistence + return noise diff --git a/anomalib/pre_processing/pre_process.py b/anomalib/pre_processing/pre_process.py index fd10fc10f6..b8ec079aa5 100644 --- a/anomalib/pre_processing/pre_process.py +++ b/anomalib/pre_processing/pre_process.py @@ -104,13 +104,7 @@ def get_transforms(self) -> A.Compose: transforms: A.Compose if self.config is None and self.image_size is not None: - if isinstance(self.image_size, int): - height, width = self.image_size, self.image_size - elif isinstance(self.image_size, tuple): - height, width = self.image_size - else: - raise ValueError("``image_size`` could be either int or Tuple[int, int]") - + height, width = self._get_height_and_width() transforms = A.Compose( [ A.Resize(height=height, width=width, always_apply=True), @@ -131,8 +125,23 @@ def get_transforms(self) -> A.Compose: if isinstance(transforms[-1], ToTensorV2): transforms = A.Compose(transforms[:-1]) + # always resize to specified image size + if not any(isinstance(transform, A.Resize) for transform in transforms) and self.image_size is not None: + height, width = self._get_height_and_width() + transforms = A.Compose([A.Resize(height=height, width=width, always_apply=True), transforms]) + return transforms def __call__(self, *args, **kwargs): """Return transformed arguments.""" return self.transforms(*args, **kwargs) + + def _get_height_and_width(self) -> Tuple[Optional[int], Optional[int]]: + """Extract height and width from image size attribute.""" + if isinstance(self.image_size, int): + return self.image_size, self.image_size + if isinstance(self.image_size, tuple): + return int(self.image_size[0]), int(self.image_size[1]) + if self.image_size is None: + return None, None + raise ValueError("``image_size`` could be either int or Tuple[int, int]") diff --git a/docs/source/images/draem/architecture.png b/docs/source/images/draem/architecture.png new file mode 100644 index 0000000000..a791020b63 Binary files /dev/null and b/docs/source/images/draem/architecture.png differ diff --git a/requirements/base.txt b/requirements/base.txt index 97eaec8409..a4fe088a2f 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -12,3 +12,4 @@ torchtext>=0.9.1 wandb==0.12.17 matplotlib>=3.4.3 gradio>=2.9.4 +imgaug==0.4.0 diff --git a/tests/pre_merge/models/test_model_premerge.py b/tests/pre_merge/models/test_model_premerge.py index c3fe3edfda..8e36f7a1e5 100644 --- a/tests/pre_merge/models/test_model_premerge.py +++ b/tests/pre_merge/models/test_model_premerge.py @@ -31,6 +31,7 @@ class TestModel: ("cflow", False), ("dfkde", False), ("dfm", False), + ("draem", False), ("fastflow", False), ("ganomaly", False), ("padim", False), diff --git a/third-party-programs.txt b/third-party-programs.txt index 09433d8408..f15e7a26fc 100644 --- a/third-party-programs.txt +++ b/third-party-programs.txt @@ -26,3 +26,7 @@ terms are listed below. 3. FastFlowModel Copyright (c) 2022 @gathierry, https://github.com/gathierry/FastFlow SPDX-License-Identifier: Apache-2.0 + +4. Torch models and utils of the Draem module (anomalib.models.draem) + Copyright (c) 2021 VitjanZ, https://github.com/VitjanZ/DRAEM. + SPDX-License-Identifier: MIT