Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor AUPRO metric #991

Merged
merged 9 commits into from
Apr 5, 2023
55 changes: 45 additions & 10 deletions src/anomalib/utils/metrics/aupro.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(
process_group: Any | None = None,
dist_sync_fn: Callable | None = None,
fpr_limit: float = 0.3,
inspection_mask: Tensor | None = None,
) -> None:
super().__init__(
compute_on_step=compute_on_step,
Expand All @@ -49,6 +50,12 @@ def __init__(
self.add_state("preds", default=[], dist_reduce_fx="cat") # pylint: disable=not-callable
self.add_state("target", default=[], dist_reduce_fx="cat") # pylint: disable=not-callable
self.register_buffer("fpr_limit", torch.tensor(fpr_limit))
self.register_buffer("inspection_mask", inspection_mask)

if (self.inspection_mask is not None) and (
not torch.equal(self.inspection_mask, self.inspection_mask.type(torch.bool))
):
raise ValueError("inspection_mask must be a binary Tensor")

def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
"""Update state with new values.
Expand All @@ -60,21 +67,17 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
self.target.append(target)
self.preds.append(preds)

def _compute(self) -> tuple[Tensor, Tensor]:
"""Compute the pro/fpr value-pairs until the fpr specified by self.fpr_limit.

It leverages the fact that the overlap corresponds to the tpr, and thus computes the overall
PRO curve by aggregating per-region tpr/fpr values produced by ROC-construction.
def perform_cca(self) -> Tensor:
"""Perform the Connected Component Analysis on the self.target tensor.

Raises:
ValueError: ValueError is raised if self.target doesn't conform with requirements imposed by kornia for
connected component analysis.

Returns:
tuple[Tensor, Tensor]: tuple containing final fpr and tpr values.
Tensor: Components labeled from 0 to N.
"""
target = dim_zero_cat(self.target)
preds = dim_zero_cat(self.preds)

# check and prepare target for labeling via kornia
if target.min() < 0 or target.max() > 1:
Expand All @@ -89,9 +92,27 @@ def _compute(self) -> tuple[Tensor, Tensor]:
else:
cca = connected_components_cpu(target)

preds = preds.flatten()
cca = cca.flatten()
target = target.flatten()
return cca

def compute_pro(self, cca: Tensor) -> tuple[Tensor, Tensor]:
"""Compute the pro/fpr value-pairs until the fpr specified by self.fpr_limit.

It leverages the fact that the overlap corresponds to the tpr, and thus computes the overall
PRO curve by aggregating per-region tpr/fpr values produced by ROC-construction.
If self.inspection_mask is not None, points that are not in the inspection mask will be filtered out.

Returns:
tuple[Tensor, Tensor]: tuple containing final fpr and tpr values.
"""
if self.inspection_mask is not None:
inspection_mask = self.inspection_mask.type(torch.bool).repeat(len(self.target), 1)
target = dim_zero_cat(self.target).flatten()[inspection_mask.flatten()]
preds = dim_zero_cat(self.preds).flatten()[inspection_mask.flatten()]
cca = cca.flatten()[inspection_mask.flatten()]
else:
target = dim_zero_cat(self.target).flatten()
preds = dim_zero_cat(self.preds).flatten()
cca = cca.flatten()

# compute the global fpr-size
fpr: Tensor = roc(preds, target)[0] # only need fpr
Expand Down Expand Up @@ -154,6 +175,20 @@ def _compute(self) -> tuple[Tensor, Tensor]:
fpr /= labels.size(0)
return fpr, tpr

def _compute(self) -> tuple[Tensor, Tensor]:
"""Compute the PRO curve.

First step is to perform the Connected Component Analysis,
Second step is to compute the PRO curve. Points that are outside the inspection mask will
be filtered out.

Returns:
tuple[Tensor, Tensor]: tuple containing final fpr and tpr values.
"""

cca = self.perform_cca()
return self.compute_pro(cca=cca)

def compute(self) -> Tensor:
"""Fist compute PRO curve, then compute and scale area under the curve.

Expand Down