Skip to content

Commit

Permalink
[Feature] Add WordAccuracy for OCR Task (#94)
Browse files Browse the repository at this point in the history
* [Feature] Add WordAccuracy for OCR Task

* add api and fix comment

* fix doc comment

* fix comment

* fix comment
  • Loading branch information
Harold-lkk authored Mar 6, 2023
1 parent a05c685 commit a03e046
Show file tree
Hide file tree
Showing 5 changed files with 142 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/en/api/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,4 @@ Metrics
KeypointEndPointError
KeypointAUC
KeypointNME
WordAccuracy
1 change: 1 addition & 0 deletions docs/zh_cn/api/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,4 @@ Metrics
KeypointEndPointError
KeypointAUC
KeypointNME
WordAccuracy
4 changes: 3 additions & 1 deletion mmeval/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from .snr import SignalNoiseRatio
from .ssim import StructuralSimilarity
from .voc_map import VOCMeanAP
from .word_accuracy import WordAccuracy

__all__ = [
'Accuracy', 'MeanIoU', 'VOCMeanAP', 'OIDMeanAP', 'EndPointError',
Expand All @@ -42,7 +43,8 @@
'AveragePrecision', 'AVAMeanAP', 'BLEU', 'DOTAMeanAP',
'SumAbsoluteDifferences', 'GradientError', 'MattingMeanSquaredError',
'ConnectivityError', 'ROUGE', 'Perplexity', 'KeypointEndPointError',
'KeypointAUC', 'KeypointNME', 'NaturalImageQualityEvaluator'
'KeypointAUC', 'KeypointNME', 'NaturalImageQualityEvaluator',
'WordAccuracy'
]

_deprecated_msg = (
Expand Down
107 changes: 107 additions & 0 deletions mmeval/metrics/word_accuracy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Copyright (c) OpenMMLab. All rights reserved.
import re
from typing import Dict, List, Sequence, Tuple, Union

from mmeval.core import BaseMetric


class WordAccuracy(BaseMetric):
r"""Calculate the word level accuracy.
Args:
mode (str or list[str]): Options are:
- 'exact': Accuracy at word level.
- 'ignore_case': Accuracy at word level, ignoring letter
case.
- 'ignore_case_symbol': Accuracy at word level, ignoring
letter case and symbol. (Default metric for academic evaluation)
If mode is a list, then metrics in mode will be calculated
separately. Defaults to 'ignore_case_symbol'.
invalid_symbol (str): A regular expression to filter out invalid or
not cared characters. Defaults to '[^A-Za-z0-9\u4e00-\u9fa5]'
**kwargs: Keyword parameters passed to :class:`BaseMetric`.
Examples:
>>> from mmeval import WordAccuracy
>>> metric = WordAccuracy()
>>> metric(['hello', 'hello', 'hello'], ['hello', 'HELLO', '$HELLO$'])
{'ignore_case_symbol_accuracy': 1.0}
>>> metric = WordAccuracy(mode=['exact', 'ignore_case',
>>> 'ignore_case_symbol'])
>>> metric(['hello', 'hello', 'hello'], ['hello', 'HELLO', '$HELLO$'])
{'accuracy': 0.333333333,
'ignore_case_accuracy': 0.666666667,
'ignore_case_symbol_accuracy': 1.0}
"""

def __init__(self,
mode: Union[str, Sequence[str]] = 'ignore_case_symbol',
invalid_symbol: str = '[^A-Za-z0-9\u4e00-\u9fa5]',
**kwargs):
super().__init__(**kwargs)
self.mode = mode
self.invalid_symbol = re.compile(invalid_symbol)
assert isinstance(mode, (str, list))
if isinstance(mode, str):
mode = [mode]
assert all(isinstance(item, str) for item in mode)
self.mode = set(mode) # type: ignore
assert set(self.mode).issubset(
{'exact', 'ignore_case', 'ignore_case_symbol'})

def add(self, predictions: Sequence[str], groundtruths: Sequence[str]) -> None: # type: ignore # yapf: disable # noqa: E501
"""Process one batch of data and predictions.
Args:
predictions (list[str]): The prediction texts.
groundtruths (list[str]): The ground truth texts.
"""
for pred, label in zip(predictions, groundtruths):
num, ignore_case_num, ignore_case_symbol_num = 0, 0, 0
if 'exact' in self.mode:
num = pred == label
if 'ignore_case' in self.mode or 'ignore_case_symbol' in self.mode:
pred_lower = pred.lower()
label_lower = label.lower()
ignore_case_num = pred_lower == label_lower
if 'ignore_case_symbol' in self.mode:
label_lower_ignore = self.invalid_symbol.sub('', label_lower)
pred_lower_ignore = self.invalid_symbol.sub('', pred_lower)
ignore_case_symbol_num =\
label_lower_ignore == pred_lower_ignore
self._results.append(
(num, ignore_case_num, ignore_case_symbol_num))

def compute_metric(self, results: List[Tuple[int, int, int]]) -> Dict:
"""Compute the metrics from processed results.
Args:
results (list[float]): The processed results of each batch.
Returns:
dict[str, float]: Nested dicts as results. Provided keys are:
- accuracy (float): Accuracy at word level.
- ignore_case_accuracy (float): Accuracy at word level, ignoring
letter case.
- ignore_case_symbol_accuracy (float): Accuracy at word level,
ignoring letter case and symbol.
"""
metric_results = {}
gt_word_num = max(len(results), 1.0)
exact_sum, ignore_case_sum, ignore_case_symbol_sum = 0.0, 0.0, 0.0
for exact, ignore_case, ignore_case_symbol in results:
exact_sum += exact
ignore_case_sum += ignore_case
ignore_case_symbol_sum += ignore_case_symbol
if 'exact' in self.mode:
metric_results['accuracy'] = exact_sum / gt_word_num
if 'ignore_case' in self.mode:
metric_results[
'ignore_case_accuracy'] = ignore_case_sum / gt_word_num
if 'ignore_case_symbol' in self.mode:
metric_results['ignore_case_symbol_accuracy'] =\
ignore_case_symbol_sum / gt_word_num
return metric_results
30 changes: 30 additions & 0 deletions tests/test_metrics/test_word_accuracy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import pytest

from mmeval import WordAccuracy


def test_init():
with pytest.raises(AssertionError):
WordAccuracy(mode=1)
with pytest.raises(AssertionError):
WordAccuracy(mode=[1, 2])
with pytest.raises(AssertionError):
WordAccuracy(mode='micro')
metric = WordAccuracy(mode=['ignore_case', 'ignore_case', 'exact'])
assert metric.mode == {'ignore_case', 'ignore_case', 'exact'}


def test_word_accuracy():
metric = WordAccuracy(mode=['exact', 'ignore_case', 'ignore_case_symbol'])
res = metric(['hello', 'hello', 'hello'], ['hello', 'HELLO', '$HELLO$'])
assert abs(res['accuracy'] - 1. / 3) < 1e-7
assert abs(res['ignore_case_accuracy'] - 2. / 3) < 1e-7
assert abs(res['ignore_case_symbol_accuracy'] - 1.0) < 1e-7
metric.reset()
for pred, label in zip(['hello', 'hello', 'hello'],
['hello', 'HELLO', '$HELLO$']):
metric.add([pred], [label])
res = metric.compute()
assert abs(res['accuracy'] - 1. / 3) < 1e-7
assert abs(res['ignore_case_accuracy'] - 2. / 3) < 1e-7
assert abs(res['ignore_case_symbol_accuracy'] - 1.0) < 1e-7

0 comments on commit a03e046

Please sign in to comment.