-
Notifications
You must be signed in to change notification settings - Fork 0
/
get_coco_mask.py
61 lines (53 loc) · 2.65 KB
/
get_coco_mask.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
"""
get semantic segmentation annotations from coco data set.
"""
from PIL import Image
import imgviz
import argparse
import os
import tqdm
from pycocotools.coco import COCO
import shutil
import numpy as np
def save_colored_mask(mask, save_path):
lbl_pil = Image.fromarray(mask.astype(np.uint8), mode="P")
colormap = imgviz.label_colormap()
lbl_pil.putpalette(colormap.flatten())
lbl_pil.save(save_path)
# input_dir = ../input/data/
# split = train_all
def main(args):
annotation_file = os.path.join(args.input_dir, '{}.json'.format(args.split))
os.makedirs(os.path.join(args.input_dir, 'SegmentationClass', 'batch_04'), exist_ok=True)
os.makedirs(os.path.join(args.input_dir, 'SegmentationClass', 'batch_05'), exist_ok=True)
os.makedirs(os.path.join(args.input_dir, 'SegmentationClass', 'batch_06'), exist_ok=True)
os.makedirs(os.path.join(args.input_dir, 'JPEGImages', 'batch_04'), exist_ok=True)
os.makedirs(os.path.join(args.input_dir, 'JPEGImages', 'batch_05'), exist_ok=True)
os.makedirs(os.path.join(args.input_dir, 'JPEGImages', 'batch_06'), exist_ok=True)
coco = COCO(annotation_file)
catIds = coco.getCatIds()
imgIds = coco.getImgIds()
print("catIds len:{}, imgIds len:{}".format(len(catIds), len(imgIds)))
for imgId in tqdm.tqdm(imgIds, ncols=100):
img = coco.loadImgs(imgId)[0]
annIds = coco.getAnnIds(imgIds=img['id'], catIds=catIds, iscrowd=None)
anns = coco.loadAnns(annIds)
if len(annIds) > 0:
mask = coco.annToMask(anns[0]) * anns[0]['category_id']
for i in range(len(anns) - 1):
mask += coco.annToMask(anns[i + 1]) * anns[i + 1]['category_id']
img_origin_path = os.path.join(args.input_dir, img['file_name'])
img_output_path = os.path.join(args.input_dir, 'JPEGImages', img['file_name'].replace('batch_01_vt','batch_04').replace('batch_02_vt','batch_05').replace('batch_03','batch_06'))
seg_output_path = os.path.join(args.input_dir, 'SegmentationClass', img['file_name'].replace('.jpg', '.png').replace('batch_01_vt','batch_04').replace('batch_02_vt','batch_05').replace('batch_03','batch_06'))
shutil.copy(img_origin_path, img_output_path)
save_colored_mask(mask, seg_output_path)
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--input_dir", default="../input/data/", type=str,
help="coco dataset directory")
parser.add_argument("--split", default="train_all", type=str,
help="train_all or train or val")
return parser.parse_args()
if __name__ == '__main__':
args = get_args()
main(args)