-
Notifications
You must be signed in to change notification settings - Fork 56
/
utils.py
123 lines (96 loc) · 3.38 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import os
import glob
import argparse
import numpy as np
import matplotlib.pyplot as plt
import torch
def show_message(text, verbose=True, end='\n', rank=0):
if verbose and (rank == 0): print(text, end=end)
def str2bool(v):
if isinstance(v, bool): return v
if v.lower() in ('yes', 'true', 't', 'y', '1'): return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'): return False
else: raise argparse.ArgumentTypeError('Boolean value expected.')
def parse_filelist(filelist_path):
with open(filelist_path, 'r') as f:
filelist = [line.strip() for line in f.readlines()]
return filelist
def latest_checkpoint_path(dir_path, regex="checkpoint_*.pt"):
f_list = glob.glob(os.path.join(dir_path, regex))
f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f))))
x = f_list[-1]
return x
def load_latest_checkpoint(logdir, model, optimizer=None):
latest_model_path = latest_checkpoint_path(logdir, regex="checkpoint_*.pt")
print(f'Latest checkpoint: {latest_model_path}')
d = torch.load(
latest_model_path,
map_location=lambda loc, storage: loc
)
iteration = d['iteration']
valid_incompatible_unexp_keys = [
'betas',
'alphas',
'alphas_cumprod',
'alphas_cumprod_prev',
'sqrt_alphas_cumprod',
'sqrt_recip_alphas_cumprod',
'sqrt_recipm1_alphas_cumprod',
'posterior_log_variance_clipped',
'posterior_mean_coef1',
'posterior_mean_coef2'
]
d['model'] = {
key: value for key, value in d['model'].items() \
if key not in valid_incompatible_unexp_keys
}
model.load_state_dict(d['model'], strict=False)
if not isinstance(optimizer, type(None)):
optimizer.load_state_dict(d['optimizer'])
return model, optimizer, iteration
def save_figure_to_numpy(fig):
# save it to a numpy array.
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
return data
def plot_tensor_to_numpy(tensor):
plt.style.use('default')
fig, ax = plt.subplots(figsize=(12, 3))
im = ax.imshow(tensor, aspect="auto", origin="lower", interpolation='none')
plt.colorbar(im, ax=ax)
plt.tight_layout()
fig.canvas.draw()
data = save_figure_to_numpy(fig)
plt.close()
return data
class ConfigWrapper(object):
"""
Wrapper dict class to avoid annoying key dict indexing like:
`config.sample_rate` instead of `config["sample_rate"]`.
"""
def __init__(self, **kwargs):
for k, v in kwargs.items():
if type(v) == dict:
v = ConfigWrapper(**v)
self[k] = v
def keys(self):
return self.__dict__.keys()
def items(self):
return self.__dict__.items()
def values(self):
return self.__dict__.values()
def to_dict_type(self):
return {
key: (value if not isinstance(value, ConfigWrapper) else value.to_dict_type())
for key, value in dict(**self).items()
}
def __len__(self):
return len(self.__dict__)
def __getitem__(self, key):
return getattr(self, key)
def __setitem__(self, key, value):
return setattr(self, key, value)
def __contains__(self, key):
return key in self.__dict__
def __repr__(self):
return self.__dict__.__repr__()