Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Detection] Compute box score when generating boxes from masks #828

Merged
merged 6 commits into from
Jan 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion anomalib/data/base/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,8 @@ def __getitem__(self, index: int) -> Dict[str, Union[str, Tensor]]:

if self.task == TaskType.DETECTION:
# create boxes from masks for detection task
item["boxes"] = masks_to_boxes(item["mask"])[0]
boxes, _ = masks_to_boxes(item["mask"])
item["boxes"] = boxes[0]
else:
raise ValueError(f"Unknown task type: {self.task}")

Expand Down
2 changes: 1 addition & 1 deletion anomalib/data/base/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def __getitem__(self, index: int) -> Dict[str, Union[str, Tensor]]:
item["mask"] = torch.stack([item["mask"] for item in processed_frames]).squeeze(0)
item["label"] = Tensor([1 in frame for frame in mask]).int().squeeze(0)
if self.task == TaskType.DETECTION:
item["boxes"] = masks_to_boxes(item["mask"])
item["boxes"], _ = masks_to_boxes(item["mask"])
item["boxes"] = item["boxes"][0] if len(item["boxes"]) == 1 else item["boxes"]
else:
item["image"] = torch.stack(
Expand Down
25 changes: 19 additions & 6 deletions anomalib/data/utils/boxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,41 +3,54 @@
# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from typing import List, Tuple
from typing import List, Optional, Tuple

import torch
from torch import Tensor

from anomalib.utils.cv import connected_components_cpu, connected_components_gpu


def masks_to_boxes(masks: Tensor) -> List[Tensor]:
def masks_to_boxes(masks: Tensor, anomaly_maps: Optional[Tensor] = None) -> Tuple[List[Tensor], List[Tensor]]:
"""Convert a batch of segmentation masks to bounding box coordinates.

Args:
masks (Tensor): Input tensor of shape (B, 1, H, W), (B, H, W) or (H, W)
anomaly_maps (Optional[Tensor], optional): Anomaly maps of shape (B, 1, H, W), (B, H, W) or (H, W) which are
used to determine an anomaly score for the converted bounding boxes.

Returns:
List[Tensor]: A list of length B where each element is a tensor of shape (N, 4) containing the bounding box
coordinates of the objects in the masks in xyxy format.
List[Tensor]: A list of length B where each element is a tensor of length (N) containing an anomaly score for
each of the converted boxes.
"""
masks = masks.view((-1, 1) + masks.shape[-2:]) # reshape to (B, 1, H, W)
masks = masks.float()
height, width = masks.shape[-2:]
masks = masks.view((-1, 1, height, width)).float() # reshape to (B, 1, H, W) and cast to float
if anomaly_maps is not None:
anomaly_maps = anomaly_maps.view((-1,) + masks.shape[-2:])

if masks.is_cuda:
batch_comps = connected_components_gpu(masks).squeeze(1)
else:
batch_comps = connected_components_cpu(masks).squeeze(1)

batch_boxes = []
for im_comps in batch_comps:
batch_scores = []
for im_idx, im_comps in enumerate(batch_comps):
labels = torch.unique(im_comps)
im_boxes = []
im_scores = []
for label in labels[labels != 0]:
y_loc, x_loc = torch.where(im_comps == label)
# add box
im_boxes.append(Tensor([torch.min(x_loc), torch.min(y_loc), torch.max(x_loc), torch.max(y_loc)]))
if anomaly_maps is not None:
im_scores.append(torch.max(anomaly_maps[im_idx, y_loc, x_loc]))
batch_boxes.append(torch.stack(im_boxes) if len(im_boxes) > 0 else torch.empty((0, 4)))
return batch_boxes
batch_scores.append(torch.stack(im_scores) if len(im_scores) > 0 else torch.empty(0))

return batch_boxes, batch_scores


def boxes_to_masks(boxes: List[Tensor], image_size: Tuple[int, int]) -> Tensor:
Expand Down
2 changes: 1 addition & 1 deletion anomalib/deploy/inferencers/torch_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def post_process(self, predictions: Tensor, meta_data: Optional[Union[Dict, Dict
pred_mask = cv2.resize(pred_mask, (image_width, image_height))

if self.config.dataset.task == TaskType.DETECTION:
pred_boxes = masks_to_boxes(torch.from_numpy(pred_mask))[0].numpy()
pred_boxes = masks_to_boxes(torch.from_numpy(pred_mask))[0][0].numpy()
box_labels = np.ones(pred_boxes.shape[0])
else:
pred_boxes = None
Expand Down
6 changes: 4 additions & 2 deletions anomalib/models/components/base/anomaly_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,12 @@ def predict_step(self, batch: Any, batch_idx: int, _dataloader_idx: Optional[int
if "anomaly_maps" in outputs.keys():
outputs["pred_masks"] = outputs["anomaly_maps"] >= self.pixel_threshold.value
if "pred_boxes" not in outputs.keys():
outputs["pred_boxes"] = masks_to_boxes(outputs["pred_masks"])
outputs["pred_boxes"], outputs["box_scores"] = masks_to_boxes(
outputs["pred_masks"], outputs["anomaly_maps"]
)
outputs["box_labels"] = [torch.ones(boxes.shape[0]) for boxes in outputs["pred_boxes"]]
# apply thresholding to boxes
if "box_scores" in outputs:
if "box_scores" in outputs and "box_labels" not in outputs:
# apply threshold to assign normal/anomalous label to boxes
is_anomalous = [scores > self.pixel_threshold.value for scores in outputs["box_scores"]]
outputs["box_labels"] = [labels.int() for labels in is_anomalous]
Expand Down
73 changes: 71 additions & 2 deletions tests/pre_merge/datasets/test_bounding_box_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,71 @@ def input_masks():
return torch.stack(masks)


@pytest.fixture
def input_maps():
masks = []
masks.append( # normal and tiny shapes
Tensor(
[
[0, 0, 1, 0, 0, 0, 0, 1, 0, 0],
[1, 1, 2, 0, 0, 1, 2, 2, 1, 0],
[1, 2, 4, 0, 0, 2, 3, 4, 2, 0],
[1, 2, 3, 0, 3, 4, 6, 5, 3, 0],
[1, 1, 2, 3, 4, 4, 5, 3, 1, 0],
[0, 1, 1, 2, 3, 4, 4, 3, 1, 1],
[0, 1, 2, 2, 2, 3, 3, 2, 1, 0],
[1, 2, 3, 3, 3, 2, 2, 1, 0, 0],
[2, 3, 5, 4, 2, 1, 1, 0, 0, 0],
[1, 2, 3, 3, 2, 1, 0, 0, 0, 0],
]
)
)
masks.append( # shapes at edge of image
Tensor(
[
[0.4, 0.2, 0, 0, 0, 0, 0, 0, 0, 0],
[0.3, 0.1, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 99999],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
]
)
)
masks.append( # diagonally touching shapes
Tensor(
[
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 1, 1, 0, 0, 0, 0],
[0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
[0, 0, 1, 1, 0, 0, 0, 0, 0, 0],
[0, 1, 1, 1, 0, 0, 0, 0, 0, 0],
[0, 1, 1, 1, 0, 0, 0, 0, 0, 0],
[0, 0, 1, 1, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
]
)
)
masks.append(torch.zeros((10, 10))) # empty mask
return torch.stack(masks)


@pytest.fixture
def target_scores():
return [
Tensor([4, 6, 5]),
Tensor([0.4, 99999]),
Tensor([1]),
Tensor(),
]


@pytest.fixture
def target_boxes():
boxes = []
Expand Down Expand Up @@ -241,7 +306,7 @@ def target_anomaly_maps():

class TestMasksToBoxes:
def test_output(self, input_masks, target_boxes):
out_boxes = masks_to_boxes(input_masks)
out_boxes, _ = masks_to_boxes(input_masks)
assert [out_box == target_box for out_box, target_box in zip(out_boxes, target_boxes)]

@pytest.mark.parametrize(
Expand All @@ -255,11 +320,15 @@ def test_output(self, input_masks, target_boxes):
),
) # (B, 1, H, W)
def test_input_shapes(self, masks):
out_boxes = masks_to_boxes(masks)
out_boxes, _ = masks_to_boxes(masks)
target_length = 1 if masks.dim() == 2 else masks.shape[0]
assert len(out_boxes) == target_length
assert out_boxes[0].shape == torch.Size((5, 4))

def test_box_scores(self, input_masks, input_maps, target_scores):
_, out_scores = masks_to_boxes(input_masks, input_maps)
assert all(torch.all(out == target) for out, target in zip(out_scores, target_scores))


class TestBoxesToMasks:
def test_output(self, input_boxes, target_masks):
Expand Down