forked from JayoungKim408/STaSy
-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
57 lines (49 loc) · 1.57 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import os
import torch
import logging
import tensorflow as tf
import torch.nn.functional as F
def restore_checkpoint(ckpt_dir, state, device):
if not tf.io.gfile.exists(ckpt_dir):
tf.io.gfile.makedirs(os.path.dirname(ckpt_dir))
logging.warning(f"No checkpoint found at {ckpt_dir}. "
f"Returned the same state as input")
return state
else:
loaded_state = torch.load(ckpt_dir, map_location=device)
state['optimizer'].load_state_dict(loaded_state['optimizer'])
state['model'].load_state_dict(loaded_state['model'], strict=False)
state['ema'].load_state_dict(loaded_state['ema'])
state['step'] = loaded_state['step']
try:
state['epoch'] = loaded_state['epoch']
except: pass
return state
def save_checkpoint(ckpt_dir, state):
saved_state = {
'optimizer': state['optimizer'].state_dict(),
'model': state['model'].state_dict(),
'ema': state['ema'].state_dict(),
'step': state['step'],
'epoch': state['epoch'],
}
torch.save(saved_state, ckpt_dir)
def apply_activate(data, output_info):
data_t = []
st = 0
for item in output_info:
if item[1] == 'tanh':
ed = st + item[0]
data_t.append(torch.tanh(data[:, st:ed]))
st = ed
elif item[1] == 'sigmoid':
ed = st + item[0]
data_t.append(data[:,st:ed])
st = ed
elif item[1] == 'softmax':
ed = st + item[0]
data_t.append(F.softmax(data[:, st:ed]))
st = ed
else:
assert 0
return torch.cat(data_t, dim=1)