Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor AUPRO metric #991

Merged
merged 9 commits into from
Apr 5, 2023
39 changes: 29 additions & 10 deletions src/anomalib/utils/metrics/aupro.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,21 +60,17 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
self.target.append(target)
self.preds.append(preds)

def _compute(self) -> tuple[Tensor, Tensor]:
"""Compute the pro/fpr value-pairs until the fpr specified by self.fpr_limit.

It leverages the fact that the overlap corresponds to the tpr, and thus computes the overall
PRO curve by aggregating per-region tpr/fpr values produced by ROC-construction.
def perform_cca(self) -> Tensor:
"""Perform the Connected Component Analysis on the self.target tensor.

Raises:
ValueError: ValueError is raised if self.target doesn't conform with requirements imposed by kornia for
connected component analysis.

Returns:
tuple[Tensor, Tensor]: tuple containing final fpr and tpr values.
Tensor: Components labeled from 0 to N.
"""
target = dim_zero_cat(self.target)
preds = dim_zero_cat(self.preds)

# check and prepare target for labeling via kornia
if target.min() < 0 or target.max() > 1:
Expand All @@ -89,9 +85,17 @@ def _compute(self) -> tuple[Tensor, Tensor]:
else:
cca = connected_components_cpu(target)

preds = preds.flatten()
cca = cca.flatten()
target = target.flatten()
return cca

def compute_pro(self, cca: Tensor, target: Tensor, preds: Tensor) -> tuple[Tensor, Tensor]:
"""Compute the pro/fpr value-pairs until the fpr specified by self.fpr_limit.

It leverages the fact that the overlap corresponds to the tpr, and thus computes the overall
PRO curve by aggregating per-region tpr/fpr values produced by ROC-construction.

Returns:
tuple[Tensor, Tensor]: tuple containing final fpr and tpr values.
"""

# compute the global fpr-size
fpr: Tensor = roc(preds, target)[0] # only need fpr
Expand Down Expand Up @@ -154,6 +158,21 @@ def _compute(self) -> tuple[Tensor, Tensor]:
fpr /= labels.size(0)
return fpr, tpr

def _compute(self) -> tuple[Tensor, Tensor]:
"""Compute the PRO curve.

Perform the Connected Component Analysis first then compute the PRO curve.

Returns:
tuple[Tensor, Tensor]: tuple containing final fpr and tpr values.
"""

cca = self.perform_cca().flatten()
target = dim_zero_cat(self.target).flatten()
preds = dim_zero_cat(self.preds).flatten()

return self.compute_pro(cca=cca, target=target, preds=preds)

def compute(self) -> Tensor:
"""Fist compute PRO curve, then compute and scale area under the curve.

Expand Down