Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Detection task type #732

Merged
merged 29 commits into from
Dec 6, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
64008c3
add basic support for detection task
djdameln Nov 24, 2022
bcd03d4
use enum for task type
djdameln Nov 24, 2022
ca427fb
formatting
djdameln Nov 25, 2022
69b83b4
small bugfix
djdameln Nov 25, 2022
30c4368
add unit tests for bounding box conversion
djdameln Nov 25, 2022
3c0cfec
update error message
djdameln Nov 28, 2022
037c1e5
use as_tensor
djdameln Nov 28, 2022
abea835
typing and docstring
djdameln Nov 28, 2022
c060333
explicit keyword arguments
djdameln Nov 28, 2022
bf573d1
simplify bbox handling in video dataset
djdameln Nov 28, 2022
b7f1b66
docstring consistency
djdameln Nov 28, 2022
7f60ea2
add missing licenses
djdameln Nov 28, 2022
eb87358
add whitespace for readability
djdameln Nov 28, 2022
4c3a6b1
add missing license
djdameln Nov 28, 2022
cec6138
Update anomalib/data/utils/boxes.py
djdameln Nov 28, 2022
d13ce5b
Revert "Update anomalib/data/utils/boxes.py"
djdameln Nov 28, 2022
0e0dc80
add test case for custom collate function
djdameln Nov 28, 2022
5ead1ad
docstring
djdameln Nov 28, 2022
44812d6
add integration tests for detection dataloading
djdameln Nov 29, 2022
d9304aa
extend and clean up datamodules tests
djdameln Nov 29, 2022
caf0867
add detection task type to visualizer tests
djdameln Nov 29, 2022
67312fc
Merge branch 'feature/datamodules' into da/detection-task-type
djdameln Nov 29, 2022
d63a7b7
only show pred_boxes during inference
djdameln Nov 30, 2022
7ec5fa4
add detection support for torch inference
djdameln Nov 30, 2022
d74bf41
add detection support for openvino inference
djdameln Nov 30, 2022
39cf0ac
test inference for all task types
djdameln Dec 1, 2022
f3d00d8
pylint
djdameln Dec 1, 2022
9962e8c
merge latest changes
djdameln Dec 5, 2022
5a055f2
merge feature branch
djdameln Dec 6, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 29 additions & 3 deletions anomalib/data/base/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,28 @@
from pandas import DataFrame
from pytorch_lightning import LightningDataModule
from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, default_collate

from anomalib.data.base.dataset import AnomalibDataset
from anomalib.data.utils import ValSplitMode, random_split

logger = logging.getLogger(__name__)


def collate_fn(batch):
"""Custom collate function that collates bounding boxes as lists."""
elem = batch[0]
djdameln marked this conversation as resolved.
Show resolved Hide resolved
out_dict = {}
if isinstance(elem, dict):
if "boxes" in elem.keys():
# collate boxes as list
out_dict["boxes"] = [item.pop("boxes") for item in batch]
# collate other data normally
out_dict.update({key: default_collate([item[key] for item in batch]) for key in elem})
return out_dict
return default_collate(batch)


class AnomalibDataModule(LightningDataModule, ABC):
"""Base Anomalib data module.

Expand Down Expand Up @@ -101,8 +115,20 @@ def train_dataloader(self) -> TRAIN_DATALOADERS:

def val_dataloader(self) -> EVAL_DATALOADERS:
"""Get validation dataloader."""
return DataLoader(self.val_data, shuffle=False, batch_size=self.eval_batch_size, num_workers=self.num_workers)
return DataLoader(
self.val_data,
shuffle=False,
batch_size=self.eval_batch_size,
num_workers=self.num_workers,
collate_fn=collate_fn,
)
djdameln marked this conversation as resolved.
Show resolved Hide resolved

def test_dataloader(self) -> EVAL_DATALOADERS:
"""Get test dataloader."""
return DataLoader(self.test_data, shuffle=False, batch_size=self.eval_batch_size, num_workers=self.num_workers)
return DataLoader(
self.test_data,
shuffle=False,
batch_size=self.eval_batch_size,
num_workers=self.num_workers,
collate_fn=collate_fn,
)
djdameln marked this conversation as resolved.
Show resolved Hide resolved
17 changes: 11 additions & 6 deletions anomalib/data/base/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,15 @@
from torch import Tensor
from torch.utils.data import Dataset

from anomalib.data.utils import read_image
from anomalib.data.utils import masks_to_boxes, read_image
from anomalib.pre_processing import PreProcessor

_EXPECTED_COLS_CLASSIFICATION = ["image_path", "split"]
_EXPECTED_COLS_SEGMENTATION = _EXPECTED_COLS_CLASSIFICATION + ["mask_path"]
_EXPECTED_COLS_PERTASK = {
"classification": _EXPECTED_COLS_CLASSIFICATION,
"segmentation": _EXPECTED_COLS_SEGMENTATION,
"detection": _EXPECTED_COLS_SEGMENTATION,
}

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -107,16 +108,16 @@ def __getitem__(self, index: int) -> Dict[str, Union[str, Tensor]]:
"""

image_path = self._samples.iloc[index].image_path
image = read_image(image_path)
mask_path = self._samples.iloc[index].mask_path
label_index = self._samples.iloc[index].label_index

image = read_image(image_path)
item = dict(image_path=image_path, label=label_index)

if self.task == "classification":
pre_processed = self.pre_process(image=image)
elif self.task == "segmentation":
mask_path = self._samples.iloc[index].mask_path

item["image"] = pre_processed["image"]
elif self.task in ["detection", "segmentation"]:
# Only Anomalous (1) images have masks in anomaly datasets
# Therefore, create empty mask for Normal (0) images.
if label_index == 0:
Expand All @@ -126,11 +127,15 @@ def __getitem__(self, index: int) -> Dict[str, Union[str, Tensor]]:

pre_processed = self.pre_process(image=image, mask=mask)

item["image"] = pre_processed["image"]
item["mask_path"] = mask_path
item["mask"] = pre_processed["mask"]

if self.task == "detection":
# create boxes from masks for detection task
item["boxes"] = masks_to_boxes(item["mask"])[0]
else:
raise ValueError(f"Unknown task type: {self.task}")
item["image"] = pre_processed["image"]

return item

Expand Down
8 changes: 7 additions & 1 deletion anomalib/data/base/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torch import Tensor

from anomalib.data.base.dataset import AnomalibDataset
from anomalib.data.utils import masks_to_boxes
from anomalib.data.utils.video import ClipsIndexer
from anomalib.pre_processing import PreProcessor

Expand Down Expand Up @@ -74,9 +75,14 @@ def __getitem__(self, index: int) -> Dict[str, Union[str, Tensor]]:
self.pre_process(image=frame.numpy(), mask=mask) for frame, mask in zip(item["image"], item["mask"])
]
item["image"] = torch.stack([item["image"] for item in processed_frames]).squeeze(0)
mask = item["mask"]
mask = Tensor(item["mask"])
djdameln marked this conversation as resolved.
Show resolved Hide resolved
item["mask"] = torch.stack([item["mask"] for item in processed_frames]).squeeze(0)
item["label"] = Tensor([1 in frame for frame in mask]).int().squeeze(0)
item["boxes"] = [
torch.empty((0, 4)) if frame.max() == 0 else masks_to_boxes(frame)
for frame in item["mask"].view((1, 1) + item["mask"].shape[-2:])
]
item["boxes"] = item["boxes"][0] if len(item["boxes"]) == 1 else item["boxes"]
else:
item["image"] = torch.stack(
[self.pre_process(image=frame.numpy())["image"] for frame in item["image"]]
Expand Down
4 changes: 4 additions & 0 deletions anomalib/data/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from .boxes import boxes_to_anomaly_maps, boxes_to_masks, masks_to_boxes
from .download import DownloadProgressBar, hash_check
from .generators import random_2d_perlin
from .image import (
Expand All @@ -25,4 +26,7 @@
"concatenate_datasets",
"Split",
"ValSplitMode",
"masks_to_boxes",
"boxes_to_masks",
"boxes_to_anomaly_maps",
]
87 changes: 87 additions & 0 deletions anomalib/data/utils/boxes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
"""Helper functions for processing bounding box detections and annotations."""

djdameln marked this conversation as resolved.
Show resolved Hide resolved
from typing import List, Tuple

import torch
from torch import Tensor

from anomalib.utils.cv import connected_components_cpu, connected_components_gpu


def masks_to_boxes(masks: Tensor) -> List[Tensor]:
"""Convert a batch of segmentation masks to bounding box coordinates.

Args:
masks (Tensor): Input tensor of shape (B, 1, H, W), (B, H, W) or (H, W)

Returns:
List[Tensor]: A list of length B where each element is a tensor of shape (N, 4) containing the bounding box
coordinates of the objects in the masks in xyxy format.
"""
masks = masks.view(
(
-1,
1,
)
+ masks.shape[-2:]
) # reshape to (B, 1, H, W)
masks = masks.float()

if masks.is_cuda:
batch_comps = connected_components_gpu(masks).squeeze(1)
else:
batch_comps = connected_components_cpu(masks).squeeze(1)
djdameln marked this conversation as resolved.
Show resolved Hide resolved

batch_boxes = []
for im_comps in batch_comps:
labels = torch.unique(im_comps)
im_boxes = []
for label in labels[labels != 0]:
y_loc, x_loc = torch.where(im_comps == label)
im_boxes.append(Tensor([torch.min(x_loc), torch.min(y_loc), torch.max(x_loc), torch.max(y_loc)]))
batch_boxes.append(torch.stack(im_boxes) if len(im_boxes) > 0 else torch.empty((0, 4)))
return batch_boxes


def boxes_to_masks(boxes: List[Tensor], image_size: Tuple[int, int]) -> Tensor:
"""Convert bounding boxes to segmentations masks.

Args:
boxes (List[Tensor]): A list of length B where each element is a tensor of shape (N, 4) containing the bounding
box coordinates of the regions of interest in xyxy format.
image_size (Tuple[int, int]): Image size of the output masks in (H, W) format.

Returns:
Tensor: Tensor of shape (B, H, W) in which each slice is a binary mask showing the pixels contained by a
bounding box.
"""
masks = torch.zeros((len(boxes),) + image_size)
for im_idx, im_boxes in enumerate(boxes):
for box in im_boxes:
x_1, y_1, x_2, y_2 = box.int()
masks[im_idx, y_1:y_2, x_1:x_2] = 1
return masks


def boxes_to_anomaly_maps(boxes: Tensor, scores: Tensor, image_size: Tuple[int, int]) -> Tensor:
"""Convert bounding box coordinates to anomaly heatmaps.

Args:
boxes (List[Tensor]): A list of length B where each element is a tensor of shape (N, 4) containing the bounding
box coordinates of the regions of interest in xyxy format.
scores (List[Tensor]): A list of length B where each element is a 1D tensor of length N containing the anomaly
scores for each region of interest.
image_size (Tuple[int, int]): Image size of the output masks in (H, W) format.

Returns:
Tensor: Tensor of shape (B, H, W). The pixel locations within each bounding box are collectively assigned the
anomaly score of the bounding box. In the case of overlapping bounding boxes, the highest score is used.
"""
anomaly_maps = torch.zeros((len(boxes),) + image_size).to(boxes[0].device)
for im_idx, (im_boxes, im_scores) in enumerate(zip(boxes, scores)):
im_map = torch.zeros((im_boxes.shape[0],) + image_size)
for box_idx, (box, score) in enumerate(zip(im_boxes, im_scores)):
x_1, y_1, x_2, y_2 = box.int()
im_map[box_idx, y_1:y_2, x_1:x_2] = score
anomaly_maps[im_idx], _ = im_map.max(dim=0)
return anomaly_maps
31 changes: 24 additions & 7 deletions anomalib/models/components/base/anomaly_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,16 @@

import logging
from abc import ABC
from typing import Any, List, Optional, OrderedDict
from typing import Any, Dict, List, Optional, OrderedDict
from warnings import warn

import pytorch_lightning as pl
import torch
from pytorch_lightning.callbacks.base import Callback
from torch import Tensor, nn
from torchmetrics import Metric

from anomalib.data.utils import boxes_to_anomaly_maps, boxes_to_masks, masks_to_boxes
from anomalib.post_processing import ThresholdMethod
from anomalib.utils.metrics import (
AnomalibMetricCollection,
Expand Down Expand Up @@ -82,6 +84,8 @@ def predict_step(self, batch: Any, batch_idx: int, _dataloader_idx: Optional[int
outputs["pred_labels"] = outputs["pred_scores"] >= self.image_threshold.value
if "anomaly_maps" in outputs.keys():
outputs["pred_masks"] = outputs["anomaly_maps"] >= self.pixel_threshold.value
if "pred_boxes" not in outputs.keys():
outputs["pred_boxes"] = masks_to_boxes(outputs["pred_masks"])
samet-akcay marked this conversation as resolved.
Show resolved Hide resolved
return outputs

def test_step(self, batch, _): # pylint: disable=arguments-differ
Expand Down Expand Up @@ -155,15 +159,28 @@ def _collect_outputs(image_metric, pixel_metric, outputs):
def _post_process(outputs):
"""Compute labels based on model predictions."""
if "pred_scores" not in outputs and "anomaly_maps" in outputs:
# infer image scores from anomaly maps
outputs["pred_scores"] = (
outputs["anomaly_maps"].reshape(outputs["anomaly_maps"].shape[0], -1).max(dim=1).values
)

@staticmethod
def _outputs_to_cpu(output):
for key, value in output.items():
if isinstance(value, Tensor):
output[key] = value.cpu()
elif "pred_scores" not in outputs and "boxes_scores" in outputs:
# infer image score from bbox confidence scores
outputs["pred_scores"] = torch.stack([scores.max() for scores in outputs["boxes_scores"]])
if "pred_boxes" in outputs and "anomaly_maps" not in outputs:
djdameln marked this conversation as resolved.
Show resolved Hide resolved
# create anomaly maps from bbox predictions for thresholding and evaluation
image_size = tuple(outputs["image"].shape[-2:])
outputs["anomaly_maps"] = boxes_to_anomaly_maps(outputs["pred_boxes"], outputs["boxes_scores"], image_size)
outputs["mask"] = boxes_to_masks(outputs["boxes"], image_size)

def _outputs_to_cpu(self, output):
if isinstance(output, Dict):
for key, value in output.items():
output[key] = self._outputs_to_cpu(value)
elif isinstance(output, List):
output = [self._outputs_to_cpu(item) for item in output]
elif isinstance(output, Tensor):
output = output.cpu()
return output

def _log_metrics(self):
"""Log computed performance metrics."""
Expand Down
19 changes: 19 additions & 0 deletions anomalib/post_processing/post_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,3 +151,22 @@ def compute_mask(anomaly_map: np.ndarray, threshold: float, kernel_size: int = 4
mask *= 255

return mask


def draw_boxes(image: np.ndarray, boxes: np.ndarray, is_ground_truth: bool = False) -> np.ndarray:
"""Draw bounding boxes on an image.

Args:
image (np.ndarray): Source image.
boxes (np.nparray): 2D array of shape (N, 4) where each row contains the xyxy coordinates of a bounding box.
is_ground_truth (bool): Flag indicating if the boxes are ground truth. When true, boxes will be drawn in red,
otherwise in blue.

Returns:
np.ndarray: Image showing the bounding boxes drawn on top of the source image.
"""
color = (255, 0, 0) if is_ground_truth else (0, 0, 255)
for box in boxes:
x_1, y_1, x_2, y_2 = box.astype(np.int)
image = cv2.rectangle(image, (x_1, y_1), (x_2, y_2), color=color, thickness=2)
return image
22 changes: 21 additions & 1 deletion anomalib/post_processing/visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from anomalib.post_processing.post_process import (
add_anomalous_label,
add_normal_label,
draw_boxes,
superimpose_anomaly_map,
)

Expand All @@ -31,6 +32,8 @@ class ImageResult:
anomaly_map: Optional[np.ndarray] = None
gt_mask: Optional[np.ndarray] = None
pred_mask: Optional[np.ndarray] = None
gt_boxes: Optional[np.ndarray] = None
pred_boxes: Optional[np.ndarray] = None

heat_map: np.ndarray = field(init=False)
segmentations: np.ndarray = field(init=False)
Expand Down Expand Up @@ -60,7 +63,7 @@ def __init__(self, mode: str, task: str) -> None:
if mode not in ["full", "simple"]:
raise ValueError(f"Unknown visualization mode: {mode}. Please choose one of ['full', 'simple']")
self.mode = mode
if task not in ["classification", "segmentation"]:
if task not in ["classification", "segmentation", "detection"]:
raise ValueError(f"Unknown task type: {mode}. Please choose one of ['classification', 'segmentation']")
self.task = task

Expand Down Expand Up @@ -90,6 +93,8 @@ def visualize_batch(self, batch: Dict) -> Iterator[np.ndarray]:
anomaly_map=batch["anomaly_maps"][i].cpu().numpy() if "anomaly_maps" in batch else None,
pred_mask=batch["pred_masks"][i].squeeze().int().cpu().numpy() if "pred_masks" in batch else None,
gt_mask=batch["mask"][i].squeeze().int().cpu().numpy() if "mask" in batch else None,
gt_boxes=batch["boxes"][i].cpu().numpy() if "boxes" in batch else None,
pred_boxes=batch["pred_boxes"][i].cpu().numpy() if "pred_boxes" in batch else None,
)
yield self.visualize_image(image_result)

Expand Down Expand Up @@ -122,6 +127,16 @@ def _visualize_full(self, image_result: ImageResult) -> np.ndarray:
An image showing the full set of visualizations for the input image.
"""
visualization = ImageGrid()
if self.task == "detection":
assert image_result.pred_boxes is not None
visualization.add_image(image_result.image, "Image")
if image_result.gt_boxes is not None:
gt_image = draw_boxes(np.copy(image_result.image), image_result.gt_boxes, is_ground_truth=True)
visualization.add_image(image=gt_image, color_map="gray", title="Ground Truth")
else:
visualization.add_image(image_result.image, "Image")
pred_image = draw_boxes(np.copy(image_result.image), image_result.pred_boxes, is_ground_truth=False)
visualization.add_image(pred_image, "Predictions")
if self.task == "segmentation":
assert image_result.pred_mask is not None
visualization.add_image(image_result.image, "Image")
Expand Down Expand Up @@ -151,6 +166,11 @@ def _visualize_simple(self, image_result: ImageResult) -> np.ndarray:
Returns:
An image showing the simple visualization for the input image.
"""
if self.task == "detection":
# return image with bounding boxes augmented
image_with_boxes = draw_boxes(image=image_result.image, boxes=image_result.gt_boxes, is_ground_truth=True)
image_with_boxes = draw_boxes(image=image_with_boxes, boxes=image_result.pred_boxes, is_ground_truth=False)
return image_with_boxes
if self.task == "segmentation":
visualization = mark_boundaries(
image_result.heat_map, image_result.pred_mask, color=(1, 0, 0), mode="thick"
Expand Down
Loading