-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
129 lines (109 loc) · 5 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import torch
import math
import time
import numpy as np
class ModelWrapper(torch.nn.Module):
def __init__(self, model, feature_dim, num_classes, normalize=False, initial_weights=None):
super(ModelWrapper, self).__init__()
self.model = model
self.classification_head = torch.nn.Linear(feature_dim, num_classes)
self.normalize = normalize
if initial_weights is None:
initial_weights = torch.zeros_like(self.classification_head.weight)
torch.nn.init.kaiming_uniform_(initial_weights, a=math.sqrt(5))
self.classification_head.weight = torch.nn.Parameter(initial_weights.clone())
self.classification_head.bias = torch.nn.Parameter(
torch.zeros_like(self.classification_head.bias))
# Note: modified. Get rid of the language part.
if hasattr(self.model, 'transformer'):
delattr(self.model, 'transformer')
def forward(self, images, return_features=False):
features = self.model.encode_image(images)
if self.normalize:
features = features / features.norm(dim=-1, keepdim=True)
logits = self.classification_head(features)
if return_features:
return logits, features
return logits
def get_model_from_sd(state_dict, base_model):
feature_dim = state_dict['classification_head.weight'].shape[1]
num_classes = state_dict['classification_head.weight'].shape[0]
model = ModelWrapper(base_model, feature_dim, num_classes, normalize=True)
for p in model.parameters():
p.data = p.data.float()
model.load_state_dict(state_dict)
model = model.cuda()
devices = [x for x in range(torch.cuda.device_count())]
return torch.nn.DataParallel(model, device_ids=devices)
def maybe_dictionarize_batch(batch):
if isinstance(batch, dict):
return batch
if len(batch) == 2:
return {'images': batch[0], 'labels': batch[1]}
elif len(batch) == 3:
return {'images': batch[0], 'labels': batch[1], 'metadata': batch[2]}
else:
raise ValueError(f'Unexpected number of elements: {len(batch)}')
def test_model_on_dataset(model, dataset):
model.eval()
device = 'cuda'
with torch.no_grad():
top1, correct, n = 0., 0., 0.
end = time.time()
loader = dataset.test_loader
if type(dataset).__name__ == 'ImageNet2p':
loader = dataset.train_loader
# assert to make sure the imagenet held-out minival logic is consistent across machines.
# tested on a few machines but if this fails for you please submit an issue and we will resolve.
assert dataset.train_dataset.__getitem__(dataset.sampler.indices[1000])['image_paths'].endswith('n01675722_4108.JPEG')
for i, batch in enumerate(loader):
batch = maybe_dictionarize_batch(batch)
inputs, labels = batch['images'].cuda(), batch['labels'].cuda()
data_time = time.time() - end
y = labels
if 'image_paths' in batch:
image_paths = batch['image_paths']
logits = model(inputs)
projection_fn = getattr(dataset, 'project_logits', None)
if projection_fn is not None:
logits = projection_fn(logits, device)
if hasattr(dataset, 'project_labels'):
y = dataset.project_labels(y, device)
if isinstance(logits, list):
logits = logits[0]
pred = logits.argmax(dim=1, keepdim=True).to(device)
if hasattr(dataset, 'accuracy'):
acc1, num_total = dataset.accuracy(logits, y, image_paths, None)
correct += acc1
n += num_total
else:
correct += pred.eq(y.view_as(pred)).sum().item()
n += y.size(0)
batch_time = time.time() - end
end = time.time()
if i % 20 == 0:
percent_complete = 100.0 * i / len(loader)
print(
f"[{percent_complete:.0f}% {i}/{len(loader)}]\t"
f"Acc: {100 * (correct/n):.2f}\tData (t) {data_time:.3f}\tBatch (t) {batch_time:.3f}"
)
top1 = correct / n
return top1
def assign_learning_rate(param_group, new_lr):
param_group["lr"] = new_lr
def _warmup_lr(base_lr, warmup_length, step):
return base_lr * (step + 1) / warmup_length
def cosine_lr(optimizer, base_lrs, warmup_length, steps):
if not isinstance(base_lrs, list):
base_lrs = [base_lrs for _ in optimizer.param_groups]
assert len(base_lrs) == len(optimizer.param_groups)
def _lr_adjuster(step):
for param_group, base_lr in zip(optimizer.param_groups, base_lrs):
if step < warmup_length:
lr = _warmup_lr(base_lr, warmup_length, step)
else:
e = step - warmup_length
es = steps - warmup_length
lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr
assign_learning_rate(param_group, lr)
return _lr_adjuster