diff --git a/src/anomalib/utils/metrics/aupro.py b/src/anomalib/utils/metrics/aupro.py index 1cfc8293a8..ba8608fd3d 100644 --- a/src/anomalib/utils/metrics/aupro.py +++ b/src/anomalib/utils/metrics/aupro.py @@ -60,21 +60,17 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore self.target.append(target) self.preds.append(preds) - def _compute(self) -> tuple[Tensor, Tensor]: - """Compute the pro/fpr value-pairs until the fpr specified by self.fpr_limit. - - It leverages the fact that the overlap corresponds to the tpr, and thus computes the overall - PRO curve by aggregating per-region tpr/fpr values produced by ROC-construction. + def perform_cca(self) -> Tensor: + """Perform the Connected Component Analysis on the self.target tensor. Raises: ValueError: ValueError is raised if self.target doesn't conform with requirements imposed by kornia for connected component analysis. Returns: - tuple[Tensor, Tensor]: tuple containing final fpr and tpr values. + Tensor: Components labeled from 0 to N. """ target = dim_zero_cat(self.target) - preds = dim_zero_cat(self.preds) # check and prepare target for labeling via kornia if target.min() < 0 or target.max() > 1: @@ -89,9 +85,17 @@ def _compute(self) -> tuple[Tensor, Tensor]: else: cca = connected_components_cpu(target) - preds = preds.flatten() - cca = cca.flatten() - target = target.flatten() + return cca + + def compute_pro(self, cca: Tensor, target: Tensor, preds: Tensor) -> tuple[Tensor, Tensor]: + """Compute the pro/fpr value-pairs until the fpr specified by self.fpr_limit. + + It leverages the fact that the overlap corresponds to the tpr, and thus computes the overall + PRO curve by aggregating per-region tpr/fpr values produced by ROC-construction. + + Returns: + tuple[Tensor, Tensor]: tuple containing final fpr and tpr values. + """ # compute the global fpr-size fpr: Tensor = roc(preds, target)[0] # only need fpr @@ -154,6 +158,21 @@ def _compute(self) -> tuple[Tensor, Tensor]: fpr /= labels.size(0) return fpr, tpr + def _compute(self) -> tuple[Tensor, Tensor]: + """Compute the PRO curve. + + Perform the Connected Component Analysis first then compute the PRO curve. + + Returns: + tuple[Tensor, Tensor]: tuple containing final fpr and tpr values. + """ + + cca = self.perform_cca().flatten() + target = dim_zero_cat(self.target).flatten() + preds = dim_zero_cat(self.preds).flatten() + + return self.compute_pro(cca=cca, target=target, preds=preds) + def compute(self) -> Tensor: """Fist compute PRO curve, then compute and scale area under the curve.