Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added cflow algorithm #47

Merged
merged 24 commits into from
Jan 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
44cd135
Added cflow algorithm
blakshma Dec 22, 2021
efece38
added cflow reqs
samet-akcay Dec 23, 2021
04d1610
changing default image size
blakshma Dec 25, 2021
52399fa
revert image size and delete resume from checkpoint
blakshma Dec 25, 2021
9f95b20
removed reference to github
blakshma Dec 25, 2021
82192aa
reducing batch size for training and inference
blakshma Dec 26, 2021
6b632bf
decreasing fiber batch size
blakshma Dec 26, 2021
fa1b4df
Added cflow algorithm
blakshma Dec 22, 2021
7e3a369
added cflow reqs
samet-akcay Dec 23, 2021
cdcf489
changing default image size
blakshma Dec 25, 2021
5c7240d
revert image size and delete resume from checkpoint
blakshma Dec 25, 2021
8219029
removed reference to github
blakshma Dec 25, 2021
8d4c17a
reducing batch size for training and inference
blakshma Dec 26, 2021
83e1bb2
decreasing fiber batch size
blakshma Dec 26, 2021
1e0d605
reducing patience for quicker convergence
blakshma Jan 2, 2022
8547c39
Merge branch 'algo/barath/cflow' of github.com:openvinotoolkit/anomal…
blakshma Jan 2, 2022
5971542
updating config file
blakshma Jan 4, 2022
3df4924
Merge branch 'development' of github.com:openvinotoolkit/anomalib int…
samet-akcay Jan 5, 2022
860c490
fixing checkpoint issue, added normalization param
blakshma Jan 5, 2022
2308c64
rolling back to resnet18
blakshma Jan 5, 2022
f1eb5f0
perform metric computation on cpu
djdameln Jan 7, 2022
3977a5d
normalization on cpu
djdameln Jan 7, 2022
c155a21
round performance comparison
samet-akcay Jan 9, 2022
30b6b51
Merge branch 'fix/da/cpu-metric-computation' of github.com:openvinoto…
samet-akcay Jan 9, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions anomalib/core/callbacks/min_max_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,11 @@ def on_predict_batch_end(
@staticmethod
def _normalize_batch(outputs, pl_module):
"""Normalize a batch of predictions."""
stats = pl_module.min_max
stats = pl_module.min_max.cpu()
outputs["pred_scores"] = normalize(
outputs["pred_scores"], pl_module.image_threshold.value, stats.min, stats.max
outputs["pred_scores"], pl_module.image_threshold.value.cpu(), stats.min, stats.max
)
if "anomaly_maps" in outputs.keys():
outputs["anomaly_maps"] = normalize(
outputs["anomaly_maps"], pl_module.pixel_threshold.value, stats.min, stats.max
outputs["anomaly_maps"], pl_module.pixel_threshold.value.cpu(), stats.min, stats.max
)
20 changes: 15 additions & 5 deletions anomalib/core/model/anomaly_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,19 +48,19 @@ def __init__(self, params: Union[DictConfig, ListConfig]):
self.loss: Tensor
self.callbacks: List[Callback]

self.image_threshold = AdaptiveThreshold(self.hparams.model.threshold.image_default)
self.image_threshold = AdaptiveThreshold(self.hparams.model.threshold.image_default).cpu()
self.pixel_threshold = AdaptiveThreshold(self.hparams.model.threshold.pixel_default)

self.training_distribution = AnomalyScoreDistribution()
self.min_max = MinMax()
self.training_distribution = AnomalyScoreDistribution().cpu()
self.min_max = MinMax().cpu()

self.model: nn.Module

# metrics
auroc = AUROC(num_classes=1, pos_label=1, compute_on_step=False)
f1_score = F1(num_classes=1, compute_on_step=False)
self.image_metrics = MetricCollection([auroc, f1_score], prefix="image_")
self.pixel_metrics = self.image_metrics.clone(prefix="pixel_")
self.image_metrics = MetricCollection([auroc, f1_score], prefix="image_").cpu()
self.pixel_metrics = self.image_metrics.clone(prefix="pixel_").cpu()

def forward(self, batch): # pylint: disable=arguments-differ
"""Forward-pass input tensor to the module.
Expand Down Expand Up @@ -111,11 +111,13 @@ def test_step(self, batch, _): # pylint: disable=arguments-differ

def validation_step_end(self, val_step_outputs): # pylint: disable=arguments-differ
"""Called at the end of each validation step."""
self._outputs_to_cpu(val_step_outputs)
self._post_process(val_step_outputs)
return val_step_outputs

def test_step_end(self, test_step_outputs): # pylint: disable=arguments-differ
"""Called at the end of each test step."""
self._outputs_to_cpu(test_step_outputs)
self._post_process(test_step_outputs)
return test_step_outputs

Expand Down Expand Up @@ -152,8 +154,10 @@ def _compute_adaptive_threshold(self, outputs):

def _collect_outputs(self, image_metric, pixel_metric, outputs):
for output in outputs:
image_metric.cpu()
image_metric.update(output["pred_scores"], output["label"].int())
if "mask" in output.keys() and "anomaly_maps" in output.keys():
pixel_metric.cpu()
pixel_metric.update(output["anomaly_maps"].flatten(), output["mask"].flatten().int())

def _post_process(self, outputs):
Expand All @@ -163,6 +167,12 @@ def _post_process(self, outputs):
outputs["anomaly_maps"].reshape(outputs["anomaly_maps"].shape[0], -1).max(dim=1).values
)

def _outputs_to_cpu(self, output):
# for output in outputs:
for key, value in output.items():
if isinstance(value, Tensor):
output[key] = value.cpu()

def _log_metrics(self):
"""Log computed performance metrics."""
self.log_dict(self.image_metrics)
Expand Down
7 changes: 7 additions & 0 deletions anomalib/core/model/feature_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,17 @@ def __init__(self, backbone: nn.Module, layers: Iterable[str]):
self.backbone = backbone
self.layers = layers
self._features = {layer: torch.empty(0) for layer in self.layers}
self.out_dims = []

for layer_id in layers:
layer = dict([*self.backbone.named_modules()])[layer_id]
layer.register_forward_hook(self.get_features(layer_id))
# get output dimension of features if available
layer_modules = [*layer.modules()]
for idx in reversed(range(len(layer_modules))):
if hasattr(layer_modules[idx], "out_channels"):
self.out_dims.append(layer_modules[idx].out_channels)
break

def get_features(self, layer_id: str) -> Callable:
"""Get layer features.
Expand Down
2 changes: 1 addition & 1 deletion anomalib/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def get_model(config: Union[DictConfig, ListConfig]) -> AnomalyModule:
AnomalyModule: Anomaly Model
"""
openvino_model_list: List[str] = ["stfpm"]
torch_model_list: List[str] = ["padim", "stfpm", "dfkde", "dfm", "patchcore"]
torch_model_list: List[str] = ["padim", "stfpm", "dfkde", "dfm", "patchcore", "cflow"]
model: AnomalyModule

if config.openvino:
Expand Down
3 changes: 3 additions & 0 deletions anomalib/models/cflow/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Real-Time Unsupervised Anomaly Detection via Conditional Normalizing Flows

This is the implementation of the [CW-AD](https://arxiv.org/pdf/2107.12571v1.pdf) paper.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have sample results? It would be nice to include them here for consistency with other tasks. Here is an example https://github.com/openvinotoolkit/anomalib/blob/development/anomalib/models/stfpm/README.md
Otherwise, can you create a ticket for it to address this in a different PR.

18 changes: 18 additions & 0 deletions anomalib/models/cflow/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
"""Real-Time Unsupervised Anomaly Detection via Conditional Normalizing Flows.
[CW-AD](https://arxiv.org/pdf/2107.12571v1.pdf)
"""

# Copyright (C) 2020 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.
100 changes: 100 additions & 0 deletions anomalib/models/cflow/backbone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
"""Helper functions to create backbone model."""

# Copyright (C) 2020 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.

import math

import FrEIA.framework as Ff
import FrEIA.modules as Fm
import torch
from FrEIA.framework.sequence_inn import SequenceINN
from torch import nn


def positional_encoding_2d(condition_vector: int, height: int, width: int) -> torch.Tensor:
"""Creates embedding to store relative position of the feature vector using sine and cosine functions.
Args:
condition_vector (int): Length of the condition vector
height (int): H of the positions
width (int): W of the positions
Raises:
ValueError: Cannot generate encoding with conditional vector length not as multiple of 4
Returns:
torch.Tensor: condition_vector x HEIGHT x WIDTH position matrix
"""
if condition_vector % 4 != 0:
raise ValueError(f"Cannot use sin/cos positional encoding with odd dimension (got dim={condition_vector})")
pos_encoding = torch.zeros(condition_vector, height, width)
# Each dimension use half of condition_vector
condition_vector = condition_vector // 2
div_term = torch.exp(torch.arange(0.0, condition_vector, 2) * -(math.log(1e4) / condition_vector))
pos_w = torch.arange(0.0, width).unsqueeze(1)
pos_h = torch.arange(0.0, height).unsqueeze(1)
pos_encoding[0:condition_vector:2, :, :] = (
torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
)
pos_encoding[1:condition_vector:2, :, :] = (
torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
)
pos_encoding[condition_vector::2, :, :] = (
torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
)
pos_encoding[condition_vector + 1 :: 2, :, :] = (
torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
)
return pos_encoding


def subnet_fc(dims_in: int, dims_out: int):
"""Subnetwork which predicts the affine coefficients.
Args:
dims_in (int): input dimensions
dims_out (int): output dimensions
Returns:
nn.Sequential: Feed-forward subnetwork
"""
return nn.Sequential(nn.Linear(dims_in, 2 * dims_in), nn.ReLU(), nn.Linear(2 * dims_in, dims_out))


def cflow_head(condition_vector: int, coupling_blocks: int, clamp_alpha: float, n_features: int) -> SequenceINN:
"""Create invertible decoder network.
Args:
condition_vector (int): length of the condition vector
coupling_blocks (int): number of coupling blocks to build the decoder
clamp_alpha (float): clamping value to avoid exploding values
n_features (int): number of decoder features
Returns:
SequenceINN: decoder network block
"""
coder = Ff.SequenceINN(n_features)
print("CNF coder:", n_features)
for _ in range(coupling_blocks):
coder.append(
Fm.AllInOneBlock,
cond=0,
cond_shape=(condition_vector,),
subnet_constructor=subnet_fc,
affine_clamping=clamp_alpha,
global_affine_type="SOFTPLUS",
permute_soft=True,
)
return coder
96 changes: 96 additions & 0 deletions anomalib/models/cflow/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
dataset:
name: mvtec
format: mvtec
path: ./datasets/MVTec
url: ftp://guest:GU.205dldo@ftp.softronics.ch/mvtec_anomaly_detection/mvtec_anomaly_detection.tar.xz
category: leather
task: segmentation
label_format: None
image_size: 256
train_batch_size: 16
test_batch_size: 16
inference_batch_size: 16
fiber_batch_size: 64
num_workers: 36

model:
name: cflow
backbone: resnet18
layers:
- layer2
- layer3
- layer4
decoder: freia-cflow
condition_vector: 128
coupling_blocks: 8
clamp_alpha: 1.9
lr: 0.0001
early_stopping:
patience: 3
metric: pixel_AUROC
mode: max
normalization_method: min_max # options: [null, min_max, cdf]
threshold:
image_default: 0
pixel_default: 0
adaptive: true

project:
seed: 0
path: ./results
log_images_to: [local]
logger: false
save_to_csv: false

# PL Trainer Args. Don't add extra parameter here.
trainer:
accelerator: null
accumulate_grad_batches: 1
amp_backend: native
amp_level: O2
auto_lr_find: false
auto_scale_batch_size: false
auto_select_gpus: false
benchmark: false
check_val_every_n_epoch: 1
checkpoint_callback: true
default_root_dir: null
deterministic: true
distributed_backend: null
fast_dev_run: false
flush_logs_every_n_steps: 100
gpus: 1
gradient_clip_val: 0
limit_predict_batches: 1.0
limit_test_batches: 1.0
limit_train_batches: 1.0
limit_val_batches: 1.0
log_every_n_steps: 50
log_gpu_memory: null
max_epochs: 50
max_steps: null
min_epochs: null
min_steps: null
move_metrics_to_cpu: false
multiple_trainloader_mode: max_size_cycle
num_nodes: 1
num_processes: 1
num_sanity_val_steps: 0
overfit_batches: 0.0
plugins: null
precision: 32
prepare_data_per_node: true
process_position: 0
profiler: null
progress_bar_refresh_rate: null
reload_dataloaders_every_epoch: false
replace_sampler_ddp: true
stochastic_weight_avg: false
sync_batchnorm: false
terminate_on_nan: false
tpu_cores: null
track_grad_norm: -1
truncated_bptt_steps: null
val_check_interval: 1.0
weights_save_path: null
weights_summary: top
Loading