diff --git a/mmeval/metrics/one_minus_norm_edit_distance.py b/mmeval/metrics/one_minus_norm_edit_distance.py index 5c5e5be8..ddd58df5 100644 --- a/mmeval/metrics/one_minus_norm_edit_distance.py +++ b/mmeval/metrics/one_minus_norm_edit_distance.py @@ -21,17 +21,17 @@ class OneMinusNormEditDistance(BaseMetric): - unchanged: Do not change prediction texts and labels. - upper: Convert prediction texts and labels into uppercase - characters. + characters. - lower: Convert prediction texts and labels into lowercase - characters. + characters. Usually, it only works for English characters. Defaults to 'unchanged'. invalid_symbol (str): A regular expression to filter out invalid or - not cared characters. Defaults to '[^A-Z^a-z^0-9^\u4e00-\u9fa5]'. + not cared characters. Defaults to '[^A-Za-z0-9\u4e00-\u9fa5]'. **kwargs: Keyword parameters passed to :class:`BaseMetric`. - Example: + Examples: >>> from mmeval import OneMinusNormEditDistance >>> metric = OneMinusNormEditDistance() >>> metric(['helL', 'HEL'], ['hello', 'HELLO']) @@ -43,7 +43,7 @@ class OneMinusNormEditDistance(BaseMetric): def __init__(self, letter_case: str = 'unchanged', - invalid_symbol: str = '[^A-Z^a-z^0-9^\u4e00-\u9fa5]', + invalid_symbol: str = '[^A-Za-z0-9\u4e00-\u9fa5]', **kwargs): super().__init__(**kwargs) @@ -51,14 +51,14 @@ def __init__(self, self.letter_case = letter_case self.invalid_symbol = re.compile(invalid_symbol) - def add(self, predictions: Sequence[str], labels: Sequence[str]): # type: ignore # yapf: disable # noqa: E501 + def add(self, predictions: Sequence[str], groundtruths: Sequence[str]): # type: ignore # yapf: disable # noqa: E501 """Process one batch of data and predictions. Args: predictions (list[str]): The prediction texts. - labels (list[str]): The ground truth texts. + groundtruths (list[str]): The ground truth texts. """ - for pred, label in zip(predictions, labels): + for pred, label in zip(predictions, groundtruths): if self.letter_case in ['upper', 'lower']: pred = getattr(pred, self.letter_case)() label = getattr(label, self.letter_case)() @@ -80,6 +80,6 @@ def compute_metric(self, results: List[float]) -> Dict: gt_word_num = len(results) norm_ed_sum = sum(results) normalized_edit_distance = norm_ed_sum / max(1.0, gt_word_num) - eval_res = {} - eval_res['1-N.E.D'] = 1.0 - normalized_edit_distance - return eval_res + metric_results = {} + metric_results['1-N.E.D'] = 1.0 - normalized_edit_distance + return metric_results