-
Notifications
You must be signed in to change notification settings - Fork 4
/
utils.py
137 lines (121 loc) · 4.61 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import os
import sys
import json
import random
import numpy as np
import torch
def setup_device(gpu_id):
#set up GPUS
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES'] = ""
gpu_id = int(gpu_id)
if gpu_id >= 0:
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id)
print("set CUDA_VISIBLE_DEVICES=%s"%gpu_id)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("using device %s"%device)
return device
def setup_savedir(prefix="",basedir="./experiments",args=None,append_args=[]):
savedir = prefix
if len(append_args) > 0 and args is not None:
for arg_opt in append_args:
arg_value = getattr(args, arg_opt)
savedir +="_"+arg_opt+"-"+str(arg_value)
else:
savedir += "exp"
savedir = savedir.replace(" ","").replace("'","").replace('"','')
savedir = os.path.join(basedir,savedir)
#if exists, append _num-[num]
i = 1
savedir_ori = savedir
while True:
try:
os.makedirs(savedir)
break
except FileExistsError as e:
savedir = savedir_ori+"_num-%d"%i
i+=1
print("made the log directory",savedir)
return savedir
def save_args(savedir,args,name="args.json"):
#save args as "args.json" in the savedir
path = os.path.join(savedir,name)
with open(path, 'w') as f:
json.dump( vars(args), f, sort_keys=True, indent=4)
print("args saved as %s"%path)
def save_json(dict,path):
with open(path, 'w') as f:
json.dump( dict, f, sort_keys=True, indent=4)
print("log saved at %s"%path)
def resume_model(model,resume,state_dict_key = "model"):
'''
model:pytorch model
resume: path to the resume file
state_dict_key: dict key
'''
print("resuming trained weights from %s"%resume)
checkpoint = torch.load(resume,map_location='cpu')
if state_dict_key is not None:
pretrained_dict = checkpoint[state_dict_key]
else:
pretrained_dict = checkpoint
try:
model.load_state_dict(pretrained_dict)
except RuntimeError as e:
print(e)
print("can't load the all weights due to error above, trying to load part of them!")
model_dict = model.state_dict()
# 1. filter out unnecessary keys
pretrained_dict_use = {}
pretrained_dict_ignored = {}
for k, v in pretrained_dict.items():
if k in model_dict:
pretrained_dict_use[k] = v
else:
pretrained_dict_ignored[k] = v
pretrained_dict =pretrained_dict_use
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
# 3. load the new state dict
model.load_state_dict(model_dict)
print("resumed only",pretrained_dict.keys())
print("ignored:",pretrained_dict_ignored.keys())
return model
def save_checkpoint(path,model,key="model"):
#save model state dict
checkpoint = {}
checkpoint[key] = model.state_dict()
torch.save(checkpoint, path)
print("checkpoint saved at",path)
def check_gitstatus():
try:
import git
except:
print("cannot import gitpython ; try pip install gitpython")
return None
#from https://stackoverflow.com/questions/14989858/get-the-current-git-hash-in-a-python-script
#from https://stackoverflow.com/questions/31540449/how-to-check-if-a-git-repo-has-uncommitted-changes-using-python
#from https://stackoverflow.com/questions/33733453/get-changed-files-using-gitpython/42792158
try:
repo = git.Repo(search_parent_directories=True)
sha = repo.head.object.hexsha
untracked = repo.untracked_files
changed = [ item.a_path for item in repo.index.diff(None) ]
except Exception as e:
print(e)
return str(e)
return {"hash":sha,"changed":changed,"untracked":untracked}
def make_deterministic(seed,strict=False):
#https://github.com/pytorch/pytorch/issues/7068#issuecomment-487907668
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
if strict:
#https://github.com/pytorch/pytorch/issues/7068#issuecomment-515728600
torch.backends.cudnn.enabled = False
print("strict reproducability required! cudnn disabled. make sure to set num_workers=0 too!")