Skip to content

Commit

Permalink
Refactor AUPRO metric (#991)
Browse files Browse the repository at this point in the history
* enable inspection_mask for PRO metric

* apply review comments on refactor

* nit docstring

* fix bad refactoring
  • Loading branch information
triet1102 authored Apr 5, 2023
1 parent 27adbae commit 730bdb6
Showing 1 changed file with 29 additions and 10 deletions.
39 changes: 29 additions & 10 deletions src/anomalib/utils/metrics/aupro.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,21 +60,17 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
self.target.append(target)
self.preds.append(preds)

def _compute(self) -> tuple[Tensor, Tensor]:
"""Compute the pro/fpr value-pairs until the fpr specified by self.fpr_limit.
It leverages the fact that the overlap corresponds to the tpr, and thus computes the overall
PRO curve by aggregating per-region tpr/fpr values produced by ROC-construction.
def perform_cca(self) -> Tensor:
"""Perform the Connected Component Analysis on the self.target tensor.
Raises:
ValueError: ValueError is raised if self.target doesn't conform with requirements imposed by kornia for
connected component analysis.
Returns:
tuple[Tensor, Tensor]: tuple containing final fpr and tpr values.
Tensor: Components labeled from 0 to N.
"""
target = dim_zero_cat(self.target)
preds = dim_zero_cat(self.preds)

# check and prepare target for labeling via kornia
if target.min() < 0 or target.max() > 1:
Expand All @@ -89,9 +85,17 @@ def _compute(self) -> tuple[Tensor, Tensor]:
else:
cca = connected_components_cpu(target)

preds = preds.flatten()
cca = cca.flatten()
target = target.flatten()
return cca

def compute_pro(self, cca: Tensor, target: Tensor, preds: Tensor) -> tuple[Tensor, Tensor]:
"""Compute the pro/fpr value-pairs until the fpr specified by self.fpr_limit.
It leverages the fact that the overlap corresponds to the tpr, and thus computes the overall
PRO curve by aggregating per-region tpr/fpr values produced by ROC-construction.
Returns:
tuple[Tensor, Tensor]: tuple containing final fpr and tpr values.
"""

# compute the global fpr-size
fpr: Tensor = roc(preds, target)[0] # only need fpr
Expand Down Expand Up @@ -154,6 +158,21 @@ def _compute(self) -> tuple[Tensor, Tensor]:
fpr /= labels.size(0)
return fpr, tpr

def _compute(self) -> tuple[Tensor, Tensor]:
"""Compute the PRO curve.
Perform the Connected Component Analysis first then compute the PRO curve.
Returns:
tuple[Tensor, Tensor]: tuple containing final fpr and tpr values.
"""

cca = self.perform_cca().flatten()
target = dim_zero_cat(self.target).flatten()
preds = dim_zero_cat(self.preds).flatten()

return self.compute_pro(cca=cca, target=target, preds=preds)

def compute(self) -> Tensor:
"""Fist compute PRO curve, then compute and scale area under the curve.
Expand Down

0 comments on commit 730bdb6

Please sign in to comment.