From 0bf97060d2d6c140e682aaf21db144da500c3aa5 Mon Sep 17 00:00:00 2001 From: Ashwin Vaidya Date: Fri, 25 Nov 2022 08:42:10 +0100 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20Add=20torchfx=20feature=20extractor?= =?UTF-8?q?=20(#675)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add torchfx feature extractor * Fix docstrings * Support loading models from custom class * Convert torchfx feature extractor to class for consistency * Update tests * Add FeatureExtractor method for backward compatibility * fix imports * Copy changes from #714 for tests * Revert naming to FeatureExtractor * Revert removing eval * User backbone params for torchfx feature extractor * Remove unnecessary class parameter * Address PR comments * Pass locally defined class + add tests for it * Remove hparams from dummy model * Address codacy issues Co-authored-by: Ashwin Vaidya Co-authored-by: Samet Akcay --- anomalib/models/components/__init__.py | 14 +- .../components/feature_extractors/__init__.py | 11 +- .../{feature_extractor.py => timm.py} | 32 +++- .../components/feature_extractors/torchfx.py | 174 ++++++++++++++++++ .../components/feature_extractors/utils.py | 4 +- anomalib/models/dfkde/torch_model.py | 3 +- anomalib/models/stfpm/loss.py | 2 +- tests/helpers/dummy.py | 26 ++- .../models/test_feature_extractor.py | 44 ++++- .../export_callback/dummy_lightning_model.py | 30 +-- tools/train.py | 2 +- 11 files changed, 291 insertions(+), 51 deletions(-) rename anomalib/models/components/feature_extractors/{feature_extractor.py => timm.py} (77%) create mode 100644 anomalib/models/components/feature_extractors/torchfx.py diff --git a/anomalib/models/components/__init__.py b/anomalib/models/components/__init__.py index 5d4399ec87..95f77b6017 100644 --- a/anomalib/models/components/__init__.py +++ b/anomalib/models/components/__init__.py @@ -5,7 +5,11 @@ from .base import AnomalyModule, DynamicBufferModule from .dimensionality_reduction import PCA, SparseRandomProjection -from .feature_extractors import FeatureExtractor +from .feature_extractors import ( + FeatureExtractor, + TimmFeatureExtractor, + TorchFXFeatureExtractor, +) from .filters import GaussianBlur2d from .sampling import KCenterGreedy from .stats import GaussianKDE, MultiVariateGaussian @@ -13,11 +17,13 @@ __all__ = [ "AnomalyModule", "DynamicBufferModule", - "PCA", - "SparseRandomProjection", "FeatureExtractor", - "KCenterGreedy", "GaussianKDE", "GaussianBlur2d", + "KCenterGreedy", "MultiVariateGaussian", + "PCA", + "SparseRandomProjection", + "TimmFeatureExtractor", + "TorchFXFeatureExtractor", ] diff --git a/anomalib/models/components/feature_extractors/__init__.py b/anomalib/models/components/feature_extractors/__init__.py index 100e5e234d..af1fc5d499 100644 --- a/anomalib/models/components/feature_extractors/__init__.py +++ b/anomalib/models/components/feature_extractors/__init__.py @@ -3,7 +3,14 @@ # Copyright (C) 2022 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -from .feature_extractor import FeatureExtractor +from .timm import FeatureExtractor, TimmFeatureExtractor +from .torchfx import BackboneParams, TorchFXFeatureExtractor from .utils import dryrun_find_featuremap_dims -__all__ = ["FeatureExtractor", "dryrun_find_featuremap_dims"] +__all__ = [ + "BackboneParams", + "dryrun_find_featuremap_dims", + "FeatureExtractor", + "TimmFeatureExtractor", + "TorchFXFeatureExtractor", +] diff --git a/anomalib/models/components/feature_extractors/feature_extractor.py b/anomalib/models/components/feature_extractors/timm.py similarity index 77% rename from anomalib/models/components/feature_extractors/feature_extractor.py rename to anomalib/models/components/feature_extractors/timm.py index 81c13e9c99..3a12baf337 100644 --- a/anomalib/models/components/feature_extractors/feature_extractor.py +++ b/anomalib/models/components/feature_extractors/timm.py @@ -6,6 +6,7 @@ # Copyright (C) 2022 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +import logging import warnings from typing import Dict, List @@ -13,8 +14,10 @@ import torch from torch import Tensor, nn +logger = logging.getLogger(__name__) -class FeatureExtractor(nn.Module): + +class TimmFeatureExtractor(nn.Module): """Extract features from a CNN. Args: @@ -27,9 +30,9 @@ class FeatureExtractor(nn.Module): Example: >>> import torch - >>> from anomalib.core.model.feature_extractor import FeatureExtractor + >>> from anomalib.models.components.feature_extractors import TimmFeatureExtractor - >>> model = FeatureExtractor(model="resnet18", layers=['layer1', 'layer2', 'layer3']) + >>> model = TimmFeatureExtractor(model="resnet18", layers=['layer1', 'layer2', 'layer3']) >>> input = torch.rand((32, 3, 256, 256)) >>> features = model(input) @@ -81,20 +84,33 @@ def _map_layer_to_idx(self, offset: int = 3) -> List[int]: return idx - def forward(self, input_tensor: Tensor) -> Dict[str, Tensor]: + def forward(self, inputs: Tensor) -> Dict[str, Tensor]: """Forward-pass input tensor into the CNN. Args: - input_tensor (Tensor): Input tensor + inputs (Tensor): Input tensor Returns: Feature map extracted from the CNN """ if self.requires_grad: - features = dict(zip(self.layers, self.feature_extractor(input_tensor))) + features = dict(zip(self.layers, self.feature_extractor(inputs))) else: self.feature_extractor.eval() with torch.no_grad(): - features = dict(zip(self.layers, self.feature_extractor(input_tensor))) - + features = dict(zip(self.layers, self.feature_extractor(inputs))) return features + + +class FeatureExtractor(TimmFeatureExtractor): + """Compatibility wrapper for the old FeatureExtractor class. + + See :class:`anomalib.models.components.feature_extractors.timm.TimmFeatureExtractor` for more details. + """ + + def __init__(self, *args, **kwargs): + logger.warning( + "FeatureExtractor is deprecated. Use TimmFeatureExtractor instead." + " Both FeatureExtractor and TimmFeatureExtractor will be removed in version 2023.1" + ) + super().__init__(*args, **kwargs) diff --git a/anomalib/models/components/feature_extractors/torchfx.py b/anomalib/models/components/feature_extractors/torchfx.py new file mode 100644 index 0000000000..4766daeed0 --- /dev/null +++ b/anomalib/models/components/feature_extractors/torchfx.py @@ -0,0 +1,174 @@ +"""Feature Extractor based on TorchFX.""" + +# Copyright (C) 2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import importlib +from dataclasses import dataclass, field +from typing import Callable, Dict, List, Optional, Union + +import torch +from torch import Tensor, nn +from torch.fx.graph_module import GraphModule +from torchvision.models._api import WeightsEnum +from torchvision.models.feature_extraction import create_feature_extractor + + +@dataclass +class BackboneParams: + """Used for serializing the backbone.""" + + class_path: Union[str, nn.Module] + init_args: Dict = field(default_factory=dict) + + +class TorchFXFeatureExtractor: + """Extract features from a CNN. + + Args: + backbone (Union[str, BackboneParams, Dict, nn.Module]): The backbone to which the feature extraction hooks are + attached. If the name is provided, the model is loaded from torchvision. Otherwise, the model class can be + provided and it will try to load the weights from the provided weights file. + return_nodes (Iterable[str]): List of layer names of the backbone to which the hooks are attached. + You can find the names of these nodes by using ``get_graph_node_names`` function. + weights (Optional[Union[WeightsEnum,str]]): Weights enum to use for the model. Torchvision models require + ``WeightsEnum``. These enums are defined in ``torchvision.models.``. You can pass the weights + path for custom models. + requires_grad (bool): Models like ``stfpm`` use the feature extractor for training. In such cases we should + set ``requires_grad`` to ``True``. Default is ``False``. + + Example: + With torchvision models: + + >>> import torch + >>> from anomalib.models.components.feature_extractors import TorchFXFeatureExtractor + >>> from torchvision.models.efficientnet import EfficientNet_B5_Weights + >>> feature_extractor = TorchFXFeatureExtractor( + backbone="efficientnet_b5", + return_nodes=["features.6.8"], + weights=EfficientNet_B5_Weights.DEFAULT + ) + >>> input = torch.rand((32, 3, 256, 256)) + >>> features = feature_extractor(input) + >>> [layer for layer in features.keys()] + ["features.6.8"] + >>> [feature.shape for feature in features.values()] + [torch.Size([32, 304, 8, 8])] + + With custom models: + + >>> from anomalib.models.components.feature_extractors import TorchFXFeatureExtractor + >>> feature_extractor = TorchFXFeatureExtractor( + "path.to.CustomModel", ["linear_relu_stack.3"], weights="path/to/weights.pth" + ) + >>> input = torch.randn(1, 1, 28, 28) + >>> features = feature_extractor(input) + >>> [layer for layer in features.keys()] + ["linear_relu_stack.3"] + """ + + def __init__( + self, + backbone: Union[str, BackboneParams, Dict, nn.Module], + return_nodes: List[str], + weights: Optional[Union[WeightsEnum, str]] = None, + requires_grad: bool = False, + ): + if isinstance(backbone, dict): + backbone = BackboneParams(**backbone) + elif not isinstance(backbone, BackboneParams): # if str or nn.Module + backbone = BackboneParams(class_path=backbone) + + self.feature_extractor = self.initialize_feature_extractor(backbone, return_nodes, weights, requires_grad) + + def initialize_feature_extractor( + self, + backbone: BackboneParams, + return_nodes: List[str], + weights: Optional[Union[WeightsEnum, str]] = None, + requires_grad: bool = False, + ) -> Union[GraphModule, nn.Module]: + """Extract features from a CNN. + + Args: + backbone (Union[str, BackboneParams]): The backbone to which the feature extraction hooks are attached. + If the name is provided, the model is loaded from torchvision. Otherwise, the model class can be + provided and it will try to load the weights from the provided weights file. + return_nodes (Iterable[str]): List of layer names of the backbone to which the hooks are attached. + You can find the names of these nodes by using ``get_graph_node_names`` function. + weights (Optional[Union[WeightsEnum,str]]): Weights enum to use for the model. Torchvision models require + ``WeightsEnum``. These enums are defined in ``torchvision.models.``. You can pass the weights + path for custom models. + requires_grad (bool): Models like ``stfpm`` use the feature extractor for training. In such cases we should + set ``requires_grad`` to ``True``. Default is ``False``. + + Returns: + Feature Extractor based on TorchFX. + """ + if isinstance(backbone.class_path, str): + backbone_class = self._get_backbone_class(backbone.class_path) + backbone_model = backbone_class(weights=weights, **backbone.init_args) + else: + backbone_class = backbone.class_path + backbone_model = backbone_class(**backbone.init_args) + if isinstance(weights, WeightsEnum): # torchvision models + feature_extractor = create_feature_extractor(model=backbone_model, return_nodes=return_nodes) + else: + if weights is not None: + assert isinstance(weights, str), "Weights should point to a path" + model_weights = torch.load(weights) + if "state_dict" in model_weights: + model_weights = model_weights["state_dict"] + backbone_model.load_state_dict(model_weights) + feature_extractor = create_feature_extractor(backbone_model, return_nodes) + + if not requires_grad: + feature_extractor.eval() + for param in feature_extractor.parameters(): + param.requires_grad_(False) + + return feature_extractor + + @staticmethod + def _get_backbone_class(backbone: str) -> Callable[..., nn.Module]: + """Get the backbone class from the provided path. + + If only the model name is provided, it will try to load the model from torchvision. + + Example: + >>> from anomalib.models.components.feature_extractors import TorchFXFeatureExtractor + >>> TorchFXFeatureExtractor._get_backbone_class("efficientnet_b5") + torchvision.models.efficientnet.EfficientNet> + + >>> TorchFXFeatureExtractor._get_backbone_class("path.to.CustomModel") + + + Args: + backbone (str): Path to the backbone class. + + Returns: + Backbone class. + """ + try: + if len(backbone.split(".")) > 1: + # assumes that the entire class path is provided + models = importlib.import_module(".".join(backbone.split(".")[:-1])) + backbone_class = getattr(models, backbone.split(".")[-1]) + else: + models = importlib.import_module("torchvision.models") + backbone_class = getattr(models, backbone) + except ModuleNotFoundError as exception: + raise ModuleNotFoundError( + f"Backbone {backbone} not found in torchvision.models nor in {backbone} module." + ) from exception + + return backbone_class + + def __call__(self, inputs: Tensor) -> Dict[str, Tensor]: + """Extract features from the input.""" + return self.feature_extractor(inputs) diff --git a/anomalib/models/components/feature_extractors/utils.py b/anomalib/models/components/feature_extractors/utils.py index 0efd1011fd..e72051bf78 100644 --- a/anomalib/models/components/feature_extractors/utils.py +++ b/anomalib/models/components/feature_extractors/utils.py @@ -4,9 +4,7 @@ import torch -from anomalib.models.components.feature_extractors.feature_extractor import ( - FeatureExtractor, -) +from .timm import FeatureExtractor def dryrun_find_featuremap_dims( diff --git a/anomalib/models/dfkde/torch_model.py b/anomalib/models/dfkde/torch_model.py index 399cda1290..5662aad095 100644 --- a/anomalib/models/dfkde/torch_model.py +++ b/anomalib/models/dfkde/torch_model.py @@ -48,8 +48,7 @@ def __init__( self.threshold_steepness = threshold_steepness self.threshold_offset = threshold_offset - _backbone = backbone - self.feature_extractor = FeatureExtractor(backbone=_backbone, pre_trained=pre_trained, layers=layers).eval() + self.feature_extractor = FeatureExtractor(backbone=backbone, pre_trained=pre_trained, layers=layers).eval() self.pca_model = PCA(n_components=self.n_components) self.kde_model = GaussianKDE() diff --git a/anomalib/models/stfpm/loss.py b/anomalib/models/stfpm/loss.py index 8f60ab2ec8..c675baad5c 100644 --- a/anomalib/models/stfpm/loss.py +++ b/anomalib/models/stfpm/loss.py @@ -14,7 +14,7 @@ class STFPMLoss(nn.Module): """Feature Pyramid Loss This class implmenents the feature pyramid loss function proposed in STFPM paper. Example: - >>> from anomalib.models.components.feature_extractors.feature_extractor import FeatureExtractor + >>> from anomalib.models.components.feature_extractors import FeatureExtractor >>> from anomalib.models.stfpm.loss import STFPMLoss >>> from torchvision.models import resnet18 diff --git a/tests/helpers/dummy.py b/tests/helpers/dummy.py index a98875046e..8db63dc989 100644 --- a/tests/helpers/dummy.py +++ b/tests/helpers/dummy.py @@ -4,6 +4,7 @@ import pytorch_lightning as pl import torch +import torch.nn.functional as F from torch import nn from torch.utils.data import DataLoader, Dataset @@ -30,7 +31,30 @@ def test_dataloader(self) -> DataLoader: class DummyModel(nn.Module): - pass + """Creates a very basic CNN model to fit image data for classification task + The test uses this to check if this model is converted to OpenVINO IR.""" + + def __init__( + self, + ): + super().__init__() + self.conv1 = nn.Conv2d(3, 32, 3) + self.conv2 = nn.Conv2d(32, 32, 5) + self.conv3 = nn.Conv2d(32, 1, 7) + self.fc1 = nn.Linear(400, 256) + self.fc2 = nn.Linear(256, 10) + + def forward(self, x): + batch_size, _, _, _ = x.size() + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) + x = x.view(batch_size, -1) + x = self.fc1(x) + x = F.dropout(x, p=0.2) + x = self.fc2(x) + x = F.log_softmax(x, dim=1) + return x class DummyLogger(AnomalibTensorBoardLogger): diff --git a/tests/pre_merge/models/test_feature_extractor.py b/tests/pre_merge/models/test_feature_extractor.py index cfe456a6ca..8cad82f581 100644 --- a/tests/pre_merge/models/test_feature_extractor.py +++ b/tests/pre_merge/models/test_feature_extractor.py @@ -1,12 +1,17 @@ +from tempfile import TemporaryDirectory from typing import Tuple import pytest import torch +from torchvision.models.efficientnet import EfficientNet_B5_Weights +from torchvision.models.resnet import ResNet18_Weights from anomalib.models.components.feature_extractors import ( FeatureExtractor, + TorchFXFeatureExtractor, dryrun_find_featuremap_dims, ) +from tests.helpers.dummy import DummyModel class TestFeatureExtractor: @@ -18,7 +23,7 @@ class TestFeatureExtractor: "pretrained", [True, False], ) - def test_feature_extraction(self, backbone, pretrained): + def test_timm_feature_extraction(self, backbone, pretrained): layers = ["layer1", "layer2", "layer3"] model = FeatureExtractor(backbone=backbone, layers=layers, pre_trained=pretrained) test_input = torch.rand((32, 3, 256, 256)) @@ -39,6 +44,43 @@ def test_feature_extraction(self, backbone, pretrained): else: pass + def test_torchfx_feature_extraction(self): + model = TorchFXFeatureExtractor("resnet18", ["layer1", "layer2", "layer3"]) + test_input = torch.rand((32, 3, 256, 256)) + features = model(test_input) + assert features["layer1"].shape == torch.Size((32, 64, 64, 64)) + assert features["layer2"].shape == torch.Size((32, 128, 32, 32)) + assert features["layer3"].shape == torch.Size((32, 256, 16, 16)) + + # Test if model can be loaded by using just its name + model = TorchFXFeatureExtractor( + backbone="efficientnet_b5", return_nodes=["features.6.8"], weights=EfficientNet_B5_Weights.DEFAULT + ) + features = model(test_input) + assert features["features.6.8"].shape == torch.Size((32, 304, 8, 8)) + + # Test if model can be loaded by using entire class path + model = TorchFXFeatureExtractor( + backbone="torchvision.models.resnet18", + return_nodes=["layer1", "layer2", "layer3"], + weights=ResNet18_Weights.DEFAULT, + ) + features = model(test_input) + assert features["layer1"].shape == torch.Size((32, 64, 64, 64)) + assert features["layer2"].shape == torch.Size((32, 128, 32, 32)) + assert features["layer3"].shape == torch.Size((32, 256, 16, 16)) + + # Test if local model can be loaded using string of weights path + with TemporaryDirectory() as tmpdir: + torch.save(DummyModel().state_dict(), tmpdir + "/dummy_model.pt") + model = TorchFXFeatureExtractor( + backbone=DummyModel, + weights=tmpdir + "/dummy_model.pt", + return_nodes=["conv3"], + ) + features = model(test_input) + assert features["conv3"].shape == torch.Size((32, 1, 244, 244)) + @pytest.mark.parametrize( "backbone", diff --git a/tests/pre_merge/utils/callbacks/export_callback/dummy_lightning_model.py b/tests/pre_merge/utils/callbacks/export_callback/dummy_lightning_model.py index c9e0929181..fe7a338565 100644 --- a/tests/pre_merge/utils/callbacks/export_callback/dummy_lightning_model.py +++ b/tests/pre_merge/utils/callbacks/export_callback/dummy_lightning_model.py @@ -1,7 +1,6 @@ from typing import Union import pytorch_lightning as pl -import torch.nn.functional as F from omegaconf import DictConfig, ListConfig from torch import nn, optim from torch.utils.data import DataLoader @@ -14,6 +13,7 @@ AnomalyScoreThreshold, MinMax, ) +from tests.helpers.dummy import DummyModel class FakeDataModule(pl.LightningDataModule): @@ -45,32 +45,6 @@ def test_dataloader(self): ) -class DummyModel(nn.Module): - """Creates a very basic CNN model to fit image data for classification task - The test uses this to check if this model is converted to OpenVINO IR.""" - - def __init__(self, hparams: Union[DictConfig, ListConfig]): - super().__init__() - self.hparams = hparams - self.conv1 = nn.Conv2d(3, 32, 3) - self.conv2 = nn.Conv2d(32, 32, 5) - self.conv3 = nn.Conv2d(32, 1, 7) - self.fc1 = nn.Linear(400, 256) - self.fc2 = nn.Linear(256, 10) - - def forward(self, x): - batch_size, _, _, _ = x.size() - x = self.conv1(x) - x = self.conv2(x) - x = self.conv3(x) - x = x.view(batch_size, -1) - x = self.fc1(x) - x = F.dropout(x, p=self.hparams.model.dropout) - x = self.fc2(x) - x = F.log_softmax(x, dim=1) - return x - - class DummyLightningModule(pl.LightningModule): """A dummy model which fits the torchvision FakeData dataset.""" @@ -93,7 +67,7 @@ def __init__(self, hparams: Union[DictConfig, ListConfig]): self.training_distribution = AnomalyScoreDistribution().cpu() self.min_max = MinMax().cpu() - self.model = DummyModel(hparams) + self.model = DummyModel() def training_step(self, batch, _): x, y = batch diff --git a/tools/train.py b/tools/train.py index 0e5daa3b10..0bb00c8179 100644 --- a/tools/train.py +++ b/tools/train.py @@ -1,4 +1,4 @@ -"""Anomalib Traning Script. +"""Anomalib Training Script. This script reads the name of the model or config file from command line, train/test the anomaly model to get quantitative and qualitative