-
Notifications
You must be signed in to change notification settings - Fork 28
/
mnist_eval.py
62 lines (50 loc) · 1.61 KB
/
mnist_eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import torch
import torch.nn as nn
from torch.nn.utils import weight_norm
import config
from temporal_ensembling import train
from utils import GaussianNoise, savetime, save_exp
class CNN(nn.Module):
def __init__(self, batch_size, std, p=0.5, fm1=16, fm2=32):
super(CNN, self).__init__()
self.fm1 = fm1
self.fm2 = fm2
self.std = std
self.gn = GaussianNoise(batch_size, std=self.std)
self.act = nn.ReLU()
self.drop = nn.Dropout(p)
self.conv1 = weight_norm(nn.Conv2d(1, self.fm1, 3, padding=1))
self.conv2 = weight_norm(nn.Conv2d(self.fm1, self.fm2, 3, padding=1))
self.mp = nn.MaxPool2d(3, stride=2, padding=1)
self.fc = nn.Linear(self.fm2 * 7 * 7, 10)
def forward(self, x):
if self.training:
x = self.gn(x)
x = self.act(self.mp(self.conv1(x)))
x = self.act(self.mp(self.conv2(x)))
x = x.view(-1, self.fm2 * 7 * 7)
x = self.drop(x)
x = self.fc(x)
return x
# metrics
accs = []
accs_best = []
losses = []
sup_losses = []
unsup_losses = []
idxs = []
ts = savetime()
cfg = vars(config)
for i in xrange(cfg['n_exp']):
model = CNN(cfg['batch_size'], cfg['std'])
seed = cfg['seeds'][i]
acc, acc_best, l, sl, usl, indices = train(model, seed, **cfg)
accs.append(acc)
accs_best.append(acc_best)
losses.append(l)
sup_losses.append(sl)
unsup_losses.append(usl)
idxs.append(indices)
print 'saving experiment'
save_exp(ts, losses, sup_losses, unsup_losses,
accs, accs_best, idxs, **cfg)