-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_PersonaChat.py
388 lines (326 loc) · 20 KB
/
train_PersonaChat.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
import sys
import logging
import os
import argparse
from transformers import BartTokenizer, AdamW, WEIGHTS_NAME, CONFIG_NAME
import torch
import numpy as np
from torch.utils.data import DataLoader, TensorDataset
from torch.nn.parallel import DistributedDataParallel
from ignite.engine import Engine, Events
from ignite.handlers import ModelCheckpoint, global_step_from_engine
from ignite.metrics import Accuracy, Loss, MetricsLambda, RunningAverage
from ignite.contrib.handlers import ProgressBar, PiecewiseLinear
from ignite.contrib.handlers.tensorboard_logger import TensorboardLogger, OutputHandler, OptimizerParamsHandler
import math
from pprint import pformat
from build_data_PersonaChat import create_data, build_dataloader, build_infer_dataset
from model.modeling_Tmema import TmemaModel
def average_distributed_scalar(scalar, args):
""" Average a scalar over the nodes if we are in distributed training. We use this for distributed evaluation. """
if args.local_rank == -1:
return scalar
scalar_t = torch.tensor(scalar, dtype=torch.float, device=args.device) / torch.distributed.get_world_size()
torch.distributed.all_reduce(scalar_t, op=torch.distributed.ReduceOp.SUM)
return scalar_t.item()
def init_config():
parser = argparse.ArgumentParser(description='TMEM')
parser.add_argument('--gpu', type=int, default=0)
parser.add_argument('--max_grad_norm', type=float, default=1.0)
parser.add_argument('--seed', type=int, default=783435, metavar='S', help='random seed')
parser.add_argument("--train_batch_size", type=int, default=4, help="Batch size for dialogue training")
parser.add_argument("--infer_batch_size", type=int, default=128, help="Batch size for infer training")
parser.add_argument("--valid_batch_size", type=int, default=2, help="Batch size for validation")
parser.add_argument("--lr", type=float, default=5e-6, help="Learning rate")
parser.add_argument("--warmup", type=float, default=0.2, help="warmup rate")
parser.add_argument("--epochs", type=int, default=3, help="Number of training epochs")
parser.add_argument("--num_latent", type=int, default=10, help="number of latent")
parser.add_argument("--num_latent2", type=int, default=10, help="number of latent2")
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu",
help="Device (cuda or cpu)")
parser.add_argument("--local_rank", type=int, default=-1,
help="Local rank for distributed training (-1: not distributed)")
parser.add_argument("--output_dir", type=str, default="persona",
help="save model")
parser.add_argument("--load_from", type=str, default=None,
help="save model")
parser.add_argument('--eval', action='store_true', default=False, help='eval model')
parser.add_argument('--revised', action='store_true', default=False, help='use revised')
parser.add_argument("--fp16", type=str, default="",
help="Set to O0, O1, O2 or O3 for fp16 training (see apex documentation)")
parser.add_argument("--eval_before_start", action='store_true',
help="If true start with a first evaluation before training")
parser.add_argument("--gradient_accumulation_steps", type=int, default=8,
help="Accumulate gradients on several steps")
parser.add_argument('--cand', type=int, default=5, help='number of candidate')
parser.add_argument("--max_history", type=int, default=7, help="length of dialogue context")
parser.add_argument('--smalldataset', action='store_true', default=False, help='use 32 pairs')
parser.add_argument('--model_type', type=str, help='model_type')
args = parser.parse_args()
args.cuda = torch.cuda.is_available()
args.distributed = (args.local_rank != -1)
if args.distributed:
torch.cuda.set_device(args.local_rank)
args.device = torch.device("cuda", args.local_rank)
torch.distributed.init_process_group(backend='nccl', init_method='env://')
elif args.cuda:
torch.cuda.set_device(args.gpu)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.cuda:
torch.cuda.manual_seed(args.seed)
torch.backends.cudnn.deterministic = True
args = argparse.Namespace(**vars(args))
return args
if __name__ == '__main__':
os.chdir(os.path.dirname(os.path.abspath(__file__)))
add_special_tokens = {'additional_special_tokens': ['<query>', '<response>', '<latent>', '<persona>']}
args = init_config()
if args.revised:
data_from = "_revised"
else:
data_from = "_original"
if not args.eval:
if not os.path.exists(os.path.join(args.output_dir + data_from)):
os.makedirs(os.path.join(args.output_dir + data_from))
log_file = os.path.join(args.output_dir + data_from, "train.log")
else:
log_file = os.path.join(args.load_from, "eval.log")
program = os.path.basename(sys.argv[0])
logger = logging.getLogger(program)
format_str = logging.Formatter('%(asctime)s: %(levelname)s: %(message)s')
logger.setLevel(level=logging.INFO)
sh = logging.StreamHandler()
sh.setFormatter(format_str)
fh = logging.FileHandler(filename=log_file, encoding='utf-8', mode='w')
fh.setFormatter(format_str)
logger.addHandler(sh)
logger.addHandler(fh)
logger.info(r"running %s" % ''.join(sys.argv))
logger.info("Arguments: %s", pformat(args))
logger.info("Get pretrained model and tokenizer")
if args.load_from != None:
tokenizer = BartTokenizer.from_pretrained(args.load_from)
model = TmemaModel.from_pretrained(args.load_from)
else:
tokenizer = BartTokenizer.from_pretrained("facebook/bart-large")
num_added_toks = tokenizer.add_special_tokens(add_special_tokens)
logger.info('We have added {} tokens'.format(num_added_toks))
model = TmemaModel.from_pretrained("facebook/bart-large", num_labels=1,
num_token=len(tokenizer),
num_latent=args.num_latent, num_latent2=args.num_latent2)
model.resize_token_embeddings(len(tokenizer))
model.config.decoder_start_token_id = tokenizer.convert_tokens_to_ids('<response>')
model.config.forced_bos_token_id = None
logger.info('We have {} tokens'.format(len(tokenizer)))
model.to(args.device)
logger.info("Complete loading model.")
logger.info("Build train data")
persona, query, response, cand = create_data(f"data/ConvAI2/train_self{data_from}.txt", args.smalldataset)
train_data = build_dataloader(persona, query, response, cand, tokenizer, max_history=args.max_history, n_cand=args.cand)
logger.info("Build valid data")
persona, query, response, cand = create_data(f"data/ConvAI2/valid_self{data_from}.txt", args.smalldataset)
val_data = build_dataloader(persona, query, response, cand, tokenizer, max_history=args.max_history, use_all=True)
logger.info("Build infer data")
infer_data = build_infer_dataset(tokenizer, "data/dnli/dialogue_nli_train.jsonl", args.smalldataset)
MODEL_INPUTS = ["input_ids", "attention_mask", "lmlabels", "decoder_input_ids", "decoder_attention_mask",
"cls_index", "clslabel", "per_input_ids", "per_attention_mask"]
INFER_INPUTS = ["encoder_input_ids", "decoder_input_ids", "attention_mask", "decoder_attention_mask",
"lmlabels"]
trainset = []
valset = []
inferset = []
for input_name in MODEL_INPUTS:
if input_name == "clslabel":
tensor = train_data[input_name].view(-1)
logger.info("{}: {}".format(input_name, tensor.size()))
trainset.append(tensor)
tensor = val_data[input_name].view(-1)
logger.info("{}: {}".format(input_name, tensor.size()))
valset.append(tensor)
else:
tensor = train_data[input_name].view(-1, args.cand, train_data[input_name].size(-1))
trainset.append(tensor)
logger.info("{}: {}".format(input_name, tensor.size()))
tensor = val_data[input_name].view(-1, 20, val_data[input_name].size(-1))
valset.append(tensor)
logger.info("{}: {}".format(input_name, tensor.size()))
for input_name in INFER_INPUTS:
tensor = infer_data[input_name].view(-1, 1, infer_data[input_name].size(-1))
logger.info("{}: {}".format(input_name, tensor.size()))
inferset.append(tensor)
train_dataset = TensorDataset(*trainset)
val_dataset = TensorDataset(*valset)
infer_dataset = TensorDataset(*inferset)
logger.info("Prepare dataloader.")
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) if args.distributed else None
infer_sampler = torch.utils.data.distributed.DistributedSampler(infer_dataset) if args.distributed else None
val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset) if args.distributed else None
train_loader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size,
shuffle=(not args.distributed), num_workers=0)
infer_loader = DataLoader(infer_dataset, sampler=infer_sampler, batch_size=args.infer_batch_size,
shuffle=(not args.distributed), num_workers=0)
val_loader = DataLoader(val_dataset, sampler=val_sampler, batch_size=args.valid_batch_size, shuffle=False)
train_iter = len(train_loader)
memory1_params = list(map(id, model.memory1))
base_params = filter(lambda p: id(p) not in memory1_params,
model.parameters())
optimizer_infer = AdamW(model.parameters(), lr=args.lr, correct_bias=True)
optimizer_bart = AdamW(base_params, lr=args.lr, correct_bias=True)
if args.distributed:
model = DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank)
def infer_update(engine, batch):
model.train()
batch = tuple(input_tensor.to(args.device) for input_tensor in batch)
infer_input_ids, infer_decoder_input_ids, infer_attention_mask, \
infer_decoder_attention_mask, infer_lmlabels = batch
outputs = model(infer_input_ids=infer_input_ids,
infer_decoder_input_ids=infer_decoder_input_ids,
infer_attention_mask=infer_attention_mask,
infer_lmlabels=infer_lmlabels,
infer_decoder_attention_mask=infer_decoder_attention_mask
)
loss = outputs.loss
if args.fp16:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_norm)
else:
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
optimizer_infer.step()
optimizer_infer.zero_grad()
return {'loss': loss.item()}
infer_trainer = Engine(infer_update)
def update(engine, batch):
model.train()
batch = tuple(input_tensor.to(args.device) for input_tensor in batch)
input_ids, attention_mask, lmlabels, decoder_input_ids, decoder_attention_mask, cls_index, clslabel, \
per_input_ids, per_attention_mask = batch
outputs = model(input_ids=input_ids, attention_mask=attention_mask, lmlabels=lmlabels,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
cls_index=cls_index, clslabel=clslabel,
per_input_ids=per_input_ids,
per_attention_mask=per_attention_mask,
)
(lm_loss, cls_loss, m_loss, _, bow) = outputs.loss
loss = lm_loss + cls_loss + m_loss + bow
if args.fp16:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_norm)
else:
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
if engine.state.iteration % args.gradient_accumulation_steps == 0:
optimizer_bart.step()
model.zero_grad()
return {'loss': loss.item(), 'cls': cls_loss.item(), 'lm': lm_loss.item(),
"mem": m_loss.item(), "bow": bow.item()}
trainer = Engine(update)
def inference(engine, batch):
model.eval()
with torch.no_grad():
batch = tuple(input_tensor.to(args.device) for input_tensor in batch)
input_ids, attention_mask, lmlabels, decoder_input_ids, decoder_attention_mask, cls_index, clslabel, \
per_input_ids, per_attention_mask = batch
outputs = model(input_ids=input_ids, attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
cls_index=cls_index,
per_input_ids=per_input_ids,
per_attention_mask=per_attention_mask
)
(lm_logits, cls_logits) = outputs.logits
tmp_lmlogits = lm_logits.view(-1, 20,
lm_logits.size(1), lm_logits.size(2))
tmp_lmlogits = tmp_lmlogits[:, 0, :, :].contiguous().view(-1, lm_logits.size(-1))
lmlabels = lmlabels[:, 0, :].contiguous().view(-1)
cls_logits = cls_logits.view(-1, 20)
return (tmp_lmlogits, cls_logits), (lmlabels, clslabel)
evaluator = Engine(inference)
infer_trainer.add_event_handler(Events.EPOCH_COMPLETED, lambda _: trainer.run(train_loader))
trainer.add_event_handler(Events.EPOCH_COMPLETED, lambda _: evaluator.run(val_loader))
if args.eval_before_start:
infer_trainer.add_event_handler(Events.STARTED, lambda _: evaluator.run(val_loader))
# Make sure distributed data samplers split the dataset nicely between the distributed processes
if args.distributed:
infer_trainer.add_event_handler(Events.EPOCH_STARTED,
lambda engine: infer_sampler.set_epoch(engine.state.epoch))
trainer.add_event_handler(Events.EPOCH_STARTED, lambda engine: train_sampler.set_epoch(engine.state.epoch))
evaluator.add_event_handler(Events.EPOCH_STARTED, lambda engine: val_sampler.set_epoch(engine.state.epoch))
scheduler_infer = PiecewiseLinear(optimizer_infer, "lr", [(0, args.lr), (args.epochs * len(infer_loader), 0.0)])
scheduler_bart = PiecewiseLinear(optimizer_bart, "lr", [(0, args.lr), (args.epochs * len(train_loader), 0.0)])
infer_trainer.add_event_handler(Events.ITERATION_STARTED, scheduler_infer)
trainer.add_event_handler(Events.ITERATION_STARTED, scheduler_bart)
RunningAverage(output_transform=lambda x: x["loss"]).attach(infer_trainer, "loss")
RunningAverage(output_transform=lambda x: x["loss"]).attach(trainer, "loss")
RunningAverage(output_transform=lambda x: x["cls"]).attach(trainer, "cls")
RunningAverage(output_transform=lambda x: x["lm"]).attach(trainer, "lm")
RunningAverage(output_transform=lambda x: x["bow"]).attach(trainer, "bow")
RunningAverage(output_transform=lambda x: x["mem"]).attach(trainer, "mem")
if args.eval:
logger.info("Begin evaluating")
metrics = {
"nll": Loss(torch.nn.CrossEntropyLoss(ignore_index=-100), output_transform=lambda x: (x[0][0], x[1][0]))}
metrics.update({"average_nll": MetricsLambda(average_distributed_scalar, metrics["nll"], args)})
metrics["average_ppl"] = MetricsLambda(math.exp, metrics["average_nll"])
for name, metric in metrics.items():
metric.attach(evaluator, name)
evaluator.add_event_handler(Events.COMPLETED,
lambda __: logger.info("Validation: %s" % pformat(evaluator.state.metrics)))
pbar_infer = ProgressBar(persist=True, ncols=140)
pbar = ProgressBar(position=0, persist=True, ncols=140)
pbar_infer.attach(evaluator)
evaluator.run(val_loader)
else:
metrics = {
"nll": Loss(torch.nn.CrossEntropyLoss(ignore_index=-100), output_transform=lambda x: (x[0][0], x[1][0])),
"accuracy": Accuracy(output_transform=lambda x: (x[0][1], x[1][1]))}
metrics.update({"average_nll": MetricsLambda(average_distributed_scalar, metrics["nll"], args),
"average_accuracy": MetricsLambda(average_distributed_scalar, metrics["accuracy"], args)})
metrics["average_ppl"] = MetricsLambda(math.exp, metrics["average_nll"])
for name, metric in metrics.items():
metric.attach(evaluator, name)
# On the main process: add progress bar, tensorboard, checkpoints and save model, configuration and tokenizer before we start to train
if args.local_rank in [-1, 0]:
pbar_infer = ProgressBar(persist=True, ncols=140)
pbar = ProgressBar(position=0, persist=True, ncols=140)
pbar_infer.attach(infer_trainer, metric_names=["loss"])
pbar.attach(trainer, metric_names=["loss", "cls", "lm", "mem", "bow"])
infer_trainer.add_event_handler(Events.EPOCH_COMPLETED,
lambda engine: logger.info(f"Complete infer epoch: {engine.state.epoch}"))
trainer.add_event_handler(Events.EPOCH_COMPLETED,
lambda engine: logger.info(f"Complete trainer epoch: {engine.state.epoch}"))
evaluator.add_event_handler(Events.COMPLETED,
lambda __: logger.info("Validation: %s" % pformat(evaluator.state.metrics)))
log_dir = os.path.join(args.output_dir + data_from)
tb_logger = TensorboardLogger(log_dir)
tb_logger.attach(infer_trainer, log_handler=OutputHandler(tag="infer", metric_names=["loss"]),
event_name=Events.ITERATION_COMPLETED)
tb_logger.attach(trainer, log_handler=OutputHandler(tag="training",
metric_names=["loss", "cls", "lm", "mem", "bow"]),
event_name=Events.ITERATION_COMPLETED)
tb_logger.attach(infer_trainer, log_handler=OptimizerParamsHandler(optimizer_infer),
event_name=Events.ITERATION_STARTED)
tb_logger.attach(trainer, log_handler=OptimizerParamsHandler(optimizer_bart),
event_name=Events.ITERATION_STARTED)
tb_logger.attach(evaluator, log_handler=OutputHandler(tag="validation", metric_names=list(metrics.keys()),
global_step_transform=global_step_from_engine(trainer)),
event_name=Events.EPOCH_COMPLETED)
checkpoint_handler = ModelCheckpoint(log_dir, 'checkpoint', n_saved=None)
infer_trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler, {
'mymodel': getattr(model, 'module', model)}) # "getattr" takes care of distributed encapsulation
torch.save(args, log_dir + '/model_training_args.bin')
getattr(model, 'module', model).config.to_json_file(os.path.join(log_dir, CONFIG_NAME))
tokenizer.save_pretrained(log_dir)
logger.info("Begin training")
# Run the training
infer_trainer.run(infer_loader, max_epochs=args.epochs)
# On the main process: close tensorboard logger and rename the last checkpoint (for easy re-loading with OpenAIGPTModel.from_pretrained method)
if args.local_rank in [-1, 0] and args.epochs > 0:
os.rename(os.path.join(log_dir, checkpoint_handler._saved[-1][1]),
os.path.join(log_dir,
WEIGHTS_NAME)) # TODO: PR in ignite to have better access to saved file paths (cleaner)
tb_logger.close()