Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add OneMinusNormEditDistance for OCR Task #95

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/en/api/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ Metrics
DOTAMeanAP
ROUGE
NaturalImageQualityEvaluator
OneMinusNormEditDistance
Perplexity
KeypointEndPointError
KeypointAUC
Expand Down
1 change: 1 addition & 0 deletions docs/zh_cn/api/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ Metrics
DOTAMeanAP
ROUGE
NaturalImageQualityEvaluator
OneMinusNormEditDistance
Perplexity
KeypointEndPointError
KeypointAUC
Expand Down
4 changes: 3 additions & 1 deletion mmeval/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .mse import MeanSquaredError
from .niqe import NaturalImageQualityEvaluator
from .oid_map import OIDMeanAP
from .one_minus_norm_edit_distance import OneMinusNormEditDistance
from .pck_accuracy import JhmdbPCKAccuracy, MpiiPCKAccuracy, PCKAccuracy
from .perplexity import Perplexity
from .precision_recall_f1score import (MultiLabelPrecisionRecallF1score,
Expand All @@ -46,7 +47,8 @@
'ConnectivityError', 'ROUGE', 'Perplexity', 'KeypointEndPointError',
'KeypointAUC', 'KeypointNME', 'NaturalImageQualityEvaluator',
'WordAccuracy', 'PrecisionRecallF1score',
'SingleLabelPrecisionRecallF1score', 'MultiLabelPrecisionRecallF1score'
'SingleLabelPrecisionRecallF1score', 'MultiLabelPrecisionRecallF1score',
'OneMinusNormEditDistance'
]

_deprecated_msg = (
Expand Down
86 changes: 86 additions & 0 deletions mmeval/metrics/one_minus_norm_edit_distance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright (c) OpenMMLab. All rights reserved.
import re
from typing import TYPE_CHECKING, Dict, List, Sequence

from mmeval.core import BaseMetric
from mmeval.utils import try_import

if TYPE_CHECKING:
from rapidfuzz.distance import Levenshtein
else:
distance = try_import('rapidfuzz.distance')
if distance is not None:
Levenshtein = distance.Levenshtein


class OneMinusNormEditDistance(BaseMetric):
r"""One minus NED metric for text recognition task.

Args:
letter_case (str): There are three options to alter the letter cases

- unchanged: Do not change prediction texts and labels.
- upper: Convert prediction texts and labels into uppercase
characters.
- lower: Convert prediction texts and labels into lowercase
characters.

Usually, it only works for English characters. Defaults to
'unchanged'.
invalid_symbol (str): A regular expression to filter out invalid or
not cared characters. Defaults to '[^A-Za-z0-9\u4e00-\u9fa5]'.
**kwargs: Keyword parameters passed to :class:`BaseMetric`.

Examples:
>>> from mmeval import OneMinusNormEditDistance
>>> metric = OneMinusNormEditDistance()
>>> metric(['helL', 'HEL'], ['hello', 'HELLO'])
{'1-N.E.D': 0.6}
>>> metric = OneMinusNormEditDistance(letter_case='upper')
>>> metric(['helL', 'HEL'], ['hello', 'HELLO'])
{'1-N.E.D': 0.7}
"""

def __init__(self,
letter_case: str = 'unchanged',
invalid_symbol: str = '[^A-Za-z0-9\u4e00-\u9fa5]',
**kwargs):
super().__init__(**kwargs)

Copy link
Collaborator

@zhouzaida zhouzaida Mar 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rapidfuzz is not added in runtime.txt so the metric will throw an error when calling add method. The solution is to check it in __init__. Refer to

raise ImportError(f'For availability of {self.__class__.__name__},'

assert letter_case in ['unchanged', 'upper', 'lower']
self.letter_case = letter_case
self.invalid_symbol = re.compile(invalid_symbol)

def add(self, predictions: Sequence[str], groundtruths: Sequence[str]): # type: ignore # yapf: disable # noqa: E501
"""Process one batch of data and predictions.

Args:
predictions (list[str]): The prediction texts.
groundtruths (list[str]): The ground truth texts.
"""
for pred, label in zip(predictions, groundtruths):
if self.letter_case in ['upper', 'lower']:
pred = getattr(pred, self.letter_case)()
label = getattr(label, self.letter_case)()
label = self.invalid_symbol.sub('', label)
pred = self.invalid_symbol.sub('', pred)
norm_ed = Levenshtein.normalized_distance(pred, label)
self._results.append(norm_ed)

def compute_metric(self, results: List[float]) -> Dict:
"""Compute the metrics from processed results.

Args:
results (list[float]): The processed results of each batch.

Returns:
dict[str, float]: Nested dicts as results.

- 1-N.E.D (float): One minus the normalized edit distance.
"""
gt_word_num = len(results)
norm_ed_sum = sum(results)
normalized_edit_distance = norm_ed_sum / max(1.0, gt_word_num)
C1rN09 marked this conversation as resolved.
Show resolved Hide resolved
metric_results = {}
metric_results['1-N.E.D'] = 1.0 - normalized_edit_distance
return metric_results
1 change: 1 addition & 0 deletions requirements/optional.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
opencv-python!=4.5.5.62,!=4.5.5.64
pycocotools
rapidfuzz
shapely
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ SPLIT_BEFORE_EXPRESSION_AFTER_OPENING_PAREN = true
[codespell]
skip = *.ipynb
quiet-level = 3
ignore-words-list = dota, rouge
ignore-words-list = dota, rouge, ned

[mypy]
allow_redefinition = True
22 changes: 22 additions & 0 deletions tests/test_metrics/test_one_minus_norm_edit_distance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import pytest

from mmeval import OneMinusNormEditDistance


def test_init():
with pytest.raises(AssertionError):
OneMinusNormEditDistance(letter_case='fake')


@pytest.mark.parametrize(
argnames=['letter_case', 'expected'],
argvalues=[('unchanged', 0.6), ('upper', 0.7), ('lower', 0.7)])
def test_one_minus_norm_edit_distance_metric(letter_case, expected):
metric = OneMinusNormEditDistance(letter_case=letter_case)
res = metric(['helL', 'HEL'], ['hello', 'HELLO'])
assert abs(res['1-N.E.D'] - expected) < 1e-7
metric.reset()
for pred, label in zip(['helL', 'HEL'], ['hello', 'HELLO']):
metric.add([pred], [label])
res = metric.compute()
assert abs(res['1-N.E.D'] - expected) < 1e-7