Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor feature extraction key #748

Merged
13 changes: 7 additions & 6 deletions anomalib/models/cflow/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@ dataset:

model:
name: cflow
backbone: wide_resnet50_2
pre_trained: true
layers:
- layer2
- layer3
- layer4
feature_extractor:
backbone: wide_resnet50_2
pre_trained: true
layers:
- layer2
- layer3
- layer4
decoder: freia-cflow
condition_vector: 128
coupling_blocks: 8
Expand Down
18 changes: 8 additions & 10 deletions anomalib/models/cflow/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from typing import List, Tuple, Union
from typing import Tuple, Union

import einops
import torch
Expand All @@ -19,6 +19,10 @@
from anomalib.models.cflow.torch_model import CflowModel
from anomalib.models.cflow.utils import get_logp, positional_encoding_2d
from anomalib.models.components import AnomalyModule
from anomalib.models.components.feature_extractors import (
TimmFeatureExtractorParams,
TorchFXFeatureExtractorParams,
)

__all__ = ["Cflow", "CflowLightning"]

Expand All @@ -30,9 +34,7 @@ class Cflow(AnomalyModule):
def __init__(
self,
input_size: Tuple[int, int],
backbone: str,
layers: List[str],
pre_trained: bool = True,
feature_extractor: Union[TimmFeatureExtractorParams, TorchFXFeatureExtractorParams],
ashwinvaidya17 marked this conversation as resolved.
Show resolved Hide resolved
fiber_batch_size: int = 64,
decoder: str = "freia-cflow",
condition_vector: int = 128,
Expand All @@ -45,9 +47,7 @@ def __init__(

self.model: CflowModel = CflowModel(
input_size=input_size,
backbone=backbone,
pre_trained=pre_trained,
layers=layers,
feature_extractor=feature_extractor,
fiber_batch_size=fiber_batch_size,
decoder=decoder,
condition_vector=condition_vector,
Expand Down Expand Up @@ -183,9 +183,7 @@ class CflowLightning(Cflow):
def __init__(self, hparams: Union[DictConfig, ListConfig]) -> None:
super().__init__(
input_size=hparams.model.input_size,
backbone=hparams.model.backbone,
layers=hparams.model.layers,
pre_trained=hparams.model.pre_trained,
feature_extractor=hparams.model.feature_extractor,
fiber_batch_size=hparams.model.fiber_batch_size,
decoder=hparams.model.decoder,
condition_vector=hparams.model.condition_vector,
Expand Down
16 changes: 8 additions & 8 deletions anomalib/models/cflow/torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from typing import List, Tuple
from typing import List, Tuple, Union

import einops
import torch
Expand All @@ -12,6 +12,10 @@
from anomalib.models.cflow.anomaly_map import AnomalyMapGenerator
from anomalib.models.cflow.utils import cflow_head, get_logp, positional_encoding_2d
from anomalib.models.components import FeatureExtractor
from anomalib.models.components.feature_extractors import (
TimmFeatureExtractorParams,
TorchFXFeatureExtractorParams,
)


class CflowModel(nn.Module):
Expand All @@ -20,9 +24,7 @@ class CflowModel(nn.Module):
def __init__(
self,
input_size: Tuple[int, int],
backbone: str,
layers: List[str],
pre_trained: bool = True,
feature_extractor: Union[TimmFeatureExtractorParams, TorchFXFeatureExtractorParams],
fiber_batch_size: int = 64,
decoder: str = "freia-cflow",
condition_vector: int = 128,
Expand All @@ -32,13 +34,11 @@ def __init__(
):
super().__init__()

self.backbone = backbone
self.fiber_batch_size = fiber_batch_size
self.condition_vector: int = condition_vector
self.dec_arch = decoder
self.pool_layers = layers

self.encoder = FeatureExtractor(backbone=self.backbone, layers=self.pool_layers, pre_trained=pre_trained)
self.encoder = FeatureExtractor(feature_extractor)
self.pool_layers = self.encoder.layers
self.pool_dims = self.encoder.out_dims
self.decoders = nn.ModuleList(
[
Expand Down
11 changes: 6 additions & 5 deletions anomalib/models/components/feature_extractors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
# Copyright (C) 2022 Intel Corporation
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this module be called feature_extraction? This would be more in line other modules under models.components (e.g. dimensionality_reduction and sampling)

# SPDX-License-Identifier: Apache-2.0

from .timm import FeatureExtractor, TimmFeatureExtractor
from .torchfx import BackboneParams, TorchFXFeatureExtractor
from .utils import dryrun_find_featuremap_dims

from .feature_extractor import FeatureExtractor
from .timm import TimmFeatureExtractor, TimmFeatureExtractorParams
from .torchfx import TorchFXFeatureExtractor, TorchFXFeatureExtractorParams

__all__ = [
"BackboneParams",
"dryrun_find_featuremap_dims",
"FeatureExtractor",
"TimmFeatureExtractor",
"TimmFeatureExtractorParams",
"TorchFXFeatureExtractor",
"TorchFXFeatureExtractorParams",
]
159 changes: 159 additions & 0 deletions anomalib/models/components/feature_extractors/feature_extractor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
"""Convenience wrapper for feature extraction methods."""

# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from typing import Dict, List, Tuple, Union

import torch
from omegaconf import DictConfig
from torch import Tensor, nn

from .timm import TimmFeatureExtractor, TimmFeatureExtractorParams
from .torchfx import TorchFXFeatureExtractor, TorchFXFeatureExtractorParams


class FeatureExtractor(nn.Module):
ashwinvaidya17 marked this conversation as resolved.
Show resolved Hide resolved
"""Convenience wrapper for feature extractors.

Selects either timm or torchfx feature extractor based on the arguments passed.

If you want to use timm feature extractor, you need to pass the following arguments:
backbone,layers,pre_trained,requires_grad

If you want to use torchfx feature extractor, you need to pass the following arguments:
backbone,return_nodes,weights,requires_grad

Example:
Using Timm

>>> from anomalib.models.components import FeatureExtractor
>>> FeatureExtractor(backbone="resnet18",layers=["layer1","layer2","layer3"])
TimmFeatureExtractor will be removed in 2023.1
FeatureExtractor(
(feature_extractor): TimmFeatureExtractor(
(feature_extractor): FeatureListNet(
...

Using TorchFX

>>> FeatureExtractor(backbone="resnet18",return_nodes=["layer1","layer2","layer3"])
FeatureExtractor(
(feature_extractor): TorchFXFeatureExtractor(
(feature_extractor): ResNet(

Using Backbone params

>>> from anomalib.models.components.feature_extractors import TorchFXFeatureExtractorParams
>>> from torchvision.models.efficientnet import EfficientNet_B5_Weights
>>> params = TorchFXFeatureExtractorParams(backbone="efficientnet_b5",
... return_nodes=["features.6.8"],
... weights=EfficientNet_B5_Weights.DEFAULT
... )
>>> FeatureExtractor(params)
FeatureExtractor(
(feature_extractor): TorchFXFeatureExtractor(
(feature_extractor): EfficientNet(
...
"""

def __init__(self, *args: Union[TimmFeatureExtractorParams, TorchFXFeatureExtractorParams, DictConfig], **kwargs):
super().__init__()

# Check if argument is passed as a key word argument or as a single argument of dictionary or dataclass.
feature_extractor_params = self._get_feature_extractor_params(args, kwargs)
self.feature_extractor = self._assign_feature_extractor(feature_extractor_params)
self.layers = (
self.feature_extractor.layers
if isinstance(self.feature_extractor, TimmFeatureExtractor)
else self.feature_extractor.return_nodes
)
self._out_dims: List[int]

def _get_feature_extractor_params(self, args, kwargs):
"""Performs validation checks and converts the arguments to the correct data type.

Checks if the arguments are passed as a key word argument or as a single argument of dictionary or dataclass.
If the checks pass, returns the feature extractor parameters as a dataclass.

The feature extractor expects only one of args of kwargs

Args:
args (Union[TimmFeatureExtractorParams, TorchFXFeatureExtractorParams, DictConfig]): Feature extractor
parameters.
kwargs (Dict[str, Any]): Feature extractor parameters as key word arguments.
"""
if len(args) == 1:
feature_extractor_params = self._convert_datatype(args[0])
elif len(args) > 0 and kwargs is not None:
raise ValueError(
"Either arguments as keyword arguments or as a single argument of type TimmFeatureExtractorParams or"
" TorchFXFeatureExtractorParams"
)
else:
feature_extractor_params = self._convert_datatype(kwargs)
return feature_extractor_params

def _convert_datatype(
self,
feature_extractor_params: Union[TimmFeatureExtractorParams, TorchFXFeatureExtractorParams, DictConfig, Dict],
):
"""When config us loaded from entry point scripts, the data type of the arguments is DictConfig.

Args:
feature_extractor_params: Feature extractor parameters to convert.

Returns:
Union[TimmFeatureExtractorParams, TorchFXFeatureExtractorParams]: Converted feature extractor parameters.
"""
if isinstance(feature_extractor_params, (DictConfig, dict)):
if "layers" in feature_extractor_params:
feature_extractor_params = TimmFeatureExtractorParams(**feature_extractor_params)
else:
feature_extractor_params = TorchFXFeatureExtractorParams(**feature_extractor_params)
if not isinstance(feature_extractor_params, (TimmFeatureExtractorParams, TorchFXFeatureExtractorParams)):
raise ValueError(f"Unknown feature extractor params type: {type(feature_extractor_params)}")
return feature_extractor_params
ashwinvaidya17 marked this conversation as resolved.
Show resolved Hide resolved

def _assign_feature_extractor(
self, feature_extractor_params: Union[TimmFeatureExtractorParams, TorchFXFeatureExtractorParams]
) -> Union[TimmFeatureExtractor, TorchFXFeatureExtractor]:
"""Assigns the feature extractor based on the arguments passed."""
if isinstance(feature_extractor_params, TimmFeatureExtractorParams):
feature_extractor = TimmFeatureExtractor(**vars(feature_extractor_params))
else:
feature_extractor = TorchFXFeatureExtractor(**vars(feature_extractor_params))
return feature_extractor

def forward(self, inputs: Tensor) -> Tensor:
"""Returns the feature maps from the selected feature extractor."""
return self.feature_extractor(inputs)

@property
def out_dims(self) -> List[int]:
"""Returns the number of channels of the requested layers."""
if not hasattr(self, "_out_dims"):
if isinstance(self.feature_extractor, TimmFeatureExtractor):
self._out_dims = self.feature_extractor.out_dims
else:
# run a small tensor through the model to get the output dimensions
self._out_dims = [val["num_features"] for val in self.dryrun_find_featuremap_dims((1, 1)).values()]
return self._out_dims

def dryrun_find_featuremap_dims(self, input_shape: Tuple[int, int]) -> Dict[str, Dict]:
"""Dry run an empty image of get the feature map tensors' dimensions (num_features, resolution).

Args:
input_shape (Tuple[int, int]): Shape of the input image.

Returns:
Dict[str, Tuple]: mapping of ```layer -> dimensions dict```
Each `dimension dict` has two keys: `num_features` (int) and ```resolution```(Tuple[int, int]).
"""

dryrun_input = torch.empty(1, 3, *input_shape)
dryrun_features = self.feature_extractor(dryrun_input)
return {
layer: {"num_features": dryrun_features[layer].shape[1], "resolution": dryrun_features[layer].shape[2:]}
for layer in self.layers
}
27 changes: 12 additions & 15 deletions anomalib/models/components/feature_extractors/timm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import logging
import warnings
from dataclasses import dataclass
from typing import Dict, List

import timm
Expand All @@ -17,6 +18,16 @@
logger = logging.getLogger(__name__)


@dataclass
class TimmFeatureExtractorParams:
"""Used for serializing the Timm Feature Extractor."""

backbone: str
layers: List[str]
pre_trained: bool = True
requires_grad: bool = False


class TimmFeatureExtractor(nn.Module):
"""Extract features from a CNN.

Expand Down Expand Up @@ -44,6 +55,7 @@ class TimmFeatureExtractor(nn.Module):

def __init__(self, backbone: str, layers: List[str], pre_trained: bool = True, requires_grad: bool = False):
super().__init__()
logger.warning("TimmFeatureExtractor will be removed in 2023.1")
ashwinvaidya17 marked this conversation as resolved.
Show resolved Hide resolved
self.backbone = backbone
self.layers = layers
self.idx = self._map_layer_to_idx()
Expand All @@ -56,7 +68,6 @@ def __init__(self, backbone: str, layers: List[str], pre_trained: bool = True, r
out_indices=self.idx,
)
self.out_dims = self.feature_extractor.feature_info.channels()
self._features = {layer: torch.empty(0) for layer in self.layers}

def _map_layer_to_idx(self, offset: int = 3) -> List[int]:
"""Maps set of layer names to indices of model.
Expand Down Expand Up @@ -100,17 +111,3 @@ def forward(self, inputs: Tensor) -> Dict[str, Tensor]:
with torch.no_grad():
features = dict(zip(self.layers, self.feature_extractor(inputs)))
return features


class FeatureExtractor(TimmFeatureExtractor):
"""Compatibility wrapper for the old FeatureExtractor class.

See :class:`anomalib.models.components.feature_extractors.timm.TimmFeatureExtractor` for more details.
"""

def __init__(self, *args, **kwargs):
logger.warning(
"FeatureExtractor is deprecated. Use TimmFeatureExtractor instead."
" Both FeatureExtractor and TimmFeatureExtractor will be removed in version 2023.1"
)
super().__init__(*args, **kwargs)
Loading