-
Notifications
You must be signed in to change notification settings - Fork 18
/
pretrain.py
117 lines (93 loc) · 4.79 KB
/
pretrain.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
from util import *
from model import *
from dataloader import *
class PretrainModelManager:
def __init__(self, args, data):
set_seed(args.seed)
self.model = BertForModel.from_pretrained(args.bert_model, cache_dir = "", num_labels = data.n_known_cls)
if args.freeze_bert_parameters:
self.freeze_parameters(self.model)
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(self.device)
n_gpu = torch.cuda.device_count()
if n_gpu > 1:
self.model = torch.nn.DataParallel(self.model)
self.num_train_optimization_steps = int(len(data.train_labeled_examples) / args.train_batch_size) * args.num_train_epochs
self.optimizer = self.get_optimizer(args)
self.best_eval_score = 0
def eval(self, args, data):
self.model.eval()
total_labels = torch.empty(0,dtype=torch.long).to(self.device)
total_logits = torch.empty((0, data.n_known_cls)).to(self.device)
for batch in tqdm(data.eval_dataloader, desc="Iteration"):
batch = tuple(t.to(self.device) for t in batch)
input_ids, input_mask, segment_ids, label_ids = batch
with torch.set_grad_enabled(False):
_, logits = self.model(input_ids, segment_ids, input_mask, mode = 'eval')
total_labels = torch.cat((total_labels,label_ids))
total_logits = torch.cat((total_logits, logits))
total_probs, total_preds = F.softmax(total_logits.detach(), dim=1).max(dim = 1)
y_pred = total_preds.cpu().numpy()
y_true = total_labels.cpu().numpy()
acc = round(accuracy_score(y_true, y_pred) * 100, 2)
return acc
def train(self, args, data):
wait = 0
best_model = None
for epoch in trange(int(args.num_train_epochs), desc="Epoch"):
self.model.train()
tr_loss = 0
nb_tr_examples, nb_tr_steps = 0, 0
for step, batch in enumerate(tqdm(data.train_labeled_dataloader, desc="Iteration")):
batch = tuple(t.to(self.device) for t in batch)
input_ids, input_mask, segment_ids, label_ids = batch
with torch.set_grad_enabled(True):
loss = self.model(input_ids, segment_ids, input_mask, label_ids, mode = "train")
loss.backward()
tr_loss += loss.item()
self.optimizer.step()
self.optimizer.zero_grad()
nb_tr_examples += input_ids.size(0)
nb_tr_steps += 1
loss = tr_loss / nb_tr_steps
print('train_loss',loss)
eval_score = self.eval(args, data)
print('eval_score',eval_score)
if eval_score > self.best_eval_score:
best_model = copy.deepcopy(self.model)
wait = 0
self.best_eval_score = eval_score
else:
wait += 1
if wait >= args.wait_patient:
break
self.model = best_model
if args.save_model:
self.save_model(args)
def get_optimizer(self, args):
param_optimizer = list(self.model.named_parameters())
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
optimizer = BertAdam(optimizer_grouped_parameters,
lr = args.lr_pre,
warmup = args.warmup_proportion,
t_total = self.num_train_optimization_steps)
return optimizer
def save_model(self, args):
if not os.path.exists(args.pretrain_dir):
os.makedirs(args.pretrain_dir)
self.save_model = self.model.module if hasattr(self.model, 'module') else self.model
model_file = os.path.join(args.pretrain_dir, WEIGHTS_NAME)
model_config_file = os.path.join(args.pretrain_dir, CONFIG_NAME)
torch.save(self.save_model.state_dict(), model_file)
with open(model_config_file, "w") as f:
f.write(self.save_model.config.to_json_string())
def freeze_parameters(self,model):
for name, param in model.bert.named_parameters():
param.requires_grad = False
if "encoder.layer.11" in name or "pooler" in name:
param.requires_grad = True