From d56155c82df3b0a4e859b692acc7fd9a26d760d3 Mon Sep 17 00:00:00 2001 From: Tong Gao Date: Tue, 7 Mar 2023 20:08:25 +0800 Subject: [PATCH] [Feature] Support lmdb format in Dataset Preparer (#1762) * [Dataset Preparer] Support lmdb format * fix * fix * fix * fix * fix * readme * readme --- dataset_zoo/icdar2015/textrecog.py | 4 +- .../data_prepare/dataset_preparer.md | 58 +++++++- .../data_prepare/dataset_preparer.md | 64 ++++++-- .../preparers/config_generators/base.py | 11 ++ .../textrecog_config_generator.py | 31 +++- mmocr/datasets/preparers/dumpers/__init__.py | 6 +- .../datasets/preparers/dumpers/lmdb_dumper.py | 140 ++++++++++++++++++ .../obtainers/naive_data_obtainer.py | 2 +- tools/dataset_converters/prepare_dataset.py | 57 ++++++- 9 files changed, 347 insertions(+), 26 deletions(-) create mode 100644 mmocr/datasets/preparers/dumpers/lmdb_dumper.py diff --git a/dataset_zoo/icdar2015/textrecog.py b/dataset_zoo/icdar2015/textrecog.py index daecdf906..181f7d614 100644 --- a/dataset_zoo/icdar2015/textrecog.py +++ b/dataset_zoo/icdar2015/textrecog.py @@ -61,7 +61,9 @@ parser=dict(type='ICDARTxtTextRecogAnnParser', encoding='utf-8-sig'), packer=dict(type='TextRecogPacker'), dumper=dict(type='JsonDumper')) -delete = ['annotations'] +delete = [ + 'annotations', 'ic15_textrecog_train_img_gt', 'ic15_textrecog_test_img' +] config_generator = dict( type='TextRecogConfigGenerator', test_anns=[ diff --git a/docs/en/user_guides/data_prepare/dataset_preparer.md b/docs/en/user_guides/data_prepare/dataset_preparer.md index af520ecc5..0e00a544d 100644 --- a/docs/en/user_guides/data_prepare/dataset_preparer.md +++ b/docs/en/user_guides/data_prepare/dataset_preparer.md @@ -15,11 +15,15 @@ Only one line of command is needed to complete the data download, decompression, python tools/dataset_converters/prepare_dataset.py [$DATASET_NAME] --task [$TASK] --nproc [$NPROC] ``` -| ARGS | Type | Description | -| ------------ | ---- | ----------------------------------------------------------------------------------------------------------------------------------------- | -| dataset_name | str | (required) dataset name. | -| --task | str | Convert the dataset to the format of a specified task supported by MMOCR. options are: 'textdet', 'textrecog', 'textspotting', and 'kie'. | -| --nproc | int | Number of processes to be used. Defaults to 4. | +| ARGS | Type | Description | +| ------------------ | ---- | ----------------------------------------------------------------------------------------------------------------------------------------- | +| dataset_name | str | (required) dataset name. | +| --nproc | int | Number of processes to be used. Defaults to 4. | +| --task | str | Convert the dataset to the format of a specified task supported by MMOCR. options are: 'textdet', 'textrecog', 'textspotting', and 'kie'. | +| --splits | str | Splits of the dataset to be prepared. Multiple splits can be accepted. Defaults to `train val test`. | +| --lmdb | str | Store the data in LMDB format. Only valid when the task is `textrecog`. | +| --overwrite-cfg | str | Whether to overwrite the dataset config file if it already exists in `configs/{task}/_base_/datasets`. | +| --dataset-zoo-path | str | Path to the dataset config file. If not specified, the default path is `./dataset_zoo`. | For example, the following command shows how to use the script to prepare the ICDAR2015 dataset for text detection task. @@ -37,6 +41,44 @@ To check the supported datasets of Dataset Preparer, please refer to [Dataset Zo ## Advanced Usage +### LMDB Format + +In text recognition tasks, we usually use LMDB format to store data to speed up data loading. When using the `prepare_dataset.py` script to prepare data, you can store data to the LMDB format by the `--lmdb` parameter. For example: + +```bash +python tools/dataset_converters/prepare_dataset.py icdar2015 --task textrecog --lmdb +``` + +As soon as the dataset is prepared, Dataset Preparer will generate `icdar2015_lmdb.py` in the `configs/textrecog/_base_/datasets/` directory. You can inherit this file and point the `dataloader` to the LMDB dataset. Moreover, the LMDB dataset needs to be loaded by [`LoadImageFromNDArray`](mmocr.datasets.transforms.LoadImageFromNDArray), thus you also need to modify `pipeline`. + +For example, if we want to change the training set of `configs/textrecog/crnn/crnn_mini-vgg_5e_mj.py` to icdar2015 generated before, we need to perform the following modifications: + +1. Modify `configs/textrecog/crnn/crnn_mini-vgg_5e_mj.py`: + + ```python + _base_ = [ + '../_base_/datasets/icdar2015_lmdb.py', # point to icdar2015 lmdb dataset + ... + ] + + train_list = [_base_.icdar2015_lmdb_textrecog_train] + ... + ``` + +2. Modify `train_pipeline` in `configs/textrecog/crnn/_base_crnn_mini-vgg.py`, change `LoadImageFromFile` to `LoadImageFromNDArray`: + + ```python + train_pipeline = [ + dict( + type='LoadImageFromNDArray', + color_type='grayscale', + file_client_args=file_client_args, + ignore_empty=True, + min_size=2), + ... + ] + ``` + ### Configuration of Dataset Preparer Dataset preparer uses a modular design to enhance extensibility, which allows users to extend it to other public or private datasets easily. The configuration files of the dataset preparers are stored in the `dataset_zoo/`, where all the configs of currently supported datasets can be found here. The directory structure is as follows: @@ -95,6 +137,10 @@ Data: It is not mandatory to use the metafile in the dataset preparation process (so users can ignore this file when preparing private datasets), but in order to better understand the information of each public dataset, we recommend that users read the metafile before preparing the dataset, which will help to understand whether the datasets meet their needs. +```{warning} +The following section is outdated as of MMOCR 1.0.0rc6. +``` + #### Config of Dataset Preparer Next, we will introduce the conventional fields and usage of the dataset preparer configuration files. @@ -186,7 +232,7 @@ Therefore, we provide two built-in gatherers, `pair_gather` and `mono_gather`, t When the image and annotation file are matched, the original annotations will be parsed. Since the annotation format is usually varied from dataset to dataset, the parsers are usually dataset related. Then, the parser will pack the required data into the MMOCR format. -Finally, we can specify the dumpers to decide the data format. Currently, we only support `JsonDumper` and `WildreceiptOpensetDumper`, where the former is used to save the data in the standard MMOCR Json format, and the latter is used to save the data in the Wildreceipt format. In the future, we plan to support `LMDBDumper` to save the annotation files in LMDB format. +Finally, we can specify the dumpers to decide the data format. Currently, we support `JsonDumper`, `WildreceiptOpensetDumper`, and `TextRecogLMDBDumper`. They are used to save the data in the standard MMOCR Json format, Wildreceipt format, and the LMDB format commonly used in academia in the field of text recognition, respectively. ### Use DataPreparer to prepare customized dataset diff --git a/docs/zh_cn/user_guides/data_prepare/dataset_preparer.md b/docs/zh_cn/user_guides/data_prepare/dataset_preparer.md index a5f9b40ec..8a89dd30f 100644 --- a/docs/zh_cn/user_guides/data_prepare/dataset_preparer.md +++ b/docs/zh_cn/user_guides/data_prepare/dataset_preparer.md @@ -1,4 +1,4 @@ -# 数据准备 (Beta) +c# 数据准备 (Beta) ```{note} Dataset Preparer 目前仍处在公测阶段,欢迎尝鲜试用!如遇到任何问题,请及时向我们反馈。 @@ -11,16 +11,18 @@ MMOCR 提供了统一的一站式数据集准备脚本 `prepare_dataset.py`。 仅需一行命令即可完成数据的下载、解压、格式转换,及基础配置的生成。 ```bash -python tools/dataset_converters/prepare_dataset.py [$DATASET_NAME] [--task $TASK] [--nproc $NPROC] [--overwrite-cfg] [--dataset-zoo-path $DATASET_ZOO_PATH] +python tools/dataset_converters/prepare_dataset.py [-h] [--nproc NPROC] [--task {textdet,textrecog,textspotting,kie}] [--splits SPLITS [SPLITS ...]] [--lmdb] [--overwrite-cfg] [--dataset-zoo-path DATASET_ZOO_PATH] datasets [datasets ...] ``` -| 参数 | 类型 | 说明 | -| ------------------ | ---- | ----------------------------------------------------------------------------------------------------- | -| dataset_name | str | (必须)需要准备的数据集名称。 | -| --task | str | 将数据集格式转换为指定任务的 MMOCR 格式。可选项为: 'textdet', 'textrecog', 'textspotting' 和 'kie'。 | -| --nproc | str | 使用的进程数,默认为 4。 | -| --overwrite-cfg | str | 若数据集的基础配置已经在 `configs/{task}/_base_/datasets` 中存在,依然重写该配置 | -| --dataset-zoo-path | str | 存放数据库配置文件的路径。若不指定,则默认为 `./dataset_zoo` | +| 参数 | 类型 | 说明 | +| ------------------ | -------------------------- | ----------------------------------------------------------------------------------------------------- | +| dataset_name | str | (必须)需要准备的数据集名称。 | +| --nproc | str | 使用的进程数,默认为 4。 | +| --task | str | 将数据集格式转换为指定任务的 MMOCR 格式。可选项为: 'textdet', 'textrecog', 'textspotting' 和 'kie'。 | +| --splits | \['train', 'val', 'test'\] | 希望准备的数据集分割,可以接受多个参数。默认为 `train val test`。 | +| --lmdb | str | 把数据储存为 LMDB 格式,仅当任务为 `textrecog` 时生效。 | +| --overwrite-cfg | str | 若数据集的基础配置已经在 `configs/{task}/_base_/datasets` 中存在,依然重写该配置 | +| --dataset-zoo-path | str | 存放数据库配置文件的路径。若不指定,则默认为 `./dataset_zoo` | 例如,以下命令展示了如何使用该脚本为 ICDAR2015 数据集准备文本检测任务所需的数据。 @@ -38,6 +40,44 @@ python tools/dataset_converters/prepare_dataset.py icdar2015 totaltext --task te ## 进阶用法 +### LMDB 格式 + +在文本识别任务中,我们通常使用 LMDB 格式来存储数据,以加快数据的读取速度。在使用 `prepare_dataset.py` 脚本准备数据时,可以通过 `--lmdb` 参数来指定将数据转换为 LMDB 格式。例如: + +```bash +python tools/dataset_converters/prepare_dataset.py icdar2015 --task textrecog --lmdb +``` + +数据集准备完成后,Dataset Preparer 会在 `configs/textrecog/_base_/datasets/` 中生成 `icdar2015_lmdb.py` 配置。你可以继承该配置,并将 `dataloader` 指向 LMDB 数据集。然而,LMDB 数据集的读取需要配合 [`LoadImageFromNDArray`](mmocr.datasets.transforms.LoadImageFromNDArray),因此你也同样需要修改 `pipeline`。 + +例如,我们想要将 `configs/textrecog/crnn/crnn_mini-vgg_5e_mj.py` 的训练集改为刚刚生成的 icdar2015,则需要作如下修改: + +1. 修改 `configs/textrecog/crnn/crnn_mini-vgg_5e_mj.py`: + + ```python + _base_ = [ + '../_base_/datasets/icdar2015_lmdb.py', # 指向 icdar2015 lmdb 数据集 + ... # 省略 + ] + + train_list = [_base_.icdar2015_lmdb_textrecog_train] + ... + ``` + +2. 修改 `configs/textrecog/crnn/_base_crnn_mini-vgg.py` 中的 `train_pipeline`, 将 `LoadImageFromFile` 改为 `LoadImageFromNDArray`: + + ```python + train_pipeline = [ + dict( + type='LoadImageFromNDArray', + color_type='grayscale', + file_client_args=file_client_args, + ignore_empty=True, + min_size=2), + ... + ] + ``` + ### 数据集配置 数据集自动化准备脚本使用了模块化的设计,极大地增强了扩展性,用户能够很方便地配置其他公开数据集或私有数据集。数据集自动化准备脚本的配置文件被统一存储在 `dataset_zoo/` 目录下,用户可以在该目录下找到所有已由 MMOCR 官方支持的数据集准备脚本配置文件。该文件夹的目录结构如下: @@ -96,6 +136,10 @@ Data: 该文件在数据集准备过程中并不是强制要求的(因此用户在使用添加自己的私有数据集时可以忽略该文件),但为了用户更好地了解各个公开数据集的信息,我们建议用户在使用数据集准备脚本前阅读对应的元文件信息,以了解该数据集的特征是否符合用户需求。 +```{warning} +自 MMOCR 1.0.0rc6 起,接下来的章节可能会与实际实现有所出入。 +``` + #### 数据集准备脚本配置文件 下面,我们将介绍数据集准备脚本配置文件 `textXXX.py` 的默认字段与使用方法。 @@ -235,7 +279,7 @@ OCR 数据集通常有两种标注保存形式,一种为多个标注文件对 ###### `dumper` -之后,我们可以通过指定不同的 dumper 来决定要将数据保存为何种格式。目前,我们仅支持 `JsonDumper` 与 `WildreceiptOpensetDumper`,其中,前者用于将数据保存为标准的 MMOCR Json 格式,而后者用于将数据保存为 Wildreceipt 格式。未来,我们计划支持 `LMDBDumper` 用于保存 LMDB 格式的标注文件。 +之后,我们可以通过指定不同的 dumper 来决定要将数据保存为何种格式。目前,我们支持 `JsonDumper`, `WildreceiptOpensetDumper`,及 `TextRecogLMDBDumper`。他们分别用于将数据保存为标准的 MMOCR Json 格式、Wildreceipt 格式,及文本识别领域学术界常用的 LMDB 格式。 ###### `delete` diff --git a/mmocr/datasets/preparers/config_generators/base.py b/mmocr/datasets/preparers/config_generators/base.py index 6139ba4f7..ba3811a42 100644 --- a/mmocr/datasets/preparers/config_generators/base.py +++ b/mmocr/datasets/preparers/config_generators/base.py @@ -81,6 +81,17 @@ def _prepare_anns(self, train_anns: Optional[List[Dict]], ' None!') for ann_dict in ann_list: assert 'ann_file' in ann_dict + suffix = ann_dict['ann_file'].split('.')[-1] + if suffix == 'json': + dataset_type = 'OCRDataset' + elif suffix == 'lmdb': + assert self.task == 'textrecog', \ + 'LMDB format only works for textrecog now.' + dataset_type = 'RecogLMDBDataset' + else: + raise NotImplementedError( + 'ann file only supports JSON file or LMDB file') + ann_dict['dataset_type'] = dataset_type if ann_dict.get('dataset_postfix', ''): key = f'{self.dataset_name}_{ann_dict["dataset_postfix"]}_{self.task}_{split}' # noqa else: diff --git a/mmocr/datasets/preparers/config_generators/textrecog_config_generator.py b/mmocr/datasets/preparers/config_generators/textrecog_config_generator.py index 23ce3b374..bb8b62625 100644 --- a/mmocr/datasets/preparers/config_generators/textrecog_config_generator.py +++ b/mmocr/datasets/preparers/config_generators/textrecog_config_generator.py @@ -36,19 +36,38 @@ class TextRecogConfigGenerator(BaseDatasetConfigGenerator): Example: It generates a dataset config like: - >>> ic15_rec_data_root = 'data/icdar2015/' + >>> icdar2015_textrecog_data_root = 'data/icdar2015/' >>> icdar2015_textrecog_train = dict( >>> type='OCRDataset', - >>> data_root=ic15_rec_data_root, + >>> data_root=icdar2015_textrecog_data_root, >>> ann_file='textrecog_train.json', - >>> test_mode=False, >>> pipeline=None) >>> icdar2015_textrecog_test = dict( >>> type='OCRDataset', - >>> data_root=ic15_rec_data_root, + >>> data_root=icdar2015_textrecog_data_root, >>> ann_file='textrecog_test.json', >>> test_mode=True, >>> pipeline=None) + + It generates a lmdb format dataset config like: + >>> icdar2015_lmdb_textrecog_data_root = 'data/icdar2015' + >>> icdar2015_lmdb_textrecog_train = dict( + >>> type='RecogLMDBDataset', + >>> data_root=icdar2015_lmdb_textrecog_data_root, + >>> ann_file='textrecog_train.lmdb', + >>> pipeline=None) + >>> icdar2015_lmdb_textrecog_test = dict( + >>> type='RecogLMDBDataset', + >>> data_root=icdar2015_lmdb_textrecog_data_root, + >>> ann_file='textrecog_test.lmdb', + >>> test_mode=True, + >>> pipeline=None) + >>> icdar2015_lmdb_1811_textrecog_test = dict( + >>> type='RecogLMDBDataset', + >>> data_root=icdar2015_lmdb_textrecog_data_root, + >>> ann_file='textrecog_test_1811.lmdb', + >>> test_mode=True, + >>> pipeline=None) """ def __init__( @@ -100,8 +119,8 @@ def _gen_dataset_config(self) -> str: cfg = '' for key_name, ann_dict in self.anns.items(): cfg += f'\n{key_name} = dict(\n' - cfg += ' type=\'OCRDataset\',\n' - cfg += ' data_root=' + f'{self.dataset_name}_{self.task}_data_root,\n' # noqa: E501 + cfg += f' type=\'{ann_dict["dataset_type"]}\',\n' + cfg += f' data_root={self.dataset_name}_{self.task}_data_root,\n' # noqa: E501 cfg += f' ann_file=\'{ann_dict["ann_file"]}\',\n' if ann_dict['split'] in ['test', 'val']: cfg += ' test_mode=True,\n' diff --git a/mmocr/datasets/preparers/dumpers/__init__.py b/mmocr/datasets/preparers/dumpers/__init__.py index 1a73468ef..ed3dda486 100644 --- a/mmocr/datasets/preparers/dumpers/__init__.py +++ b/mmocr/datasets/preparers/dumpers/__init__.py @@ -1,6 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. from .base import BaseDumper from .json_dumper import JsonDumper +from .lmdb_dumper import TextRecogLMDBDumper from .wild_receipt_openset_dumper import WildreceiptOpensetDumper -__all__ = ['BaseDumper', 'JsonDumper', 'WildreceiptOpensetDumper'] +__all__ = [ + 'BaseDumper', 'JsonDumper', 'WildreceiptOpensetDumper', + 'TextRecogLMDBDumper' +] diff --git a/mmocr/datasets/preparers/dumpers/lmdb_dumper.py b/mmocr/datasets/preparers/dumpers/lmdb_dumper.py new file mode 100644 index 000000000..9cd49d17f --- /dev/null +++ b/mmocr/datasets/preparers/dumpers/lmdb_dumper.py @@ -0,0 +1,140 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +import warnings +from typing import Dict, List + +import cv2 +import lmdb +import mmengine +import numpy as np + +from mmocr.registry import DATA_DUMPERS +from .base import BaseDumper + + +@DATA_DUMPERS.register_module() +class TextRecogLMDBDumper(BaseDumper): + """Text recognition LMDB format dataset dumper. + + Args: + task (str): Task type. Options are 'textdet', 'textrecog', + 'textspotter', and 'kie'. It is usually set automatically and users + do not need to set it manually in config file in most cases. + split (str): It' s the partition of the datasets. Options are 'train', + 'val' or 'test'. It is usually set automatically and users do not + need to set it manually in config file in most cases. Defaults to + None. + data_root (str): The root directory of the image and + annotation. It is usually set automatically and users do not need + to set it manually in config file in most cases. Defaults to None. + batch_size (int): Number of files written to the cache each time. + Defaults to 1000. + encoding (str): Label encoding method. Defaults to 'utf-8'. + lmdb_map_size (int): Maximum size database may grow to. Defaults to + 1099511627776. + verify (bool): Whether to check the validity of every image. Defaults + to True. + """ + + def __init__(self, + task: str, + split: str, + data_root: str, + batch_size: int = 1000, + encoding: str = 'utf-8', + lmdb_map_size: int = 1099511627776, + verify: bool = True) -> None: + assert task == 'textrecog', \ + f'TextRecogLMDBDumper only works with textrecog, but got {task}' + super().__init__(task=task, split=split, data_root=data_root) + self.batch_size = batch_size + self.encoding = encoding + self.lmdb_map_size = lmdb_map_size + self.verify = verify + + def check_image_is_valid(self, imageBin): + if imageBin is None: + return False + imageBuf = np.frombuffer(imageBin, dtype=np.uint8) + img = cv2.imdecode(imageBuf, cv2.IMREAD_GRAYSCALE) + imgH, imgW = img.shape[0], img.shape[1] + if imgH * imgW == 0: + return False + return True + + def write_cache(self, env, cache): + with env.begin(write=True) as txn: + cursor = txn.cursor() + cursor.putmulti(cache, dupdata=False, overwrite=True) + + def parser_pack_instance(self, instance: Dict): + """parser an packed MMOCR format textrecog instance. + Args: + instance (Dict): An packed MMOCR format textrecog instance. + For example, + { + "instance": [ + { + "text": "Hello" + } + ], + "img_path": "img1.jpg" + } + """ + assert isinstance(instance, + Dict), 'Element of data_list must be a dict' + assert 'img_path' in instance and 'instances' in instance, \ + 'Element of data_list must have the following keys: ' \ + f'img_path and instances, but got {instance.keys()}' + assert isinstance(instance['instances'], List) and len( + instance['instances']) == 1 + assert 'text' in instance['instances'][0] + + img_path = instance['img_path'] + text = instance['instances'][0]['text'] + return img_path, text + + def dump(self, data: Dict) -> None: + """Dump data to LMDB format.""" + + # create lmdb env + output_dirname = f'{self.task}_{self.split}.lmdb' + output = osp.join(self.data_root, output_dirname) + mmengine.mkdir_or_exist(output) + env = lmdb.open(output, map_size=self.lmdb_map_size) + # load data + if 'data_list' not in data: + raise ValueError('Dump data must have data_list key') + data_list = data['data_list'] + cache = [] + # index start from 1 + cnt = 1 + n_samples = len(data_list) + for d in data_list: + # convert both images and labels to lmdb + label_key = 'label-%09d'.encode(self.encoding) % cnt + img_name, text = self.parser_pack_instance(d) + img_path = osp.join(self.data_root, img_name) + if not osp.exists(img_path): + warnings.warn('%s does not exist' % img_path) + continue + with open(img_path, 'rb') as f: + image_bin = f.read() + if self.verify: + if not self.check_image_is_valid(image_bin): + warnings.warn('%s is not a valid image' % img_path) + continue + image_key = 'image-%09d'.encode(self.encoding) % cnt + cache.append((image_key, image_bin)) + cache.append((label_key, text.encode(self.encoding))) + + if cnt % self.batch_size == 0: + self.write_cache(env, cache) + cache = [] + print('Written %d / %d' % (cnt, n_samples)) + cnt += 1 + n_samples = cnt - 1 + cache.append(('num-samples'.encode(self.encoding), + str(n_samples).encode(self.encoding))) + self.write_cache(env, cache) + print('Created lmdb dataset with %d samples' % n_samples) diff --git a/mmocr/datasets/preparers/obtainers/naive_data_obtainer.py b/mmocr/datasets/preparers/obtainers/naive_data_obtainer.py index e429902b3..664ca6817 100644 --- a/mmocr/datasets/preparers/obtainers/naive_data_obtainer.py +++ b/mmocr/datasets/preparers/obtainers/naive_data_obtainer.py @@ -127,7 +127,7 @@ def extract(self, elif '.finish' not in name and len(name) > 0: while True: c = input(f'{dst_path} already exists when extracting ' - '{zip_name}, whether to unzip again? (y/n)') + '{zip_name}, unzip again? (y/N) ') or 'N' if c.lower() in ['y', 'n']: extracted = c == 'n' break diff --git a/tools/dataset_converters/prepare_dataset.py b/tools/dataset_converters/prepare_dataset.py index a075804cb..84b8a0353 100644 --- a/tools/dataset_converters/prepare_dataset.py +++ b/tools/dataset_converters/prepare_dataset.py @@ -29,6 +29,13 @@ def parse_args(): default=['train', 'test', 'val'], help='A list of the split that would like to prepare.', nargs='+') + parser.add_argument( + '--lmdb', + action='store_true', + default=False, + help='Whether to dump the textrecog dataset to LMDB format, It\'s a ' + 'shortcut to force the dataset to be dumped in lmdb format. ' + 'Applicable when --task=textrecog') parser.add_argument( '--overwrite-cfg', action='store_true', @@ -73,8 +80,54 @@ def parse_meta(task: str, meta_path: str) -> None: time.sleep(1) +def force_lmdb(cfg): + """Force the dataset to be dumped in lmdb format. + + Args: + cfg (Config): Config object. + + Returns: + Config: Config object. + """ + for split in ['train', 'val', 'test']: + preparer_cfg = cfg.get(f'{split}_preparer') + if preparer_cfg: + if preparer_cfg.get('dumper') is None: + raise ValueError( + f'{split} split does not come with a dumper, ' + 'so most likely the annotations are MMOCR-ready and do ' + 'not need any adaptation, and it ' + 'cannot be dumped in LMDB format.') + preparer_cfg.dumper['type'] = 'TextRecogLMDBDumper' + + cfg.config_generator['dataset_name'] = f'{cfg.dataset_name}_lmdb' + + for split in ['train_anns', 'val_anns', 'test_anns']: + if split in cfg.config_generator: + # It can be None when users want to clear out the default + # value + if not cfg.config_generator[split]: + continue + ann_list = cfg.config_generator[split] + for ann_dict in ann_list: + ann_dict['ann_file'] = ( + osp.splitext(ann_dict['ann_file'])[0] + '.lmdb') + else: + if split == 'train_anns': + ann_list = [dict(ann_file='textrecog_train.lmdb')] + elif split == 'test_anns': + ann_list = [dict(ann_file='textrecog_test.lmdb')] + else: + ann_list = [] + cfg.config_generator[split] = ann_list + + return cfg + + def main(): args = parse_args() + if args.lmdb and args.task != 'textrecog': + raise ValueError('--lmdb only works with --task=textrecog') for dataset in args.datasets: if not osp.isdir(osp.join(args.dataset_zoo_path, dataset)): warnings.warn(f'{dataset} is not supported yet. Please check ' @@ -86,10 +139,12 @@ def main(): cfg = Config.fromfile(cfg_path) if args.overwrite_cfg and cfg.get('config_generator', None) is not None: - cfg.config_generator.overwrite = args.overwrite_cfg + cfg.config_generator.overwrite_cfg = args.overwrite_cfg cfg.nproc = args.nproc cfg.task = args.task cfg.dataset_name = dataset + if args.lmdb: + cfg = force_lmdb(cfg) preparer = DatasetPreparer.from_file(cfg) preparer.run(args.splits)