Skip to content

Commit

Permalink
fix comment
Browse files Browse the repository at this point in the history
  • Loading branch information
Harold-lkk committed Mar 13, 2023
1 parent e0a934e commit dbeb3ea
Showing 1 changed file with 13 additions and 12 deletions.
25 changes: 13 additions & 12 deletions mmeval/metrics/one_minus_norm_edit_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,17 @@ class OneMinusNormEditDistance(BaseMetric):
- unchanged: Do not change prediction texts and labels.
- upper: Convert prediction texts and labels into uppercase
characters.
characters.
- lower: Convert prediction texts and labels into lowercase
characters.
characters.
Usually, it only works for English characters. Defaults to
'unchanged'.
invalid_symbol (str): A regular expression to filter out invalid or
not cared characters. Defaults to '[^A-Z^a-z^0-9^\u4e00-\u9fa5]'.
not cared characters. Defaults to '[^A-Za-z0-9\u4e00-\u9fa5]'.
**kwargs: Keyword parameters passed to :class:`BaseMetric`.
Example:
Examples:
>>> from mmeval import OneMinusNormEditDistance
>>> metric = OneMinusNormEditDistance()
>>> metric(['helL', 'HEL'], ['hello', 'HELLO'])
Expand All @@ -43,22 +43,22 @@ class OneMinusNormEditDistance(BaseMetric):

def __init__(self,
letter_case: str = 'unchanged',
invalid_symbol: str = '[^A-Z^a-z^0-9^\u4e00-\u9fa5]',
invalid_symbol: str = '[^A-Za-z0-9\u4e00-\u9fa5]',
**kwargs):
super().__init__(**kwargs)

assert letter_case in ['unchanged', 'upper', 'lower']
self.letter_case = letter_case
self.invalid_symbol = re.compile(invalid_symbol)

def add(self, predictions: Sequence[str], labels: Sequence[str]): # type: ignore # yapf: disable # noqa: E501
def add(self, predictions: Sequence[str], groundtruths: Sequence[str]): # type: ignore # yapf: disable # noqa: E501
"""Process one batch of data and predictions.
Args:
predictions (list[str]): The prediction texts.
labels (list[str]): The ground truth texts.
groundtruths (list[str]): The ground truth texts.
"""
for pred, label in zip(predictions, labels):
for pred, label in zip(predictions, groundtruths):
if self.letter_case in ['upper', 'lower']:
pred = getattr(pred, self.letter_case)()
label = getattr(label, self.letter_case)()
Expand All @@ -75,11 +75,12 @@ def compute_metric(self, results: List[float]) -> Dict:
Returns:
dict[str, float]: Nested dicts as results.
- 1-N.E.D (float): One minus the normalized edit distance.
- 1-N.E.D (float): One minus the normalized edit distance.
"""
gt_word_num = len(results)
norm_ed_sum = sum(results)
normalized_edit_distance = norm_ed_sum / max(1.0, gt_word_num)
eval_res = {}
eval_res['1-N.E.D'] = 1.0 - normalized_edit_distance
return eval_res
metric_results = {}
metric_results['1-N.E.D'] = 1.0 - normalized_edit_distance
return metric_results

0 comments on commit dbeb3ea

Please sign in to comment.