-
Notifications
You must be signed in to change notification settings - Fork 1
/
make_idx.py
29 lines (23 loc) · 952 Bytes
/
make_idx.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from subprocess import call
import os
import argparse
from tqdm import tqdm
def make_idx_files(tfrecord_root, idx_file_root):
if not os.path.isdir(idx_file_root):
os.makedirs(idx_file_root)
for f in tqdm(os.listdir(tfrecord_root), desc='Writing idx'):
tfrecord_path = os.path.join(tfrecord_root, f)
idx_file_path = os.path.join(idx_file_root, f'{f}.idx')
call(['tfrecord2idx', tfrecord_path, idx_file_path])
def main(root):
print('Processing train:')
make_idx_files(os.path.join(root, 'train'),
os.path.join(root, 'idx_files/train'))
print('Processing validation:')
make_idx_files(os.path.join(root, 'validation'),
os.path.join(root, 'idx_files/validation'))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--tfrecord_root', default='./tf_records')
args = parser.parse_args()
main(args.tfrecord_root)