forked from pytorch/rl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
155 lines (130 loc) · 5.01 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Train the transformer model. Configurable via config/train.yaml, but any argument can
also be overridden at the command line.
To run on a single GPU, example:
$ python train.py --batch_size=32 --compile=False
"""
import time
import hydra
import torch
from models.transformer import init_transformer
from torch.optim.lr_scheduler import CosineAnnealingLR
from torchrl.data.rlhf.dataset import get_dataloader
from torchrl.data.rlhf.prompt import PromptData
from utils import get_file_logger, resolve_name_or_path, setup
def create_loss_estimator(eval_iters, ctx):
# helps estimate an arbitrarily accurate loss over either split using many batches
@torch.no_grad()
def estimate_loss(model, dataloader):
model.eval()
losses = torch.zeros(eval_iters)
for k in range(eval_iters):
batch = next(dataloader)
batch.batch_size = []
with ctx:
model(batch)
losses[k] = batch.loss.item()
model.train()
return losses.mean()
return estimate_loss
@hydra.main(version_base="1.1", config_path="config", config_name="train")
def main(cfg):
loss_logger = get_file_logger("loss_logger", "transformer_loss_logger.log")
data_cfg = cfg.data
model_cfg = cfg.model
train_cfg = cfg.train
eval_interval = cfg.io.eval_interval
log_interval = cfg.io.log_interval
eval_iters = cfg.io.eval_iters
out_dir = model_cfg.out_dir
grad_clip = train_cfg.grad_clip
max_iters = train_cfg.max_iters
always_save_checkpoint = train_cfg.always_save_checkpoint
gradient_accumulation_steps = train_cfg.gradient_accumulation_steps
device = cfg.sys.device
dtype = cfg.sys.dtype
compile_ = cfg.sys.compile
ctx = setup(device=device, dtype=dtype)
train_loader = get_dataloader(
data_cfg.batch_size,
data_cfg.block_size,
PromptData,
device,
dataset_name="CarperAI/openai_summarize_tldr",
split="train",
)
val_loader = get_dataloader(
data_cfg.batch_size,
data_cfg.block_size,
PromptData,
device,
dataset_name="CarperAI/openai_summarize_tldr",
split="valid",
)
model = init_transformer(
resolve_name_or_path(model_cfg.name_or_path),
model_cfg.dropout,
device,
compile_model=compile_,
)
optimizer = torch.optim.AdamW(model.parameters(), **train_cfg.optimizer)
scheduler = None
if train_cfg.decay_lr:
scheduler = CosineAnnealingLR(optimizer, **train_cfg.scheduler)
scaler = torch.cuda.amp.GradScaler(enabled=(dtype == "float16"))
estimate_loss = create_loss_estimator(eval_iters, ctx)
best_val_loss = float("inf")
t0 = time.time()
next_batch = next(train_loader) # fetch the very first batch
for it in range(1, max_iters + 1):
for _ in range(gradient_accumulation_steps):
batch = next_batch
# TODO: can we handle this better with a differently structured tensorclass?
batch.batch_size = []
with ctx:
model(batch)
# immediately async prefetch next batch while model is doing the forward pass on the GPU
next_batch = next(train_loader)
# backward pass, with gradient scaling if training in fp16
scaler.scale(batch.loss).backward()
# clip the gradient
if grad_clip != 0.0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
# step the optimizer and scaler if training in fp16
scaler.step(optimizer)
scaler.update()
# flush the gradients as soon as we can, no need for this memory anymore
optimizer.zero_grad(set_to_none=True)
# update learning rate
if scheduler is not None:
scheduler.step()
t1 = time.time()
dt = t1 - t0
t0 = t1
if it % eval_interval == 0:
# evaluate the loss on train/val sets and write checkpoints
train_loss = estimate_loss(model, train_loader)
val_loss = estimate_loss(model, val_loader)
msg = f"VALID: {it=}: {train_loss=:.4f}, {val_loss=:.4f}"
print(msg)
loss_logger.info(msg)
if val_loss < best_val_loss or always_save_checkpoint:
best_val_loss = val_loss
if it > 0:
msg = f"saving checkpoint to {out_dir}"
print(msg)
loss_logger.info(msg)
model.module.save_pretrained(out_dir)
elif it % log_interval == 0:
# loss as float. note: this is a CPU-GPU sync point
loss = batch.loss.item()
msg = f"TRAIN: {it=}: {loss=:.4f}, time {dt*1000:.2f}ms"
print(msg)
loss_logger.info(msg)
if __name__ == "__main__":
main()