Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add CharRecallPrecision for OCR Task #96

Merged
merged 6 commits into from
Mar 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/en/api/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ Metrics
ROUGE
NaturalImageQualityEvaluator
Perplexity
CharRecallPrecision
KeypointEndPointError
KeypointAUC
KeypointNME
Expand Down
1 change: 1 addition & 0 deletions docs/zh_cn/api/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ Metrics
ROUGE
NaturalImageQualityEvaluator
Perplexity
CharRecallPrecision
KeypointEndPointError
KeypointAUC
KeypointNME
Expand Down
4 changes: 3 additions & 1 deletion mmeval/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .ava_map import AVAMeanAP
from .average_precision import AveragePrecision
from .bleu import BLEU
from .char_recall_precision import CharRecallPrecision
from .coco_detection import COCODetection
from .connectivity_error import ConnectivityError
from .dota_map import DOTAMeanAP
Expand Down Expand Up @@ -46,7 +47,8 @@
'ConnectivityError', 'ROUGE', 'Perplexity', 'KeypointEndPointError',
'KeypointAUC', 'KeypointNME', 'NaturalImageQualityEvaluator',
'WordAccuracy', 'PrecisionRecallF1score',
'SingleLabelPrecisionRecallF1score', 'MultiLabelPrecisionRecallF1score'
'SingleLabelPrecisionRecallF1score', 'MultiLabelPrecisionRecallF1score',
'CharRecallPrecision'
]

_deprecated_msg = (
Expand Down
105 changes: 105 additions & 0 deletions mmeval/metrics/char_recall_precision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Copyright (c) OpenMMLab. All rights reserved.
import re
from difflib import SequenceMatcher
from typing import Dict, Sequence, Tuple

from mmeval.core import BaseMetric


class CharRecallPrecision(BaseMetric):
r"""Calculate the char level recall & precision.

Args:
letter_case (str): There are three options to alter the letter cases

- unchanged: Do not change prediction texts and labels.
- upper: Convert prediction texts and labels into uppercase
characters.
- lower: Convert prediction texts and labels into lowercase
characters.

Usually, it only works for English characters. Defaults to
'unchanged'.
invalid_symbol (str): A regular expression to filter out invalid or
not cared characters. Defaults to '[^A-Za-z0-9\u4e00-\u9fa5]'.
**kwargs: Keyword parameters passed to :class:`BaseMetric`.

Examples:
>>> from mmeval import CharRecallPrecision
>>> metric = CharRecallPrecision()
>>> metric(['helL', 'HEL'], ['hello', 'HELLO'])
{'char_recall': 0.6, 'char_precision': 0.8571428571428571}
>>> metric = CharRecallPrecision(letter_case='upper')
>>> metric(['helL', 'HEL'], ['hello', 'HELLO'])
{'char_recall': 0.7, 'char_precision': 1.0}
"""

def __init__(self,
letter_case: str = 'unchanged',
invalid_symbol: str = '[^A-Za-z0-9\u4e00-\u9fa5]',
**kwargs):
super().__init__(**kwargs)
assert letter_case in ['unchanged', 'upper', 'lower']
self.letter_case = letter_case
self.invalid_symbol = re.compile(invalid_symbol)

def add(self, predictions: Sequence[str], groundtruths: Sequence[str]) -> None: # type: ignore # yapf: disable # noqa: E501
"""Process one batch of data and predictions.

Args:
predictions (list[str]): The prediction texts.
groundtruths (list[str]): The ground truth texts.
"""
for pred, label in zip(predictions, groundtruths):
if self.letter_case in ['upper', 'lower']:
pred = getattr(pred, self.letter_case)()
label = getattr(label, self.letter_case)()
valid_label = self.invalid_symbol.sub('', label)
valid_pred = self.invalid_symbol.sub('', pred)
# number to calculate char level recall & precision
true_positive_char_num = self._cal_true_positive_char(
valid_pred, valid_label)
self._results.append(
(len(valid_label), len(valid_pred), true_positive_char_num))

def compute_metric(self, results: Sequence[Tuple[int, int, int]]) -> Dict:
"""Compute the metrics from processed results.

Args:
results (list[tuple]): The processed results of each batch.

Returns:
Dict: The computed metrics. The keys are the names of the
metrics, and the values are corresponding results.
"""
gt_sum, pred_sum, true_positive_sum = 0.0, 0.0, 0.0
for gt, pred, true_positive in results:
gt_sum += gt
pred_sum += pred
true_positive_sum += true_positive
char_recall = true_positive_sum / max(gt_sum, 1.0)
char_precision = true_positive_sum / max(pred_sum, 1.0)
metric_results = {}
metric_results['recall'] = char_recall
metric_results['precision'] = char_precision
return metric_results

def _cal_true_positive_char(self, pred: str, gt: str) -> int:
"""Calculate correct character number in prediction.

Args:
pred (str): Prediction text.
gt (str): Ground truth text.

Returns:
true_positive_char_num (int): The true positive number.
"""

all_opt = SequenceMatcher(None, pred, gt)
true_positive_char_num = 0
for opt, _, _, s2, e2 in all_opt.get_opcodes():
if opt == 'equal':
true_positive_char_num += (e2 - s2)
else:
pass
return true_positive_char_num
28 changes: 28 additions & 0 deletions tests/test_metrics/test_char_recall_precision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import pytest

from mmeval import CharRecallPrecision


def test_init():
with pytest.raises(AssertionError):
CharRecallPrecision(letter_case='fake')


@pytest.mark.parametrize(
argnames=['letter_case', 'recall', 'precision'],
argvalues=[
('lower', 0.7, 1),
('upper', 0.7, 1),
('unchanged', 0.6, 6.0 / 7),
])
def test_char_recall_precision_metric(letter_case, recall, precision):
metric = CharRecallPrecision(letter_case=letter_case)
res = metric(['helL', 'HEL'], ['hello', 'HELLO'])
assert abs(res['recall'] - recall) < 1e-7
assert abs(res['precision'] - precision) < 1e-7
metric.reset()
for pred, label in zip(['helL', 'HEL'], ['hello', 'HELLO']):
metric.add([pred], [label])
res = metric.compute()
assert abs(res['recall'] - recall) < 1e-7
assert abs(res['precision'] - precision) < 1e-7