diff --git a/docs/en/api/metrics.rst b/docs/en/api/metrics.rst index b5c4f231..c0d7761f 100644 --- a/docs/en/api/metrics.rst +++ b/docs/en/api/metrics.rst @@ -53,3 +53,4 @@ Metrics KeypointEndPointError KeypointAUC KeypointNME + WordAccuracy diff --git a/docs/zh_cn/api/metrics.rst b/docs/zh_cn/api/metrics.rst index b5c4f231..c0d7761f 100644 --- a/docs/zh_cn/api/metrics.rst +++ b/docs/zh_cn/api/metrics.rst @@ -53,3 +53,4 @@ Metrics KeypointEndPointError KeypointAUC KeypointNME + WordAccuracy diff --git a/mmeval/metrics/__init__.py b/mmeval/metrics/__init__.py index d63fed6e..9ae21aef 100644 --- a/mmeval/metrics/__init__.py +++ b/mmeval/metrics/__init__.py @@ -32,6 +32,7 @@ from .snr import SignalNoiseRatio from .ssim import StructuralSimilarity from .voc_map import VOCMeanAP +from .word_accuracy import WordAccuracy __all__ = [ 'Accuracy', 'MeanIoU', 'VOCMeanAP', 'OIDMeanAP', 'EndPointError', @@ -42,7 +43,8 @@ 'AveragePrecision', 'AVAMeanAP', 'BLEU', 'DOTAMeanAP', 'SumAbsoluteDifferences', 'GradientError', 'MattingMeanSquaredError', 'ConnectivityError', 'ROUGE', 'Perplexity', 'KeypointEndPointError', - 'KeypointAUC', 'KeypointNME', 'NaturalImageQualityEvaluator' + 'KeypointAUC', 'KeypointNME', 'NaturalImageQualityEvaluator', + 'WordAccuracy' ] _deprecated_msg = ( diff --git a/mmeval/metrics/word_accuracy.py b/mmeval/metrics/word_accuracy.py new file mode 100644 index 00000000..4a867872 --- /dev/null +++ b/mmeval/metrics/word_accuracy.py @@ -0,0 +1,107 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import re +from typing import Dict, List, Sequence, Tuple, Union + +from mmeval.core import BaseMetric + + +class WordAccuracy(BaseMetric): + r"""Calculate the word level accuracy. + + Args: + mode (str or list[str]): Options are: + + - 'exact': Accuracy at word level. + - 'ignore_case': Accuracy at word level, ignoring letter + case. + - 'ignore_case_symbol': Accuracy at word level, ignoring + letter case and symbol. (Default metric for academic evaluation) + + If mode is a list, then metrics in mode will be calculated + separately. Defaults to 'ignore_case_symbol'. + invalid_symbol (str): A regular expression to filter out invalid or + not cared characters. Defaults to '[^A-Za-z0-9\u4e00-\u9fa5]' + **kwargs: Keyword parameters passed to :class:`BaseMetric`. + + Examples: + >>> from mmeval import WordAccuracy + >>> metric = WordAccuracy() + >>> metric(['hello', 'hello', 'hello'], ['hello', 'HELLO', '$HELLO$']) + {'ignore_case_symbol_accuracy': 1.0} + >>> metric = WordAccuracy(mode=['exact', 'ignore_case', + >>> 'ignore_case_symbol']) + >>> metric(['hello', 'hello', 'hello'], ['hello', 'HELLO', '$HELLO$']) + {'accuracy': 0.333333333, + 'ignore_case_accuracy': 0.666666667, + 'ignore_case_symbol_accuracy': 1.0} + """ + + def __init__(self, + mode: Union[str, Sequence[str]] = 'ignore_case_symbol', + invalid_symbol: str = '[^A-Za-z0-9\u4e00-\u9fa5]', + **kwargs): + super().__init__(**kwargs) + self.mode = mode + self.invalid_symbol = re.compile(invalid_symbol) + assert isinstance(mode, (str, list)) + if isinstance(mode, str): + mode = [mode] + assert all(isinstance(item, str) for item in mode) + self.mode = set(mode) # type: ignore + assert set(self.mode).issubset( + {'exact', 'ignore_case', 'ignore_case_symbol'}) + + def add(self, predictions: Sequence[str], groundtruths: Sequence[str]) -> None: # type: ignore # yapf: disable # noqa: E501 + """Process one batch of data and predictions. + + Args: + predictions (list[str]): The prediction texts. + groundtruths (list[str]): The ground truth texts. + """ + for pred, label in zip(predictions, groundtruths): + num, ignore_case_num, ignore_case_symbol_num = 0, 0, 0 + if 'exact' in self.mode: + num = pred == label + if 'ignore_case' in self.mode or 'ignore_case_symbol' in self.mode: + pred_lower = pred.lower() + label_lower = label.lower() + ignore_case_num = pred_lower == label_lower + if 'ignore_case_symbol' in self.mode: + label_lower_ignore = self.invalid_symbol.sub('', label_lower) + pred_lower_ignore = self.invalid_symbol.sub('', pred_lower) + ignore_case_symbol_num =\ + label_lower_ignore == pred_lower_ignore + self._results.append( + (num, ignore_case_num, ignore_case_symbol_num)) + + def compute_metric(self, results: List[Tuple[int, int, int]]) -> Dict: + """Compute the metrics from processed results. + + Args: + results (list[float]): The processed results of each batch. + + Returns: + dict[str, float]: Nested dicts as results. Provided keys are: + + - accuracy (float): Accuracy at word level. + - ignore_case_accuracy (float): Accuracy at word level, ignoring + letter case. + - ignore_case_symbol_accuracy (float): Accuracy at word level, + ignoring letter case and symbol. + """ + metric_results = {} + gt_word_num = max(len(results), 1.0) + exact_sum, ignore_case_sum, ignore_case_symbol_sum = 0.0, 0.0, 0.0 + for exact, ignore_case, ignore_case_symbol in results: + exact_sum += exact + ignore_case_sum += ignore_case + ignore_case_symbol_sum += ignore_case_symbol + if 'exact' in self.mode: + metric_results['accuracy'] = exact_sum / gt_word_num + if 'ignore_case' in self.mode: + metric_results[ + 'ignore_case_accuracy'] = ignore_case_sum / gt_word_num + if 'ignore_case_symbol' in self.mode: + metric_results['ignore_case_symbol_accuracy'] =\ + ignore_case_symbol_sum / gt_word_num + return metric_results diff --git a/tests/test_metrics/test_word_accuracy.py b/tests/test_metrics/test_word_accuracy.py new file mode 100644 index 00000000..db9ee883 --- /dev/null +++ b/tests/test_metrics/test_word_accuracy.py @@ -0,0 +1,30 @@ +import pytest + +from mmeval import WordAccuracy + + +def test_init(): + with pytest.raises(AssertionError): + WordAccuracy(mode=1) + with pytest.raises(AssertionError): + WordAccuracy(mode=[1, 2]) + with pytest.raises(AssertionError): + WordAccuracy(mode='micro') + metric = WordAccuracy(mode=['ignore_case', 'ignore_case', 'exact']) + assert metric.mode == {'ignore_case', 'ignore_case', 'exact'} + + +def test_word_accuracy(): + metric = WordAccuracy(mode=['exact', 'ignore_case', 'ignore_case_symbol']) + res = metric(['hello', 'hello', 'hello'], ['hello', 'HELLO', '$HELLO$']) + assert abs(res['accuracy'] - 1. / 3) < 1e-7 + assert abs(res['ignore_case_accuracy'] - 2. / 3) < 1e-7 + assert abs(res['ignore_case_symbol_accuracy'] - 1.0) < 1e-7 + metric.reset() + for pred, label in zip(['hello', 'hello', 'hello'], + ['hello', 'HELLO', '$HELLO$']): + metric.add([pred], [label]) + res = metric.compute() + assert abs(res['accuracy'] - 1. / 3) < 1e-7 + assert abs(res['ignore_case_accuracy'] - 2. / 3) < 1e-7 + assert abs(res['ignore_case_symbol_accuracy'] - 1.0) < 1e-7