Skip to content

Commit

Permalink
feat(mmeval/classification): add accuracy metric
Browse files Browse the repository at this point in the history
  • Loading branch information
ice-tong committed Sep 29, 2022
1 parent 21ef6a6 commit 29177a2
Show file tree
Hide file tree
Showing 5 changed files with 396 additions and 0 deletions.
1 change: 1 addition & 0 deletions mmeval/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

# flake8: noqa

from .classification import *
from .core import *
from .segmentation import *
from .version import __version__
5 changes: 5 additions & 0 deletions mmeval/classification/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.

from .accuracy import Accuracy

__all__ = ['Accuracy']
276 changes: 276 additions & 0 deletions mmeval/classification/accuracy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,276 @@
# Copyright (c) OpenMMLab. All rights reserved.

import numpy as np
from typing import Dict, List, Optional, Sequence, Tuple, Union, overload

from mmeval.core.base_metric import BaseMetric
from mmeval.core.dispatcher import dispatch

try:
import torch
except ImportError:
torch = None


def _torch_topk(inputs: 'torch.Tensor',
k: int,
dim: Optional[int] = None) -> Tuple:
"""Invoke the PyTorch topk."""
return inputs.topk(k, dim=dim)


def _numpy_topk(inputs: np.ndarray,
k: int,
axis: Optional[int] = None) -> Tuple:
"""A implementation of numpy top-k.
This implementation returns the values and indices of the k largest
elements along a given axis.
Args:
inputs (nump.ndarray): The input numpy array.
k (int): The k in `top-k`.
axis (int, optional): The axis to sort along.
Returns:
tuple: The values and indices of the k largest elements.
Note:
If PyTorch is available, the ``_torch_topk`` would be used.
"""
if torch is not None:
values, indices = _torch_topk(torch.from_numpy(inputs), k, dim=axis)
return values.numpy(), indices.numpy()

indices = np.argsort(inputs, axis=axis)
indices = np.take(indices, np.arange(k), axis=axis)
values = np.take_along_axis(inputs, indices, axis=axis)
return values, indices


NUMPY_IMPL_HINTS = Tuple[Union[np.ndarray, np.int64], np.int64]
TORCH_IMPL_HINTS = Tuple['torch.Tensor', 'torch.Tensor']


class Accuracy(BaseMetric):
"""Top-k accuracy evaluation metric.
This metric computes the accuracy based on the given topk and thresholds.
Currently, there are 2 implementations of this metric: NumPy and PyTorch.
Which implementation to use is determined by the type of the calling
parameters. e.g. `numpy.ndarray` or `torch.Tensor`.
Args:
topk (int | Sequence[int]): If the predictions in ``topk``
matches the target, the predictions will be regarded as
correct ones. Defaults to 1.
thrs (Sequence[float | None] | float | None): Predictions with scores
under the thresholds are considered negative. None means no
thresholds. Defaults to 0.
**kwargs: Keyword parameters passed to :class:`BaseMetric`.
Examples:
>>> from mmeval import Accuracy
>>> accuracy = Accuracy()
Use NumPy implementation:
>>> import numpy as np
>>> labels = np.asarray([0, 1, 2, 3])
>>> preds = np.asarray([0, 2, 1, 3])
>>> accuracy(preds, labels)
{'top1': 0.5}
Use PyTorch implementation:
>>> import torch
>>> labels = torch.Tensor([0, 1, 2, 3])
>>> preds = torch.Tensor([0, 2, 1, 3])
>>> accuracy(preds, labels)
{'top1': 0.5}
Computing top-k accuracy with specified threold:
>>> labels = np.asarray([0, 1, 2, 3])
>>> preds = np.asarray([
[0.7, 0.1, 0.1, 0.1],
[0.1, 0.3, 0.4, 0.2],
[0.3, 0.4, 0.2, 0.1],
[0.0, 0.0, 0.1, 0.9]])
>>> accuracy = Accuracy(topk=(1, 2, 3))
>>> accuracy(preds, labels)
{'top1': 0.5, 'top2': 0.75, 'top3': 1.0}
>>> accuracy = Accuracy(topk=2, thrs=(0.1, 0.5))
>>> accuracy(preds, labels)
{'top2_thr-0.10': 0.75, 'top2_thr-0.50': 0.5}
Accumulate batch:
>>> for i in range(10):
... labels = torch.randint(0, 4, size=(100, ))
... predicts = torch.randint(0, 4, size=(100, ))
... accuracy.add(predicts, labels)
>>> accuracy.compute() # doctest: +SKIP
"""

def __init__(self,
topk: Union[int, Sequence[int]] = (1, ),
thrs: Union[float, Sequence[Union[float, None]], None] = 0.,
**kwargs) -> None:
super().__init__(**kwargs)

if isinstance(topk, int):
self.topk = (topk, )
else:
self.topk = tuple(topk) # type: ignore

if isinstance(thrs, float) or thrs is None:
self.thrs = (thrs, )
else:
self.thrs = tuple(thrs) # type: ignore

def add(self, predictions: Sequence, labels: Sequence) -> None: # type: ignore # yapf: disable # noqa: E501
"""Add the intermediate results to `self._results`.
Args:
predictions (Sequence): Predictions from the model. It can be
labels (N, ), or scores of every class (N, C).
labels (Sequence): The ground truth labels. It should be (N, ).
"""
for pred, label in zip(predictions, labels):
self._results.append((pred, label))

def _format_metric_results(self, results_per_topk: List[List]) -> Dict:
"""Format the given metric results into a dictionary.
Args:
results_per_topk (list): A list of per topk and thrs accuracy.
Returns:
dict: The formatted dictionary.
"""
metric_results = {}
for k, result_per_topk in zip(self.topk, results_per_topk):
for thr, result_per_thr in zip(self.thrs, result_per_topk):
name = f'top{k}'
if len(self.thrs) > 1:
name += '_no-thr' if thr is None else f'_thr-{thr:.2f}'
metric_results[name] = float(result_per_thr)
return metric_results

@overload # type: ignore
@dispatch
def _compute_metric(self, predictions: Sequence['torch.Tensor'],
labels: Sequence['torch.Tensor']) -> List[List]:
"""A PyTorch implementation that computes the accuracy metric."""
# Concatenating the intermediate results arcoss all ranks.
labels = torch.stack(labels)
predictions = torch.stack(predictions)
total_length = labels.size(0)

# In the case where the prediction is a label (N, ), the accuracy is
# calculated directly without considering `topk` and `thrs`.
if predictions.ndim == 1:
correct = (predictions.int() == labels).sum(0, keepdim=True)
acc = correct.float() / total_length
return [[acc, ], ] # yapf: disable

# compute the max topk
maxk = max(self.topk)
# NOTE: The torch.topk is non-deterministic with duplicates values.
# See: https://github.com/pytorch/pytorch/issues/27542
pred_score, pred_label = _torch_topk(predictions, maxk, dim=1)
pred_label = pred_label.t()

# Broadcast `labels` to the shape of `pred_label` and then compute
# correct tensor.
correct = (pred_label == labels.view(1, -1).expand_as(pred_label))

# compute the accuracy corresponding to all topk and thrs
results_per_topk = []
for k in self.topk:
results_per_thr = []
for thr in self.thrs:
# Only prediction socres larger than thr are counted as correct
if thr is not None:
thr_correct = correct & (pred_score.t() > thr)
else:
thr_correct = correct
topk_thr_correct = thr_correct[:k].reshape(-1).sum(
0, keepdim=True)
acc = topk_thr_correct.float() / total_length
results_per_thr.append(acc)
results_per_topk.append(results_per_thr)
return results_per_topk

@dispatch
def _compute_metric(self, predictions: Sequence[Union[np.ndarray,
np.int64]],
labels: Sequence[np.int64]) -> List[List]:
"""A NumPy implementation that computes the accuracy metric."""
# Concatenating the intermediate results arcoss all ranks.
predictions = np.stack(predictions)
labels = np.stack(labels)
total_length = labels.size

# In the case where the prediction is a label (N, ), the accuracy is
# calculated directly without considering `topk` and `thrs`.
if predictions.ndim == 1:
predictions = predictions.astype(np.int32)
correct = (predictions == labels).sum(0, keepdims=True)
acc = correct / total_length
return [[acc, ], ] # yapf: disable

# compute the max topk
maxk = max(self.topk)
pred_score, pred_label = _numpy_topk(predictions, maxk, 1)
pred_label = pred_label.T

# broadcast `labels` to the shape of `pred_label`
labels = np.broadcast_to(labels.reshape(1, -1), pred_label.shape)
# compute correct tensor
correct = (pred_label == labels)

# compute the accuracy corresponding to all topk and thrs
results_per_topk = []
for k in self.topk:
results_per_thr = []
for thr in self.thrs:
# Only socres greater than thr are counted as correct.
if thr is not None:
thr_correct = correct & (pred_score.T > thr)
else:
thr_correct = correct
topk_thr_correct = thr_correct[:k].reshape(-1).sum(
0, keepdims=True)
acc = topk_thr_correct / total_length
results_per_thr.append(acc)
results_per_topk.append(results_per_thr)
return results_per_topk

def compute_metric(
self, results: List[Union[NUMPY_IMPL_HINTS, TORCH_IMPL_HINTS]]
) -> Dict[str, float]:
"""Compute the accuracy metric.
Currently, there are 2 implementations of this method: NumPy and
PyTorch. Which implementation to use is determined by the type of the
calling parameters. e.g. `numpy.ndarray` or `torch.Tensor`.
This method would be invoked in `BaseMetric.compute` after distributed
synchronization.
Args:
results (List[Union[NUMPY_IMPL_HINTS, TORCH_IMPL_HINTS]]): A list
of tuples that consisting the prediction and label. This list
has already been synced across all ranks.
Returns:
Dict[str, float]: The computed accuracy metric.
"""
predictions = [res[0] for res in results]
labels = [res[1] for res in results]
metric_results = self._compute_metric(predictions, labels)
return self._format_metric_results(metric_results)
3 changes: 3 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,6 @@ SPLIT_BEFORE_EXPRESSION_AFTER_OPENING_PAREN = true
[codespell]
skip = *.ipynb
quiet-level = 3

[mypy]
allow_redefinition = True
Loading

0 comments on commit 29177a2

Please sign in to comment.