Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add WordAccuracy for OCR Task #94

Merged
merged 5 commits into from
Mar 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/en/api/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,4 @@ Metrics
KeypointEndPointError
KeypointAUC
KeypointNME
WordAccuracy
1 change: 1 addition & 0 deletions docs/zh_cn/api/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,4 @@ Metrics
KeypointEndPointError
KeypointAUC
KeypointNME
WordAccuracy
4 changes: 3 additions & 1 deletion mmeval/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from .snr import SignalNoiseRatio
from .ssim import StructuralSimilarity
from .voc_map import VOCMeanAP
from .word_accuracy import WordAccuracy

__all__ = [
'Accuracy', 'MeanIoU', 'VOCMeanAP', 'OIDMeanAP', 'EndPointError',
Expand All @@ -42,7 +43,8 @@
'AveragePrecision', 'AVAMeanAP', 'BLEU', 'DOTAMeanAP',
'SumAbsoluteDifferences', 'GradientError', 'MattingMeanSquaredError',
'ConnectivityError', 'ROUGE', 'Perplexity', 'KeypointEndPointError',
'KeypointAUC', 'KeypointNME', 'NaturalImageQualityEvaluator'
'KeypointAUC', 'KeypointNME', 'NaturalImageQualityEvaluator',
'WordAccuracy'
]

_deprecated_msg = (
Expand Down
107 changes: 107 additions & 0 deletions mmeval/metrics/word_accuracy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Copyright (c) OpenMMLab. All rights reserved.
import re
from typing import Dict, List, Sequence, Tuple, Union

from mmeval.core import BaseMetric


class WordAccuracy(BaseMetric):
r"""Calculate the word level accuracy.

Args:
mode (str or list[str]): Options are:

- 'exact': Accuracy at word level.
- 'ignore_case': Accuracy at word level, ignoring letter
case.
- 'ignore_case_symbol': Accuracy at word level, ignoring
letter case and symbol. (Default metric for academic evaluation)

If mode is a list, then metrics in mode will be calculated
separately. Defaults to 'ignore_case_symbol'.
invalid_symbol (str): A regular expression to filter out invalid or
not cared characters. Defaults to '[^A-Za-z0-9\u4e00-\u9fa5]'
**kwargs: Keyword parameters passed to :class:`BaseMetric`.

Examples:
>>> from mmeval import WordAccuracy
>>> metric = WordAccuracy()
>>> metric(['hello', 'hello', 'hello'], ['hello', 'HELLO', '$HELLO$'])
{'ignore_case_symbol_accuracy': 1.0}
>>> metric = WordAccuracy(mode=['exact', 'ignore_case',
>>> 'ignore_case_symbol'])
>>> metric(['hello', 'hello', 'hello'], ['hello', 'HELLO', '$HELLO$'])
{'accuracy': 0.333333333,
'ignore_case_accuracy': 0.666666667,
'ignore_case_symbol_accuracy': 1.0}
"""

def __init__(self,
mode: Union[str, Sequence[str]] = 'ignore_case_symbol',
invalid_symbol: str = '[^A-Za-z0-9\u4e00-\u9fa5]',
**kwargs):
super().__init__(**kwargs)
self.mode = mode
self.invalid_symbol = re.compile(invalid_symbol)
assert isinstance(mode, (str, list))
if isinstance(mode, str):
mode = [mode]
assert all(isinstance(item, str) for item in mode)
self.mode = set(mode) # type: ignore
assert set(self.mode).issubset(
{'exact', 'ignore_case', 'ignore_case_symbol'})

def add(self, predictions: Sequence[str], groundtruths: Sequence[str]) -> None: # type: ignore # yapf: disable # noqa: E501
"""Process one batch of data and predictions.

Args:
predictions (list[str]): The prediction texts.
groundtruths (list[str]): The ground truth texts.
"""
for pred, label in zip(predictions, groundtruths):
num, ignore_case_num, ignore_case_symbol_num = 0, 0, 0
if 'exact' in self.mode:
num = pred == label
if 'ignore_case' in self.mode or 'ignore_case_symbol' in self.mode:
pred_lower = pred.lower()
label_lower = label.lower()
ignore_case_num = pred_lower == label_lower
if 'ignore_case_symbol' in self.mode:
label_lower_ignore = self.invalid_symbol.sub('', label_lower)
pred_lower_ignore = self.invalid_symbol.sub('', pred_lower)
ignore_case_symbol_num =\
label_lower_ignore == pred_lower_ignore
self._results.append(
(num, ignore_case_num, ignore_case_symbol_num))

def compute_metric(self, results: List[Tuple[int, int, int]]) -> Dict:
"""Compute the metrics from processed results.

Args:
results (list[float]): The processed results of each batch.

Returns:
dict[str, float]: Nested dicts as results. Provided keys are:

- accuracy (float): Accuracy at word level.
- ignore_case_accuracy (float): Accuracy at word level, ignoring
letter case.
- ignore_case_symbol_accuracy (float): Accuracy at word level,
ignoring letter case and symbol.
"""
metric_results = {}
gt_word_num = max(len(results), 1.0)
exact_sum, ignore_case_sum, ignore_case_symbol_sum = 0.0, 0.0, 0.0
for exact, ignore_case, ignore_case_symbol in results:
exact_sum += exact
ignore_case_sum += ignore_case
ignore_case_symbol_sum += ignore_case_symbol
if 'exact' in self.mode:
metric_results['accuracy'] = exact_sum / gt_word_num
if 'ignore_case' in self.mode:
metric_results[
'ignore_case_accuracy'] = ignore_case_sum / gt_word_num
if 'ignore_case_symbol' in self.mode:
metric_results['ignore_case_symbol_accuracy'] =\
ignore_case_symbol_sum / gt_word_num
return metric_results
30 changes: 30 additions & 0 deletions tests/test_metrics/test_word_accuracy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import pytest

from mmeval import WordAccuracy


def test_init():
with pytest.raises(AssertionError):
WordAccuracy(mode=1)
with pytest.raises(AssertionError):
WordAccuracy(mode=[1, 2])
with pytest.raises(AssertionError):
WordAccuracy(mode='micro')
metric = WordAccuracy(mode=['ignore_case', 'ignore_case', 'exact'])
assert metric.mode == {'ignore_case', 'ignore_case', 'exact'}


def test_word_accuracy():
metric = WordAccuracy(mode=['exact', 'ignore_case', 'ignore_case_symbol'])
res = metric(['hello', 'hello', 'hello'], ['hello', 'HELLO', '$HELLO$'])
assert abs(res['accuracy'] - 1. / 3) < 1e-7
assert abs(res['ignore_case_accuracy'] - 2. / 3) < 1e-7
assert abs(res['ignore_case_symbol_accuracy'] - 1.0) < 1e-7
Harold-lkk marked this conversation as resolved.
Show resolved Hide resolved
metric.reset()
for pred, label in zip(['hello', 'hello', 'hello'],
['hello', 'HELLO', '$HELLO$']):
metric.add([pred], [label])
res = metric.compute()
assert abs(res['accuracy'] - 1. / 3) < 1e-7
assert abs(res['ignore_case_accuracy'] - 2. / 3) < 1e-7
assert abs(res['ignore_case_symbol_accuracy'] - 1.0) < 1e-7