-
Notifications
You must be signed in to change notification settings - Fork 657
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor model implementations #225
Merged
Merged
Changes from all commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
4397654
refactor CFlow implementation
djdameln aa73b20
refactor DFM implementation
djdameln a25f4ea
refactor PADIM implementation
djdameln b62856f
refactor PatchCore implementation
djdameln effb119
refactor STFPM implementation
djdameln afcd4c4
revert model tests
djdameln 8827b51
Merge branch 'development' into da/refactor/model-file-structure
djdameln f887b45
remove unintentionally committed file
djdameln 5399d64
model.py -> lightning_model.py
djdameln File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
"""Anomaly Map Generator for CFlow model implementation.""" | ||
|
||
# Copyright (C) 2020 Intel Corporation | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions | ||
# and limitations under the License. | ||
|
||
from typing import List, Tuple, Union, cast | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
from omegaconf import ListConfig | ||
from torch import Tensor | ||
|
||
|
||
class AnomalyMapGenerator: | ||
"""Generate Anomaly Heatmap.""" | ||
|
||
def __init__( | ||
self, | ||
image_size: Union[ListConfig, Tuple], | ||
pool_layers: List[str], | ||
): | ||
self.distance = torch.nn.PairwiseDistance(p=2, keepdim=True) | ||
self.image_size = image_size if isinstance(image_size, tuple) else tuple(image_size) | ||
self.pool_layers: List[str] = pool_layers | ||
|
||
def compute_anomaly_map( | ||
self, distribution: Union[List[Tensor], List[List]], height: List[int], width: List[int] | ||
) -> Tensor: | ||
"""Compute the layer map based on likelihood estimation. | ||
|
||
Args: | ||
distribution: Probability distribution for each decoder block | ||
height: blocks height | ||
width: blocks width | ||
|
||
Returns: | ||
Final Anomaly Map | ||
|
||
""" | ||
|
||
test_map: List[Tensor] = [] | ||
for layer_idx in range(len(self.pool_layers)): | ||
test_norm = torch.tensor(distribution[layer_idx], dtype=torch.double) # pylint: disable=not-callable | ||
test_norm -= torch.max(test_norm) # normalize likelihoods to (-Inf:0] by subtracting a constant | ||
test_prob = torch.exp(test_norm) # convert to probs in range [0:1] | ||
test_mask = test_prob.reshape(-1, height[layer_idx], width[layer_idx]) | ||
# upsample | ||
test_map.append( | ||
F.interpolate( | ||
test_mask.unsqueeze(1), size=self.image_size, mode="bilinear", align_corners=True | ||
).squeeze() | ||
) | ||
# score aggregation | ||
score_map = torch.zeros_like(test_map[0]) | ||
for layer_idx in range(len(self.pool_layers)): | ||
score_map += test_map[layer_idx] | ||
score_mask = score_map | ||
# invert probs to anomaly scores | ||
anomaly_map = score_mask.max() - score_mask | ||
|
||
return anomaly_map | ||
|
||
def __call__(self, **kwargs: Union[List[Tensor], List[int], List[List]]) -> Tensor: | ||
"""Returns anomaly_map. | ||
|
||
Expects `distribution`, `height` and 'width' keywords to be passed explicitly | ||
|
||
Example | ||
>>> anomaly_map_generator = AnomalyMapGenerator(image_size=tuple(hparams.model.input_size), | ||
>>> pool_layers=pool_layers) | ||
>>> output = self.anomaly_map_generator(distribution=dist, height=height, width=width) | ||
|
||
Raises: | ||
ValueError: `distribution`, `height` and 'width' keys are not found | ||
|
||
Returns: | ||
torch.Tensor: anomaly map | ||
""" | ||
if not ("distribution" in kwargs and "height" in kwargs and "width" in kwargs): | ||
raise KeyError(f"Expected keys `distribution`, `height` and `width`. Found {kwargs.keys()}") | ||
|
||
# placate mypy | ||
distribution: List[Tensor] = cast(List[Tensor], kwargs["distribution"]) | ||
height: List[int] = cast(List[int], kwargs["height"]) | ||
width: List[int] = cast(List[int], kwargs["width"]) | ||
return self.compute_anomaly_map(distribution, height, width) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,156 @@ | ||
"""CFLOW: Real-Time Unsupervised Anomaly Detection via Conditional Normalizing Flows. | ||
|
||
https://arxiv.org/pdf/2107.12571v1.pdf | ||
""" | ||
|
||
# Copyright (C) 2020 Intel Corporation | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions | ||
# and limitations under the License. | ||
|
||
import einops | ||
import torch | ||
import torch.nn.functional as F | ||
from pytorch_lightning.callbacks import EarlyStopping | ||
from torch import optim | ||
|
||
from anomalib.models.cflow.torch_model import CflowModel | ||
from anomalib.models.cflow.utils import get_logp, positional_encoding_2d | ||
from anomalib.models.components import AnomalyModule | ||
|
||
__all__ = ["CflowLightning"] | ||
|
||
|
||
class CflowLightning(AnomalyModule): | ||
"""PL Lightning Module for the CFLOW algorithm.""" | ||
|
||
def __init__(self, hparams): | ||
super().__init__(hparams) | ||
|
||
self.model: CflowModel = CflowModel(hparams) | ||
self.loss_val = 0 | ||
self.automatic_optimization = False | ||
|
||
def configure_callbacks(self): | ||
"""Configure model-specific callbacks.""" | ||
early_stopping = EarlyStopping( | ||
monitor=self.hparams.model.early_stopping.metric, | ||
patience=self.hparams.model.early_stopping.patience, | ||
mode=self.hparams.model.early_stopping.mode, | ||
) | ||
return [early_stopping] | ||
|
||
def configure_optimizers(self) -> torch.optim.Optimizer: | ||
"""Configures optimizers for each decoder. | ||
|
||
Returns: | ||
Optimizer: Adam optimizer for each decoder | ||
""" | ||
decoders_parameters = [] | ||
for decoder_idx in range(len(self.model.pool_layers)): | ||
decoders_parameters.extend(list(self.model.decoders[decoder_idx].parameters())) | ||
|
||
optimizer = optim.Adam( | ||
params=decoders_parameters, | ||
lr=self.hparams.model.lr, | ||
) | ||
return optimizer | ||
|
||
def training_step(self, batch, _): # pylint: disable=arguments-differ | ||
"""Training Step of CFLOW. | ||
|
||
For each batch, decoder layers are trained with a dynamic fiber batch size. | ||
Training step is performed manually as multiple training steps are involved | ||
per batch of input images | ||
|
||
Args: | ||
batch: Input batch | ||
_: Index of the batch. | ||
|
||
Returns: | ||
Loss value for the batch | ||
|
||
""" | ||
opt = self.optimizers() | ||
self.model.encoder.eval() | ||
|
||
images = batch["image"] | ||
activation = self.model.encoder(images) | ||
avg_loss = torch.zeros([1], dtype=torch.float64).to(images.device) | ||
|
||
height = [] | ||
width = [] | ||
for layer_idx, layer in enumerate(self.model.pool_layers): | ||
encoder_activations = activation[layer].detach() # BxCxHxW | ||
|
||
batch_size, dim_feature_vector, im_height, im_width = encoder_activations.size() | ||
image_size = im_height * im_width | ||
embedding_length = batch_size * image_size # number of rows in the conditional vector | ||
|
||
height.append(im_height) | ||
width.append(im_width) | ||
# repeats positional encoding for the entire batch 1 C H W to B C H W | ||
pos_encoding = einops.repeat( | ||
positional_encoding_2d(self.model.condition_vector, im_height, im_width).unsqueeze(0), | ||
"b c h w-> (tile b) c h w", | ||
tile=batch_size, | ||
).to(images.device) | ||
c_r = einops.rearrange(pos_encoding, "b c h w -> (b h w) c") # BHWxP | ||
e_r = einops.rearrange(encoder_activations, "b c h w -> (b h w) c") # BHWxC | ||
perm = torch.randperm(embedding_length) # BHW | ||
decoder = self.model.decoders[layer_idx].to(images.device) | ||
|
||
fiber_batches = embedding_length // self.model.fiber_batch_size # number of fiber batches | ||
assert fiber_batches > 0, "Make sure we have enough fibers, otherwise decrease N or batch-size!" | ||
|
||
for batch_num in range(fiber_batches): # per-fiber processing | ||
opt.zero_grad() | ||
if batch_num < (fiber_batches - 1): | ||
idx = torch.arange( | ||
batch_num * self.model.fiber_batch_size, (batch_num + 1) * self.model.fiber_batch_size | ||
) | ||
else: # When non-full batch is encountered batch_num * N will go out of bounds | ||
idx = torch.arange(batch_num * self.model.fiber_batch_size, embedding_length) | ||
# get random vectors | ||
c_p = c_r[perm[idx]] # NxP | ||
e_p = e_r[perm[idx]] # NxC | ||
# decoder returns the transformed variable z and the log Jacobian determinant | ||
p_u, log_jac_det = decoder(e_p, [c_p]) | ||
# | ||
decoder_log_prob = get_logp(dim_feature_vector, p_u, log_jac_det) | ||
log_prob = decoder_log_prob / dim_feature_vector # likelihood per dim | ||
loss = -F.logsigmoid(log_prob) | ||
self.manual_backward(loss.mean()) | ||
opt.step() | ||
avg_loss += loss.sum() | ||
|
||
return {"loss": avg_loss} | ||
|
||
def validation_step(self, batch, _): # pylint: disable=arguments-differ | ||
"""Validation Step of CFLOW. | ||
|
||
Similar to the training step, encoder features | ||
are extracted from the CNN for each batch, and anomaly | ||
map is computed. | ||
|
||
Args: | ||
batch: Input batch | ||
_: Index of the batch. | ||
|
||
Returns: | ||
Dictionary containing images, anomaly maps, true labels and masks. | ||
These are required in `validation_epoch_end` for feature concatenation. | ||
|
||
""" | ||
batch["anomaly_maps"] = self.model(batch["image"]) | ||
|
||
return batch |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this import needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was added by Alexander. I think it's needed for NNCF support
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah looks like it. But I don't understand how/why it's needed. It is not used anywhere