diff --git a/README.md b/README.md
index 1acff842e..5906cc808 100644
--- a/README.md
+++ b/README.md
@@ -151,6 +151,7 @@ Supported algorithms:
- [x] [ABINet](configs/textrecog/abinet/README.md) (CVPR'2021)
- [x] [ASTER](configs/textrecog/aster/README.md) (TPAMI'2018)
- [x] [CRNN](configs/textrecog/crnn/README.md) (TPAMI'2016)
+- [x] [MAERec](configs/textrecog/maerec/README.md) (ICCV'2023)
- [x] [MASTER](configs/textrecog/master/README.md) (PR'2021)
- [x] [NRTR](configs/textrecog/nrtr/README.md) (ICDAR'2019)
- [x] [RobustScanner](configs/textrecog/robust_scanner/README.md) (ECCV'2020)
diff --git a/README_zh-CN.md b/README_zh-CN.md
index c38839637..177cb93b1 100644
--- a/README_zh-CN.md
+++ b/README_zh-CN.md
@@ -150,6 +150,7 @@ mim install -e .
- [x] [ABINet](configs/textrecog/abinet/README.md) (CVPR'2021)
- [x] [ASTER](configs/textrecog/aster/README.md) (TPAMI'2018)
- [x] [CRNN](configs/textrecog/crnn/README.md) (TPAMI'2016)
+- [x] [MAERec](configs/textrecog/maerec/README.md) (ICCV'2023)
- [x] [MASTER](configs/textrecog/master/README.md) (PR'2021)
- [x] [NRTR](configs/textrecog/nrtr/README.md) (ICDAR'2019)
- [x] [RobustScanner](configs/textrecog/robust_scanner/README.md) (ECCV'2020)
diff --git a/configs/textrecog/_base_/datasets/union14m_benchmark.py b/configs/textrecog/_base_/datasets/union14m_benchmark.py
new file mode 100644
index 000000000..007e4f878
--- /dev/null
+++ b/configs/textrecog/_base_/datasets/union14m_benchmark.py
@@ -0,0 +1,65 @@
+union14m_root = 'data/Union14M-L/'
+union14m_benchmark_root = 'data/Union14M-L/Union14M-Benchmarks'
+
+union14m_benchmark_artistic = dict(
+ type='OCRDataset',
+ data_prefix=dict(img_path=f'{union14m_benchmark_root}/artistic'),
+ ann_file=f'{union14m_benchmark_root}/artistic/annotation.json',
+ test_mode=True,
+ pipeline=None)
+
+union14m_benchmark_contextless = dict(
+ type='OCRDataset',
+ data_prefix=dict(img_path=f'{union14m_benchmark_root}/contextless'),
+ ann_file=f'{union14m_benchmark_root}/contextless/annotation.json',
+ test_mode=True,
+ pipeline=None)
+
+union14m_benchmark_curve = dict(
+ type='OCRDataset',
+ data_prefix=dict(img_path=f'{union14m_benchmark_root}/curve'),
+ ann_file=f'{union14m_benchmark_root}/curve/annotation.json',
+ test_mode=True,
+ pipeline=None)
+
+union14m_benchmark_incomplete = dict(
+ type='OCRDataset',
+ data_prefix=dict(img_path=f'{union14m_benchmark_root}/incomplete'),
+ ann_file=f'{union14m_benchmark_root}/incomplete/annotation.json',
+ test_mode=True,
+ pipeline=None)
+
+union14m_benchmark_incomplete_ori = dict(
+ type='OCRDataset',
+ data_prefix=dict(img_path=f'{union14m_benchmark_root}/incomplete_ori'),
+ ann_file=f'{union14m_benchmark_root}/incomplete_ori/annotation.json',
+ test_mode=True,
+ pipeline=None)
+
+union14m_benchmark_multi_oriented = dict(
+ type='OCRDataset',
+ data_prefix=dict(img_path=f'{union14m_benchmark_root}/multi_oriented'),
+ ann_file=f'{union14m_benchmark_root}/multi_oriented/annotation.json',
+ test_mode=True,
+ pipeline=None)
+
+union14m_benchmark_multi_words = dict(
+ type='OCRDataset',
+ data_prefix=dict(img_path=f'{union14m_benchmark_root}/multi_words'),
+ ann_file=f'{union14m_benchmark_root}/multi_words/annotation.json',
+ test_mode=True,
+ pipeline=None)
+
+union14m_benchmark_salient = dict(
+ type='OCRDataset',
+ data_prefix=dict(img_path=f'{union14m_benchmark_root}/salient'),
+ ann_file=f'{union14m_benchmark_root}/salient/annotation.json',
+ test_mode=True,
+ pipeline=None)
+
+union14m_benchmark_general = dict(
+ type='OCRDataset',
+ data_prefix=dict(img_path=f'{union14m_root}/'),
+ ann_file=f'{union14m_benchmark_root}/general/annotation.json',
+ test_mode=True,
+ pipeline=None)
diff --git a/configs/textrecog/_base_/datasets/union14m_train.py b/configs/textrecog/_base_/datasets/union14m_train.py
new file mode 100644
index 000000000..a91f2b104
--- /dev/null
+++ b/configs/textrecog/_base_/datasets/union14m_train.py
@@ -0,0 +1,38 @@
+union14m_data_root = 'data/Union14M-L/'
+
+union14m_challenging = dict(
+ type='OCRDataset',
+ data_root=union14m_data_root,
+ ann_file='train_annos/mmocr1.0/train_challenging.json',
+ test_mode=True,
+ pipeline=None)
+
+union14m_hard = dict(
+ type='OCRDataset',
+ data_root=union14m_data_root,
+ ann_file='train_annos/mmocr1.0/train_hard.json',
+ pipeline=None)
+
+union14m_medium = dict(
+ type='OCRDataset',
+ data_root=union14m_data_root,
+ ann_file='train_annos/mmocr1.0/train_medium.json',
+ pipeline=None)
+
+union14m_normal = dict(
+ type='OCRDataset',
+ data_root=union14m_data_root,
+ ann_file='train_annos/mmocr1.0/train_normal.json',
+ pipeline=None)
+
+union14m_easy = dict(
+ type='OCRDataset',
+ data_root=union14m_data_root,
+ ann_file='train_annos/mmocr1.0/train_easy.json',
+ pipeline=None)
+
+union14m_val = dict(
+ type='OCRDataset',
+ data_root=union14m_data_root,
+ ann_file='train_annos/mmocr1.0/val_annos.json',
+ pipeline=None)
diff --git a/configs/textrecog/_base_/schedules/schedule_adamw_cos_10e.py b/configs/textrecog/_base_/schedules/schedule_adamw_cos_10e.py
new file mode 100644
index 000000000..4f5c32a32
--- /dev/null
+++ b/configs/textrecog/_base_/schedules/schedule_adamw_cos_10e.py
@@ -0,0 +1,21 @@
+# optimizer
+optim_wrapper = dict(
+ type='OptimWrapper',
+ optimizer=dict(
+ type='AdamW',
+ lr=4e-4,
+ betas=(0.9, 0.999),
+ eps=1e-08,
+ weight_decay=0.01))
+train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=10, val_interval=1)
+val_cfg = dict(type='ValLoop')
+test_cfg = dict(type='TestLoop')
+
+# learning policy
+param_scheduler = [
+ dict(
+ type='CosineAnnealingLR',
+ T_max=10,
+ eta_min=4e-6,
+ convert_to_iter_based=True)
+]
diff --git a/configs/textrecog/maerec/README.md b/configs/textrecog/maerec/README.md
new file mode 100644
index 000000000..18b3b87c7
--- /dev/null
+++ b/configs/textrecog/maerec/README.md
@@ -0,0 +1,80 @@
+# MAERec
+
+> [Revisiting Scene Text Recognition: A Data Perspective](https://arxiv.org/abs/2307.08723)
+
+
+
+## Abstract
+
+This paper aims to re-assess scene text recognition (STR) from a data-oriented perspective. We begin by revisiting the six commonly used benchmarks in STR and observe a trend of performance saturation, whereby only 2.91% of the benchmark images cannot be accurately recognized by an ensemble of 13 representative models. While these results are impressive and suggest that STR could be considered solved, however, we argue that this is primarily due to the less challenging nature of the common benchmarks, thus concealing the underlying issues that STR faces. To this end, we consolidate a large-scale real STR dataset, namely Union14M, which comprises 4 million labeled images and 10 million unlabeled images, to assess the performance of STR models in more complex real-world scenarios. Our experiments demonstrate that the 13 models can only achieve an average accuracy of 66.53% on the 4 million labeled images, indicating that STR still faces numerous challenges in the real world. By analyzing the error patterns of the 13 models, we identify seven open challenges in STR and develop a challenge-driven benchmark consisting of eight distinct subsets to facilitate further progress in the field. Our exploration demonstrates that STR is far from being solved and leveraging data may be a promising solution. In this regard, we find that utilizing the 10 million unlabeled images through self-supervised pre-training can significantly improve the robustness of STR model in real-world scenarios and leads to state-of-the-art performance.
+
+
+
+
+
+## Dataset
+
+### Train Dataset
+
+| trainset | instance_num | repeat_num | source |
+| :--------------------------------------------------------------: | :----------: | :--------: | :----: |
+| [Union14M](https://github.com/Mountchicken/Union14M#34-download) | 3230742 | 1 | real |
+
+### Test Dataset
+
+- On six common benchmarks
+
+ | testset | instance_num | type |
+ | :-----: | :----------: | :-------: |
+ | IIIT5K | 3000 | regular |
+ | SVT | 647 | regular |
+ | IC13 | 1015 | regular |
+ | IC15 | 2077 | irregular |
+ | SVTP | 645 | irregular |
+ | CT80 | 288 | irregular |
+
+- On Union14M-Benchmark
+
+ | testset | instance_num | type |
+ | :------------: | :----------: | :------------------: |
+ | Artistic | 900 | Unsolved Challenge |
+ | Curve | 2426 | Unsolved Challenge |
+ | Multi-Oriented | 1369 | Unsolved Challenge |
+ | Contextless | 779 | Additional Challenge |
+ | Multi-Words | 829 | Additional Challenge |
+ | Salient | 1585 | Additional Challenge |
+ | Incomplete | 1495 | Additional Challenge |
+ | General | 400,000 | - |
+
+## Results and Models
+
+- Evaluated on six common benchmarks
+
+ | Methods | Backbone | | Regular Text | | | | Irregular Text | | download |
+ | :---------------------------------------------: | :----------------------------------------------: | :----: | :----------: | :-------: | :-: | :-------: | :------------: | :--: | :----------------------------------------------: |
+ | | | IIIT5K | SVT | IC13-1015 | | IC15-2077 | SVTP | CT80 | |
+ | [MAERec-S](configs/textrecog/maerec/maerec_s_union14m.py) | [ViT-Small (Pretrained on Union14M-U)](https://github.com/Mountchicken/Union14M#51-pre-training) | 98.0 | 97.6 | 96.8 | | 87.1 | 93.2 | 97.9 | [model](https://download.openmmlab.com/mmocr/textrecog/mae/mae_union14m/maerec_s_union14m-a9a157e5.pth) |
+ | [MAERec-B](configs/textrecog/maerec/maerec_b_union14m.py) | [ViT-Base (Pretrained on Union14M-U)](https://github.com/Mountchicken/Union14M#51-pre-training) | 98.5 | 98.1 | 97.8 | | 89.5 | 94.4 | 98.6 | [model](https://download.openmmlab.com/mmocr/textrecog/mae/mae_union14m/maerec_b_union14m-4b98d1b4.pth) |
+
+- Evaluated on Union14M-Benchmark
+
+ | Methods | Backbone | | Unsolved Challenges | | | | | Additional Challenges | | General | download |
+ | ----------------------------------- | ------------------------------------- | ----- | ------------------- | -------- | ----------- | --- | ------- | --------------------- | ---------- | ------- | ------------------------------------- |
+ | | | Curve | Multi-Oriented | Artistic | Contextless | | Salient | Multi-Words | Incomplete | General | |
+ | [MAERec-S](configs/textrecog/maerec/maerec_s_union14m.py) | [ViT-Small (Pretrained on Union14M-U)](https://github.com/Mountchicken/Union14M#51-pre-training) | 81.4 | 71.4 | 72.0 | 82.0 | | 78.5 | 82.4 | 2.7 | 82.5 | [model](https://download.openmmlab.com/mmocr/textrecog/mae/mae_union14m/maerec_s_union14m-a9a157e5.pth) |
+ | [MAERec-B](configs/textrecog/maerec/maerec_b_union14m.py) | [ViT-Base (Pretrained on Union14M-U)](https://github.com/Mountchicken/Union14M#51-pre-training) | 88.8 | 83.9 | 80.0 | 85.5 | | 84.9 | 87.5 | 2.6 | 85.8 | [model](https://download.openmmlab.com/mmocr/textrecog/mae/mae_union14m/maerec_b_union14m-4b98d1b4.pth) |
+
+- **To train with MAERec, you need to download pretrained ViT weight and load it in the config file. Check [here](https://github.com/Mountchicken/Union14M/blob/main/docs/finetune.md) for instructions**
+
+## Citation
+
+```bibtex
+@misc{jiang2023revisiting,
+ title={Revisiting Scene Text Recognition: A Data Perspective},
+ author={Qing Jiang and Jiapeng Wang and Dezhi Peng and Chongyu Liu and Lianwen Jin},
+ year={2023},
+ eprint={2307.08723},
+ archivePrefix={arXiv},
+ primaryClass={cs.CV}
+}
+```
diff --git a/configs/textrecog/maerec/_base_marec_vit_s.py b/configs/textrecog/maerec/_base_marec_vit_s.py
new file mode 100644
index 000000000..06febd088
--- /dev/null
+++ b/configs/textrecog/maerec/_base_marec_vit_s.py
@@ -0,0 +1,159 @@
+dictionary = dict(
+ type='Dictionary',
+ dict_file= # noqa
+ '{{ fileDirname }}/../../../dicts/english_digits_symbols_space.txt',
+ with_padding=True,
+ with_unknown=True,
+ same_start_end=True,
+ with_start=True,
+ with_end=True)
+
+model = dict(
+ type='MAERec',
+ backbone=dict(
+ type='VisionTransformer',
+ img_size=(32, 128),
+ patch_size=(4, 4),
+ embed_dim=384,
+ depth=12,
+ num_heads=6,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ pretrained=None),
+ decoder=dict(
+ type='MAERecDecoder',
+ n_layers=6,
+ d_embedding=384,
+ n_head=8,
+ d_model=384,
+ d_inner=384 * 4,
+ d_k=48,
+ d_v=48,
+ postprocessor=dict(type='AttentionPostprocessor'),
+ module_loss=dict(
+ type='CEModuleLoss', reduction='mean', ignore_first_char=True),
+ max_seq_len=48,
+ dictionary=dictionary),
+ data_preprocessor=dict(
+ type='TextRecogDataPreprocessor',
+ mean=[123.675, 116.28, 103.53],
+ std=[58.395, 57.12, 57.375]))
+
+train_pipeline = [
+ dict(type='LoadImageFromFile', ignore_empty=True, min_size=0),
+ dict(type='LoadOCRAnnotations', with_text=True),
+ dict(type='Resize', scale=(128, 32)),
+ dict(
+ type='RandomApply',
+ prob=0.5,
+ transforms=[
+ dict(
+ type='RandomChoice',
+ transforms=[
+ dict(
+ type='RandomRotate',
+ max_angle=15,
+ ),
+ dict(
+ type='TorchVisionWrapper',
+ op='RandomAffine',
+ degrees=15,
+ translate=(0.3, 0.3),
+ scale=(0.5, 2.),
+ shear=(-45, 45),
+ ),
+ dict(
+ type='TorchVisionWrapper',
+ op='RandomPerspective',
+ distortion_scale=0.5,
+ p=1,
+ ),
+ ])
+ ],
+ ),
+ dict(
+ type='RandomApply',
+ prob=0.25,
+ transforms=[
+ dict(type='PyramidRescale'),
+ dict(
+ type='mmdet.Albu',
+ transforms=[
+ dict(type='GaussNoise', var_limit=(20, 20), p=0.5),
+ dict(type='MotionBlur', blur_limit=7, p=0.5),
+ ]),
+ ]),
+ dict(
+ type='RandomApply',
+ prob=0.25,
+ transforms=[
+ dict(
+ type='TorchVisionWrapper',
+ op='ColorJitter',
+ brightness=0.5,
+ saturation=0.5,
+ contrast=0.5,
+ hue=0.1),
+ ]),
+ dict(
+ type='PackTextRecogInputs',
+ meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio'))
+]
+
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='Resize', scale=(128, 32)),
+ # add loading annotation after ``Resize`` because ground truth
+ # does not need to do resize data transform
+ dict(type='LoadOCRAnnotations', with_text=True),
+ dict(
+ type='PackTextRecogInputs',
+ meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio'))
+]
+
+tta_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='TestTimeAug',
+ transforms=[
+ [
+ dict(
+ type='ConditionApply',
+ true_transforms=[
+ dict(
+ type='ImgAugWrapper',
+ args=[dict(cls='Rot90', k=0, keep_size=False)])
+ ],
+ condition="results['img_shape'][1] None:
+ super().__init__(
+ module_loss=module_loss,
+ postprocessor=postprocessor,
+ dictionary=dictionary,
+ init_cfg=init_cfg,
+ max_seq_len=max_seq_len)
+
+ self.padding_idx = self.dictionary.padding_idx
+ self.start_idx = self.dictionary.start_idx
+ self.max_seq_len = max_seq_len
+
+ self.trg_word_emb = nn.Embedding(
+ self.dictionary.num_classes,
+ d_embedding,
+ padding_idx=self.padding_idx)
+
+ self.position_enc = PositionalEncoding(
+ d_embedding, n_position=n_position)
+ self.dropout = nn.Dropout(p=dropout)
+
+ self.layer_stack = ModuleList([
+ TFDecoderLayer(
+ d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
+ for _ in range(n_layers)
+ ])
+ self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
+
+ pred_num_class = self.dictionary.num_classes
+ self.classifier = nn.Linear(d_model, pred_num_class)
+ self.softmax = nn.Softmax(dim=-1)
+
+ def _get_target_mask(self, trg_seq: torch.Tensor) -> torch.Tensor:
+ """Generate mask for target sequence.
+
+ Args:
+ trg_seq (torch.Tensor): Input text sequence. Shape :math:`(N, T)`.
+
+ Returns:
+ Tensor: Target mask. Shape :math:`(N, T, T)`.
+ E.g.:
+ seq = torch.Tensor([[1, 2, 0, 0]]), pad_idx = 0, then
+ target_mask =
+ torch.Tensor([[[True, False, False, False],
+ [True, True, False, False],
+ [True, True, False, False],
+ [True, True, False, False]]])
+ """
+
+ pad_mask = (trg_seq != self.padding_idx).unsqueeze(-2)
+
+ len_s = trg_seq.size(1)
+ subsequent_mask = 1 - torch.triu(
+ torch.ones((len_s, len_s), device=trg_seq.device), diagonal=1)
+ subsequent_mask = subsequent_mask.unsqueeze(0).bool()
+
+ return pad_mask & subsequent_mask
+
+ def _get_source_mask(self, src_seq: torch.Tensor,
+ valid_ratios: Sequence[float]) -> torch.Tensor:
+ """Generate mask for source sequence.
+
+ Args:
+ src_seq (torch.Tensor): Image sequence. Shape :math:`(N, T, C)`.
+ valid_ratios (list[float]): The valid ratio of input image. For
+ example, if the width of the original image is w1 and the width
+ after padding is w2, then valid_ratio = w1/w2. Source mask is
+ used to cover the area of the padding region.
+
+ Returns:
+ Tensor or None: Source mask. Shape :math:`(N, T)`. The region of
+ padding area are False, and the rest are True.
+ """
+
+ N, T, _ = src_seq.size()
+ mask = None
+ if len(valid_ratios) > 0:
+ mask = src_seq.new_zeros((N, T), device=src_seq.device)
+ for i, valid_ratio in enumerate(valid_ratios):
+ valid_width = min(T, math.ceil(T * valid_ratio))
+ mask[i, :valid_width] = 1
+
+ return mask
+
+ def _attention(self,
+ trg_seq: torch.Tensor,
+ src: torch.Tensor,
+ src_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
+ """A wrapped process for transformer based decoder including text
+ embedding, position embedding, N x transformer decoder and a LayerNorm
+ operation.
+
+ Args:
+ trg_seq (Tensor): Target sequence in. Shape :math:`(N, T)`.
+ src (Tensor): Source sequence from encoder in shape
+ Shape :math:`(N, T, D_m)` where :math:`D_m` is ``d_model``.
+ src_mask (Tensor, Optional): Mask for source sequence.
+ Shape :math:`(N, T)`. Defaults to None.
+
+ Returns:
+ Tensor: Output sequence from transformer decoder.
+ Shape :math:`(N, T, D_m)` where :math:`D_m` is ``d_model``.
+ """
+
+ trg_embedding = self.trg_word_emb(trg_seq)
+ trg_pos_encoded = self.position_enc(trg_embedding)
+ trg_mask = self._get_target_mask(trg_seq)
+ tgt_seq = self.dropout(trg_pos_encoded)
+
+ output = tgt_seq
+ for dec_layer in self.layer_stack:
+ output = dec_layer(
+ output,
+ src,
+ self_attn_mask=trg_mask,
+ dec_enc_attn_mask=src_mask)
+ output = self.layer_norm(output)
+
+ return output
+
+ def forward_train(self,
+ feat: Optional[torch.Tensor] = None,
+ out_enc: torch.Tensor = None,
+ data_samples: Sequence[TextRecogDataSample] = None
+ ) -> torch.Tensor:
+ """Forward for training. Source mask will be used here.
+
+ Args:
+ feat (Tensor, optional): Unused.
+ out_enc (Tensor): Encoder output of shape : math:`(N, T, D_m)`
+ where :math:`D_m` is ``d_model``. Defaults to None.
+ data_samples (list[TextRecogDataSample]): Batch of
+ TextRecogDataSample, containing gt_text and valid_ratio
+ information. Defaults to None.
+
+ Returns:
+ Tensor: The raw logit tensor. Shape :math:`(N, T, C)` where
+ :math:`C` is ``num_classes``.
+ """
+ valid_ratios = []
+ for data_sample in data_samples:
+ valid_ratios.append(data_sample.get('valid_ratio'))
+ src_mask = self._get_source_mask(feat, valid_ratios)
+ trg_seq = []
+ for data_sample in data_samples:
+ trg_seq.append(data_sample.gt_text.padded_indexes.to(feat.device))
+ trg_seq = torch.stack(trg_seq, dim=0)
+ attn_output = self._attention(trg_seq, feat, src_mask=src_mask)
+ outputs = self.classifier(attn_output)
+
+ return outputs
+
+ def forward_test(self,
+ feat: Optional[torch.Tensor] = None,
+ out_enc: torch.Tensor = None,
+ data_samples: Sequence[TextRecogDataSample] = None
+ ) -> torch.Tensor:
+ """Forward for testing.
+
+ Args:
+ feat (Tensor, optional): Unused.
+ out_enc (Tensor): Encoder output of shape:
+ math:`(N, T, D_m)` where :math:`D_m` is ``d_model``.
+ Defaults to None.
+ data_samples (list[TextRecogDataSample]): Batch of
+ TextRecogDataSample, containing gt_text and valid_ratio
+ information. Defaults to None.
+
+ Returns:
+ Tensor: Character probabilities. of shape
+ :math:`(N, self.max_seq_len, C)` where :math:`C` is
+ ``num_classes``.
+ """
+ valid_ratios = []
+ for data_sample in data_samples:
+ valid_ratios.append(data_sample.get('valid_ratio'))
+ src_mask = self._get_source_mask(feat, valid_ratios)
+ N = feat.size(0)
+ init_target_seq = torch.full((N, self.max_seq_len + 1),
+ self.padding_idx,
+ device=feat.device,
+ dtype=torch.long)
+ # bsz * seq_len
+ init_target_seq[:, 0] = self.start_idx
+
+ outputs = []
+ for step in range(0, self.max_seq_len):
+ decoder_output = self._attention(
+ init_target_seq, feat, src_mask=src_mask)
+ # bsz * seq_len * C
+ step_result = self.classifier(decoder_output[:, step, :])
+ # bsz * num_classes
+ outputs.append(step_result)
+ _, step_max_index = torch.max(step_result, dim=-1)
+ init_target_seq[:, step + 1] = step_max_index
+
+ outputs = torch.stack(outputs, dim=1)
+
+ return self.softmax(outputs)
diff --git a/mmocr/models/textrecog/recognizers/__init__.py b/mmocr/models/textrecog/recognizers/__init__.py
index d9016492d..d517d6fbd 100644
--- a/mmocr/models/textrecog/recognizers/__init__.py
+++ b/mmocr/models/textrecog/recognizers/__init__.py
@@ -5,6 +5,7 @@
from .crnn import CRNN
from .encoder_decoder_recognizer import EncoderDecoderRecognizer
from .encoder_decoder_recognizer_tta import EncoderDecoderRecognizerTTAModel
+from .maerec import MAERec
from .master import MASTER
from .nrtr import NRTR
from .robust_scanner import RobustScanner
@@ -15,5 +16,5 @@
__all__ = [
'BaseRecognizer', 'EncoderDecoderRecognizer', 'CRNN', 'SARNet', 'NRTR',
'RobustScanner', 'SATRN', 'ABINet', 'MASTER', 'SVTR', 'ASTER',
- 'EncoderDecoderRecognizerTTAModel'
+ 'EncoderDecoderRecognizerTTAModel', 'MAERec'
]
diff --git a/mmocr/models/textrecog/recognizers/maerec.py b/mmocr/models/textrecog/recognizers/maerec.py
new file mode 100644
index 000000000..788978f18
--- /dev/null
+++ b/mmocr/models/textrecog/recognizers/maerec.py
@@ -0,0 +1,8 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from mmocr.registry import MODELS
+from .encoder_decoder_recognizer import EncoderDecoderRecognizer
+
+
+@MODELS.register_module()
+class MAERec(EncoderDecoderRecognizer):
+ """Implementation of MAERec."""
diff --git a/model-index.yml b/model-index.yml
index 563372c26..2a227cee0 100644
--- a/model-index.yml
+++ b/model-index.yml
@@ -10,6 +10,7 @@ Import:
- configs/textrecog/abinet/metafile.yml
- configs/textrecog/aster/metafile.yml
- configs/textrecog/crnn/metafile.yml
+ - configs/textrecog/maerec/metafile.yml
- configs/textrecog/master/metafile.yml
- configs/textrecog/nrtr/metafile.yml
- configs/textrecog/svtr/metafile.yml
diff --git a/requirements/runtime.txt b/requirements/runtime.txt
index 52a9eec3c..e39d7328b 100644
--- a/requirements/runtime.txt
+++ b/requirements/runtime.txt
@@ -7,3 +7,4 @@ pyclipper
pycocotools
rapidfuzz>=2.0.0
scikit-image
+timm==0.9.2