forked from qiaozhijian/LPD-Net-Pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
script.py
executable file
·81 lines (73 loc) · 2.4 KB
/
script.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import torch
import torch.nn as nn
import torch.nn.functional as F
# import config as cfg
import pynvml
import util.initPara as para
import os
# from torch.autograd import Variable
# import numpy as np
pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(0)
# 这里的0是GPU id
ratio = 1024**2
def print_gpu():
meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle)
used = meminfo.used / ratio
print("used: ", used)
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(3, 6, (3,4), padding=False) # 输入通道数为1,输出通道数为6
self.conv2 = nn.Conv2d(6, 16, 5, padding=True) # 输入通道数为6,输出通道数为16
self.fc1 = nn.Linear(16, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
# 输入x -> conv1 -> relu -> 2x2窗口的最大池化
print(x.shape)
x = self.conv1(x)
print(x.shape)
x = F.relu(x)
print(x.shape)
x = F.max_pool2d(x, 2)
print(x.shape)
# 输入x -> conv2 -> relu -> 2x2窗口的最大池化
x = self.conv2(x)
print(x.shape)
x = F.relu(x)
print(x.shape)
x = F.max_pool2d(x, 2)
print(x.shape)
x = x.permute(0, 2, 3, 1)
print(x.shape)
x = F.relu(self.fc1(x))
print(x.shape)
x = F.relu(self.fc2(x))
print(x.shape)
x = self.fc3(x)
print(x.shape)
return x
def get_learning_rate(epoch):
learning_rate = 0.001*(0.922680834591**epoch)
learning_rate = max(learning_rate, 0.00001) # CLIP THE LEARNING RATE!
return learning_rate
if __name__ == '__main__':
checkpoint = torch.load('./pretrained/lpdnet.ckpt')
saved_state_dict = checkpoint['state_dict']
epoch = checkpoint['epoch']
TOTAL_ITERATIONS = checkpoint['iter']
ave_one_percent_recall = checkpoint['recall']
para.model.load_state_dict(saved_state_dict, strict=True)
if isinstance(para.model, nn.DataParallel):
model_to_save = para.model.module
else:
model_to_save = para.model
save_name = para.args.model_save_path + '/' + 'lpdnet.ckpt'
torch.save({
'epoch': epoch,
'iter': TOTAL_ITERATIONS,
'state_dict': model_to_save.state_dict(),
'optimizer': None,
'recall': ave_one_percent_recall,
}, save_name)