-
Notifications
You must be signed in to change notification settings - Fork 3
/
market1501.py
164 lines (136 loc) · 5.46 KB
/
market1501.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import collections
import os
import random
import re
import numpy as np
from PIL import Image
import torch
from torch.utils.data import dataset, sampler
from torchvision.datasets.folder import default_loader
import torchvision.transforms.functional as F
from file_utils import load_pickle
from kpt_to_pap_mask import gen_pap_masks
def list_pictures(directory, ext='jpg|jpeg|bmp|png|ppm'):
return sorted([os.path.join(root, f)
for root, _, files in os.walk(directory) for f in files
if re.match(r'([\w]+\.(?:' + ext + '))', f)])
class Market1501(dataset.Dataset):
"""
Attributes:
imgs (list of str): dataset image file paths
_id2label (dict): mapping from person id to softmax continuous label
"""
@staticmethod
def id(file_path):
"""
:param file_path: unix style file path
:return: person id
"""
return int(file_path.split('/')[-1].split('_')[0])
@staticmethod
def camera(file_path):
"""
:param file_path: unix style file path
:return: camera id
"""
return int(file_path.split('/')[-1].split('_')[1][1])
@property
def ids(self):
"""
:return: person id list corresponding to dataset image paths
"""
return [self.id(path) for path in self.imgs]
@property
def unique_ids(self):
"""
:return: unique person ids in ascending order
"""
return sorted(set(self.ids))
@property
def cameras(self):
"""
:return: camera id list corresponding to dataset image paths
"""
return [self.camera(path) for path in self.imgs]
def get_pap_mask(self, im_path):
if 'cuhk' in im_path:
key = 'cuhk03-np-jpg/detected' + '/' + '/'.join(im_path.split('/')[-2:])
else:
key = '/'.join(im_path.split('/')[-3:])
kpt = self.im_path_to_kpt[key]['kpt']
kpt[:, 2] = (kpt[:, 2] > 0.1).astype(np.float)
pap_mask_2p, _ = gen_pap_masks(self.im_path_to_kpt[key]['im_h_w'], (24, 8), kpt, mask_type='PAP_2P')
pap_mask_3p, _ = gen_pap_masks(self.im_path_to_kpt[key]['im_h_w'], (24, 8), kpt, mask_type='PAP_3P')
return pap_mask_2p, pap_mask_3p
def get_ps_label(self, im_path):
ps_label = Image.open('/'.join([self.ps_dir] + im_path.split('/')[-2:]).replace('.jpg', '.png'))
ps_label = ps_label.resize((16, 48), resample=Image.NEAREST)
# ps_label = ps_label.resize((8, 24), resample=Image.NEAREST) # TODO
return ps_label
def __init__(self, root, transform=None, target_transform=None, loader=default_loader, training=None, kpt_file=None, ps_dir=None):
self.root = root
self.transform = transform
self.target_transform = target_transform
self.loader = loader
self.imgs = [path for path in list_pictures(self.root) if self.id(path) != -1]
# convert person id to softmax continuous label
self._id2label = {_id: idx for idx, _id in enumerate(self.unique_ids)}
self.training = training
self.im_path_to_kpt = load_pickle(kpt_file) if kpt_file is not None else None
self.ps_dir = ps_dir
def __getitem__(self, index):
path = self.imgs[index]
target = {'id': self._id2label[self.id(path)]}
img = self.loader(path)
if self.im_path_to_kpt is not None:
target['pap_mask_2p'], target['pap_mask_3p'] = self.get_pap_mask(path)
if self.ps_dir is not None:
target['ps_label'] = self.get_ps_label(path)
if self.training is True:
if random.random() < 0.5:
img = F.hflip(img)
if 'ps_label' in target:
target['ps_label'] = F.hflip(target['ps_label'])
if self.transform is not None:
img = self.transform(img)
if 'pap_mask_2p' in target:
target['pap_mask_2p'], target['pap_mask_3p'] = torch.from_numpy(target['pap_mask_2p']).float(), torch.from_numpy(target['pap_mask_3p']).float()
if 'ps_label' in target:
target['ps_label'] = torch.from_numpy(np.array(target['ps_label'])).long()
return img, target
def __len__(self):
return len(self.imgs)
class RandomIdSampler(sampler.Sampler):
"""
Sampler for triplet semihard sample mining.
Attributes:
_id2index (dict of list): mapping from person id to its image indexes in `data_source`
"""
@staticmethod
def _sample(population, k):
if len(population) < k:
population = population * k
return random.sample(population, k)
def __init__(self, data_source, batch_image):
"""
:param data_source: Market1501 dataset
:param batch_image: batch image size for one person id
"""
super(RandomIdSampler, self).__init__(data_source)
self.data_source = data_source
self.batch_image = batch_image
self._id2index = collections.defaultdict(list)
for idx, path in enumerate(data_source.imgs):
_id = data_source.id(path)
self._id2index[_id].append(idx)
def __iter__(self):
unique_ids = self.data_source.unique_ids
random.shuffle(unique_ids)
imgs = []
for _id in unique_ids:
imgs.extend(self._sample(self._id2index[_id], self.batch_image))
return iter(imgs)
def __len__(self):
return len(self._id2index) * self.batch_image