diff --git a/docs/source/images/reverse_distillation/results/0.png b/docs/source/images/reverse_distillation/results/0.png index e50fffab8d..0f47c23851 100644 Binary files a/docs/source/images/reverse_distillation/results/0.png and b/docs/source/images/reverse_distillation/results/0.png differ diff --git a/docs/source/images/reverse_distillation/results/1.png b/docs/source/images/reverse_distillation/results/1.png index a20409b4be..e2e6cf1b89 100644 Binary files a/docs/source/images/reverse_distillation/results/1.png and b/docs/source/images/reverse_distillation/results/1.png differ diff --git a/docs/source/images/reverse_distillation/results/2.png b/docs/source/images/reverse_distillation/results/2.png index 140a8d1922..8fd57118a9 100644 Binary files a/docs/source/images/reverse_distillation/results/2.png and b/docs/source/images/reverse_distillation/results/2.png differ diff --git a/src/anomalib/models/reverse_distillation/README.md b/src/anomalib/models/reverse_distillation/README.md index 31ffeb9bea..444615007d 100644 --- a/src/anomalib/models/reverse_distillation/README.md +++ b/src/anomalib/models/reverse_distillation/README.md @@ -20,79 +20,32 @@ During testing, a similar step is followed but this time the cosine distance bet ## Benchmark -All results gathered with seed `42`. - -Note: Early Stopping (with patience 3) was enabled during training. +All results gathered with seed `42`, train batch size `16`. ## [MVTec AD Dataset](https://www.mvtec.com/company/research/datasets/mvtec-ad) ### Image-Level AUC -| | ResNet 18 | Wide ResNet 50 | -| :--------- | --------: | -------------: | -| Bottle | 0.998 | 0.992 | -| Cable | 0.982 | 0.583 | -| Capsule | 0.864 | 0.78 | -| Carpet | 0.996 | 0.539 | -| Grid | 0.941 | 0.975 | -| Hazelnut | 0.978 | 0.817 | -| Leather | 0.878 | 1 | -| Metal_nut | 0.999 | 0.929 | -| Pill | 0.944 | 0.553 | -| Screw | 0.778 | 0.86 | -| Tile | 0.833 | 0.513 | -| Toothbrush | 0.967 | 0.7 | -| Transistor | 0.928 | 0.829 | -| Wood | 0.989 | 0.993 | -| Zipper | 0.968 | 0.787 | -| Average | 0.936 | 0.79 | +| | Avg | Carpet | Grid | Leather | Tile | Wood | Bottle | Cable | Capsule | Hazelnut | Metal Nut | Pill | Screw | Toothbrush | Transistor | Zipper | +| -------------- | :---: | :----: | :---: | :-----: | :---: | :---: | :----: | :---: | :-----: | :------: | :-------: | :---: | :---: | :--------: | :--------: | :----: | +| Wide ResNet-50 | 0.985 | 0.984 | 1.000 | 1.000 | 1.000 | 0.997 | 1.000 | 0.966 | 0.974 | 1.000 | 1.000 | 0.972 | 0.985 | 0.953 | 0.970 | 0.978 | ### Pixel-Level AUC -| | ResNet 18 | Wide ResNet 50 | -| :--------- | --------: | -------------: | -| Bottle | 0.981 | 0.985 | -| Cable | 0.965 | 0.794 | -| Capsule | 0.983 | 0.986 | -| Carpet | 0.989 | 0.99 | -| Grid | 0.964 | 0.99 | -| Hazelnut | 0.988 | 0.983 | -| Leather | 0.984 | 0.995 | -| Metal_nut | 0.971 | 0.979 | -| Pill | 0.975 | 0.977 | -| Screw | 0.987 | 0.989 | -| Tile | 0.867 | 0.953 | -| Toothbrush | 0.99 | 0.979 | -| Transistor | 0.84 | 0.853 | -| Wood | 0.939 | 0.958 | -| Zipper | 0.988 | 0.959 | -| Average | 0.961 | 0.958 | +| | Avg | Carpet | Grid | Leather | Tile | Wood | Bottle | Cable | Capsule | Hazelnut | Metal Nut | Pill | Screw | Toothbrush | Transistor | Zipper | +| -------------- | :---: | :----: | :---: | :-----: | :---: | :---: | :----: | :---: | :-----: | :------: | :-------: | :---: | :---: | :--------: | :--------: | :----: | +| Wide ResNet-50 | 0.969 | 0.988 | 0.992 | 0.991 | 0.954 | 0.947 | 0.984 | 0.964 | 0.987 | 0.988 | 0.969 | 0.975 | 0.996 | 0.991 | 0.893 | 0.984 | ### Image F1 Score -| | ResNet 18 | Wide ResNet 50 | -| :--------- | --------: | -------------: | -| Bottle | 0.95 | 0.959 | -| Cable | 0.911 | 0.76 | -| Capsule | 0.933 | 0.905 | -| Carpet | 0.965 | 0.864 | -| Grid | 0.964 | 0.945 | -| Hazelnut | 0.909 | 0.901 | -| Leather | 0.896 | 0.989 | -| Metal_nut | 0.995 | 0.939 | -| Pill | 0.931 | 0.922 | -| Screw | 0.88 | 0.891 | -| Tile | 0.88 | 0.836 | -| Toothbrush | 0.933 | 0.833 | -| Transistor | 0.769 | 0.744 | -| Wood | 0.966 | 0.948 | -| Zipper | 0.944 | 0.926 | -| Average | 0.922 | 0.891 | +| | Avg | Carpet | Grid | Leather | Tile | Wood | Bottle | Cable | Capsule | Hazelnut | Metal Nut | Pill | Screw | Toothbrush | Transistor | Zipper | +| -------------- | :---: | :----: | :---: | :-----: | :---: | :---: | :----: | :---: | :-----: | :------: | :-------: | :---: | :---: | :--------: | :--------: | :----: | +| Wide ResNet-50 | 0.976 | 0.977 | 1.000 | 1.000 | 0.994 | 0.992 | 0.984 | 0.930 | 0.982 | 1.000 | 1.000 | 0.967 | 0.963 | 0.952 | 0.927 | 0.975 | ### Sample Results -![Sample Result 1](../../../docs/source/images/reverse_distillation/results/0.png "Sample Result 1") +![Sample Result 1](https://raw.githubusercontent.com/openvinotoolkit/anomalib/main/docs/source/images/reverse_distillation/results/0.png "Sample Result 1") -![Sample Result 2](../../../docs/source/images/reverse_distillation/results/1.png "Sample Result 2") +![Sample Result 2](https://raw.githubusercontent.com/openvinotoolkit/anomalib/main/docs/source/images/reverse_distillation/results/1.png "Sample Result 2") -![Sample Result 3](../../../docs/source/images/reverse_distillation/results/2.png "Sample Result 3") +![Sample Result 3](https://raw.githubusercontent.com/openvinotoolkit/anomalib/main/docs/source/images/reverse_distillation/results/2.png "Sample Result 3") diff --git a/src/anomalib/models/reverse_distillation/components/bottleneck.py b/src/anomalib/models/reverse_distillation/components/bottleneck.py index a04f6a8174..f424900341 100644 --- a/src/anomalib/models/reverse_distillation/components/bottleneck.py +++ b/src/anomalib/models/reverse_distillation/components/bottleneck.py @@ -76,10 +76,10 @@ def __init__( self.conv3 = conv3x3(128 * block.expansion, 256 * block.expansion, 2) self.bn3 = norm_layer(256 * block.expansion) - # This is present in the paper but not in the original code. With some initial experiments, removing this leads - # to better results - # self.conv4 = conv1x1(256 * block.expansion * 3, 256 * block.expansion * 3, 1) # x3 as we concatenate 3 layers - # self.bn4 = norm_layer(256 * block.expansion * 3) + # self.conv4 and self.bn4 are from the original code: + # https://github.com/hq-deng/RD4AD/blob/6554076872c65f8784f6ece8cfb39ce77e1aee12/resnet.py#L412 + self.conv4 = conv1x1(1024 * block.expansion, 512 * block.expansion, 1) + self.bn4 = norm_layer(512 * block.expansion) for module in self.modules(): if isinstance(module, nn.Conv2d): diff --git a/src/anomalib/models/reverse_distillation/config.yaml b/src/anomalib/models/reverse_distillation/config.yaml index 8724556725..87f49f1373 100644 --- a/src/anomalib/models/reverse_distillation/config.yaml +++ b/src/anomalib/models/reverse_distillation/config.yaml @@ -4,7 +4,7 @@ dataset: path: ./datasets/MVTec category: bottle task: segmentation - train_batch_size: 32 + train_batch_size: 16 eval_batch_size: 32 inference_batch_size: 32 num_workers: 8 @@ -35,21 +35,16 @@ model: - layer1 - layer2 - layer3 - early_stopping: - patience: 3 - metric: pixel_AUROC - mode: max beta1: 0.5 - beta2: 0.99 + beta2: 0.999 normalization_method: min_max # options: [null, min_max, cdf] - anomaly_map_mode: multiply + anomaly_map_mode: add # options: [add, multiply] metrics: image: - F1Score - AUROC pixel: - - F1Score - AUROC threshold: method: adaptive #options: [adaptive, manual] @@ -85,7 +80,7 @@ trainer: enable_progress_bar: true overfit_batches: 0.0 track_grad_norm: -1 - check_val_every_n_epoch: 2 # Don't validate before extracting features. + check_val_every_n_epoch: 200 # Don't validate before extracting features. fast_dev_run: false accumulate_grad_batches: 1 max_epochs: 200 diff --git a/src/anomalib/models/reverse_distillation/lightning_model.py b/src/anomalib/models/reverse_distillation/lightning_model.py index 5489daab5b..8cb65545bb 100644 --- a/src/anomalib/models/reverse_distillation/lightning_model.py +++ b/src/anomalib/models/reverse_distillation/lightning_model.py @@ -134,7 +134,7 @@ def __init__(self, hparams: DictConfig | ListConfig) -> None: self.save_hyperparameters(hparams) def configure_callbacks(self) -> list[EarlyStopping]: - """Configure model-specific callbacks. + """Configure model-specific non-mandatory callbacks. Note: This method is used for the existing CLI. @@ -142,9 +142,12 @@ def configure_callbacks(self) -> list[EarlyStopping]: deprecated, and callbacks will be configured from either config.yaml file or from CLI. """ - early_stopping = EarlyStopping( - monitor=self.hparams.model.early_stopping.metric, - patience=self.hparams.model.early_stopping.patience, - mode=self.hparams.model.early_stopping.mode, - ) - return [early_stopping] + callbacks = [] + if "early_stopping" in self.hparams.model: + early_stopping = EarlyStopping( + monitor=self.hparams.model.early_stopping.metric, + patience=self.hparams.model.early_stopping.patience, + mode=self.hparams.model.early_stopping.mode, + ) + callbacks.append(early_stopping) + return callbacks diff --git a/src/anomalib/models/reverse_distillation/loss.py b/src/anomalib/models/reverse_distillation/loss.py index e58955f175..de2d102464 100644 --- a/src/anomalib/models/reverse_distillation/loss.py +++ b/src/anomalib/models/reverse_distillation/loss.py @@ -15,6 +15,10 @@ class ReverseDistillationLoss(nn.Module): def forward(self, encoder_features: list[Tensor], decoder_features: list[Tensor]) -> Tensor: """Computes cosine similarity loss based on features from encoder and decoder. + Based on the official code: + https://github.com/hq-deng/RD4AD/blob/6554076872c65f8784f6ece8cfb39ce77e1aee12/main.py#L33C25-L33C25 + Calculates loss from flattened arrays of features, see https://github.com/hq-deng/RD4AD/issues/22 + Args: encoder_features (list[Tensor]): List of features extracted from encoder decoder_features (list[Tensor]): List of features extracted from decoder @@ -23,8 +27,13 @@ def forward(self, encoder_features: list[Tensor], decoder_features: list[Tensor] Tensor: Cosine similarity loss """ cos_loss = torch.nn.CosineSimilarity() - losses = list(map(cos_loss, encoder_features, decoder_features)) loss_sum = 0 - for loss in losses: - loss_sum += torch.mean(1 - loss) # mean of cosine distance + for encoder_feature, decoder_feature in zip(encoder_features, decoder_features): + loss_sum += torch.mean( + 1 + - cos_loss( + encoder_feature.view(encoder_feature.shape[0], -1), + decoder_feature.view(decoder_feature.shape[0], -1), + ) + ) return loss_sum diff --git a/src/anomalib/models/reverse_distillation/torch_model.py b/src/anomalib/models/reverse_distillation/torch_model.py index a0b56277dc..b10a730481 100644 --- a/src/anomalib/models/reverse_distillation/torch_model.py +++ b/src/anomalib/models/reverse_distillation/torch_model.py @@ -21,6 +21,9 @@ class ReverseDistillationModel(nn.Module): """Reverse Distillation Model. + To reproduce results in the paper, use torchvision model for the encoder: + self.encoder = torchvision.models.wide_resnet50_2(pretrained=True) + Args: backbone (str): Name of the backbone used for encoder and decoder input_size (tuple[int, int]): Size of input image