This repository has been archived by the owner on Nov 20, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 42
/
utils.py
84 lines (59 loc) · 1.97 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import argparse
import os
import shutil
import time
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import numpy as np
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
torch.save(state, filename)
if is_best:
shutil.copyfile(filename, 'model_best.pth.tar')
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def adjust_learning_rate(optimizer, epoch, base_lr):
"""Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
lr = max(base_lr * (0.5 ** (epoch // 20)), 1e-5)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
#if (epoch + 1) == 51 or (epoch + 1) == 101 or (epoch + 1) == 151:
# #lr = lr * 0.1
# for param_group in optimizer.param_groups:
# param_group['lr'] *= 0.1
def accuracy(output, target, topk=(1,)):
"""Computes the precision@k for the specified values of k"""
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
res.append(correct_k.mul_(100.0 / batch_size))
return res
def avg_class_acc(output, target, all, correct):
num_cls = all.shape[0]
for i in range(num_cls):
idx = (target == i)
all[i] += np.sum(idx)
correct[i] += np.sum( output[idx] == target[idx] )
return all, correct