Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add distributed data parallel trainer #66

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
.ipynb_checkpoints/
__pycache__/
*.txt
156 changes: 156 additions & 0 deletions mingpt/trainer_ddp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
"""
Simple distributed data parallel training loop; Boilerplate that could apply to any arbitrary neural network,
so nothing in this file really has anything to do with GPT specifically.
Written to just use native pytorch and no mpi in order to stay true to spirit of minGPT.
Most of the code is retained from trainer.py from DataParallel.
"""

import math
import logging

from tqdm import tqdm
import numpy as np

import torch
import torch.optim as optim
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data.dataloader import DataLoader, RandomSampler
from torch.nn.parallel import DistributedDataParallel as DDP

from torch.utils.data import DistributedSampler
import torch.distributed as dist
import torch.multiprocessing as mp

logger = logging.getLogger(__name__)

def dist_init(rank, world_size, port=23501):
backend = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO
dist.init_process_group(backend=backend, init_method="tcp://localhost:"+str(port), rank=rank, world_size=world_size)

def cleanup():
dist.destroy_process_group()

class TrainerConfig:
# optimization parameters
max_epochs = 10
batch_size = 64
learning_rate = 3e-4
betas = (0.9, 0.95)
grad_norm_clip = 1.0
weight_decay = 0.1 # only applied on matmul weights
# learning rate decay params: linear warmup followed by cosine decay to 10% of original
lr_decay = False
warmup_tokens = 375e6 # these two numbers come from the GPT-3 paper, but may not be good defaults elsewhere
final_tokens = 260e9 # (at what point we reach 10% of original LR)
# checkpoint settings
ckpt_path = None
num_workers = 0 # for DataLoader

def __init__(self, **kwargs):
for k,v in kwargs.items():
setattr(self, k, v)

class TrainerDDP:

def __init__(self, model, train_dataset, test_dataset, config, port=23501):
self.model = model
self.train_dataset = train_dataset
self.test_dataset = test_dataset
self.config = config
self.port = port

def save_checkpoint(self, ddp_model):
# DataParallel wrappers keep raw model object in .module attribute
raw_model = ddp_model.module if hasattr(ddp_model, "module") else ddp_model
logger.info("saving %s", self.config.ckpt_path)
torch.save(raw_model.state_dict(), self.config.ckpt_path)

def load_checkpoint(self, model=None, map_location='cuda'):
if model is None:
self.model.load_state_dict(torch.load(self.config.ckpt_path, map_location=map_location))
return self.model
else:
model.load_state_dict(torch.load(self.config.ckpt_path, map_location=map_location))
return model

def train(self, rank:int, world_size:int):
model, config = self.model, self.config
dist_init(rank, world_size, port=self.port)

raw_model = model.module if hasattr(self.model, "module") else model
optimizer = raw_model.configure_optimizers(config)

dev = torch.device('cuda', rank)
model.to(dev) ## Send it to the device with the appropriate rank_id
ddp_model = DDP(model, device_ids=[rank])

def run_epoch(split):
is_train = split == 'train'
ddp_model.train(is_train)
data = self.train_dataset if is_train else self.test_dataset
loader = DataLoader(data, pin_memory=True,
batch_size=config.batch_size,
#sampler = RandomSampler(data) if is_train else SequentialSampler(data),
sampler = DistributedSampler(data),
#data_collator = default_data_collator,
num_workers=config.num_workers)
losses = []
pbar = tqdm(enumerate(loader), total=len(loader)) if is_train else enumerate(loader)
for it, (x,y) in pbar:

x, y = x.to(dev), y.to(dev)

# forward the model
with torch.set_grad_enabled(is_train):
logits, loss = ddp_model(idx=x, targets=y)
loss = loss.mean()
losses.append(loss.item())

if is_train:

# backprop and update the parameters
ddp_model.zero_grad()
loss.backward()

torch.nn.utils.clip_grad_norm_(ddp_model.parameters(), config.grad_norm_clip)
optimizer.step()

# decay the learning rate based on our progress
if config.lr_decay:
self.tokens += (y >= 0).sum() # number of tokens processed this step (i.e. label is not -100)
if self.tokens < config.warmup_tokens:
# linear warmup
lr_mult = float(self.tokens) / float(max(1, config.warmup_tokens))
else:
# cosine learning rate decay
progress = float(self.tokens - config.warmup_tokens) / float(max(1, config.final_tokens - config.warmup_tokens))
lr_mult = max(0.1, 0.5 * (1.0 + math.cos(math.pi * progress)))
lr = config.learning_rate * lr_mult
for param_group in optimizer.param_groups:
param_group['lr'] = lr
else:
lr = config.learning_rate
# report progress

pbar.set_description(f"epoch {epoch+1} iter {it}: train loss {loss.item():.5f}. lr {lr:e}")

if not is_train:
test_loss = float(np.mean(losses))
logger.info("test loss: %f", test_loss)
return test_loss

best_loss = float('inf')
self.tokens = 0 # counter used for learning rate decay
for epoch in range(config.max_epochs):

run_epoch('train')
if self.test_dataset is not None:
test_loss = run_epoch('test')

# supports early stopping based on the test loss, or just save always if no test set is provided
good_model = self.test_dataset is None or test_loss < best_loss
if self.config.ckpt_path is not None and good_model:
best_loss = test_loss
if rank == 0:
self.save_checkpoint(ddp_model)
dist.barrier()
112 changes: 112 additions & 0 deletions play_char.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# set up logging
import logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)

# make deterministic
from mingpt.utils import set_seed
set_seed(42)

import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
import torch.multiprocessing as mp

import math
from torch.utils.data import Dataset

class CharDataset(Dataset):

def __init__(self, data, block_size):
chars = sorted(list(set(data)))
data_size, vocab_size = len(data), len(chars)
print('data has %d characters, %d unique.' % (data_size, vocab_size))

self.stoi = { ch:i for i,ch in enumerate(chars) }
self.itos = { i:ch for i,ch in enumerate(chars) }
self.block_size = block_size
self.vocab_size = vocab_size
self.data = data

def __len__(self):
return len(self.data) - self.block_size

def __getitem__(self, idx):
# grab a chunk of (block_size + 1) characters from the data
chunk = self.data[idx:idx + self.block_size + 1]
# encode every character to an integer
dix = [self.stoi[s] for s in chunk]
"""
arrange data and targets so that the first i elements of x
will be asked to predict the i-th element of y. Notice that
the eventual language model will actually make block_size
individual predictions at the same time based on this data,
so we are being clever and amortizing the cost of the forward
pass of the network. So for example if block_size is 4, then
we could e.g. sample a chunk of text "hello", the integers in
x will correspond to "hell" and in y will be "ello". This will
then actually "multitask" 4 separate examples at the same time
in the language model:
- given just "h", please predict "e" as next
- given "he" please predict "l" next
- given "hel" predict "l" next
- given "hell" predict "o" next

In addition, because the DataLoader will create batches of examples,
every forward/backward pass during traning will simultaneously train
a LOT of predictions, amortizing a lot of computation. In particular,
for a batched input of integers X (B, T) where B is batch size and
T is block_size and Y (B, T), the network will during training be
simultaneously training to make B*T predictions, all at once! Of course,
at test time we can paralellize across batch B, but unlike during training
we cannot parallelize across the time dimension T - we have to run
a forward pass of the network to recover the next single character of the
sequence along each batch dimension, and repeatedly always feed in a next
character to get the next one.

So yes there is a big asymmetry between train/test time of autoregressive
models. During training we can go B*T at a time with every forward pass,
but during test time we can only go B at a time, T times, with T forward
passes.
"""
x = torch.tensor(dix[:-1], dtype=torch.long)
y = torch.tensor(dix[1:], dtype=torch.long)
return x, y

if __name__ == '__main__':
block_size = 128 # spatial extent of the model for its context

# you can download this file at https://github.com/karpathy/char-rnn/blob/master/data/tinyshakespeare/input.txt
text = open('input.txt', 'r').read() # don't worry we won't run out of file handles
train_dataset = CharDataset(text, block_size) # one line of poem is roughly 50 characters

from mingpt.model import GPT, GPTConfig
mconf = GPTConfig(train_dataset.vocab_size, train_dataset.block_size,
n_layer=8, n_head=8, n_embd=512)
model = GPT(mconf)

from mingpt.trainer_ddp import TrainerDDP, TrainerConfig

# initialize a trainer_ddp instance and kick off training
tconf = TrainerConfig(max_epochs=2, batch_size=64, learning_rate=6e-4,
lr_decay=True, ckpt_path='./checkpoint_ddp', warmup_tokens=512*20, final_tokens=2*len(train_dataset)*block_size,
num_workers=4)
world_size = torch.cuda.device_count()
print('world size', world_size)
trainer_ddp = TrainerDDP(model, train_dataset, train_dataset, tconf, port=23501)
mp.spawn(trainer_ddp.train, args=(world_size,), nprocs = world_size, join=True)

# sample from the model to check if it's decent. sample on the GPU.
model = trainer_ddp.load_checkpoint(model=model, map_location='cuda')
from mingpt.utils import sample
sampling_device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu'
context = "What's in a name anyway? It's a name attached to an idea. A rose by any other name would smell as sweet"
model = model.to(sampling_device)
x = torch.tensor([train_dataset.stoi[s] for s in context], dtype=torch.long)[None,...].to(sampling_device)
y = sample(model, x, 2000, temperature=1.0, sample=True, top_k=10)[0]
completion = ''.join([train_dataset.itos[int(i)] for i in y])
print(completion)