diff --git a/examples/rlhf/.gitignore b/examples/rlhf/.gitignore new file mode 100644 index 00000000000..d8bad909a58 --- /dev/null +++ b/examples/rlhf/.gitignore @@ -0,0 +1,4 @@ +*.png +*.bin +*.pt +*.json diff --git a/examples/rlhf/README.md b/examples/rlhf/README.md new file mode 100644 index 00000000000..1ddca8dfb96 --- /dev/null +++ b/examples/rlhf/README.md @@ -0,0 +1,45 @@ +# RLHF example + +This example uses RLHF (Reinforcement Learning with Human Feedback) to train a language model to summarize Reddit posts. + +## Getting started + +Make sure you have PyTorch 2.0 installed. You can find installation instructions [here](https://pytorch.org/get-started/locally/). + +From this directory, you can install extra requirements for running these examples with + +```sh +pip install -r requirements.txt +``` + +## Training the models +### Training the transformer + +Once the data has been prepared, you can train the GPT model. + +```sh +python train.py +``` + +Default configuration can be found in `config/train.yaml`, and any option can be overridden with command-line arguments, for example to run the training script with a different batch size + +```sh +python train.py --batch_size=128 +``` +> **_NOTE:_** Apple Silicon Macbooks users make sure to use `--device=mps` and prepend all commands with `PYTORCH_ENABLE_MPS_FALLBACK=1` to enable CPU fallback + +### Training the reward model + +Next you can train the reward model with + +```sh +python train_reward.py +``` + +### Training the final model with RLHF + +To train the final model run + +```sh +python train_rlhf.py +``` diff --git a/examples/rlhf/config/train.yaml b/examples/rlhf/config/train.yaml new file mode 100644 index 00000000000..6d27088902f --- /dev/null +++ b/examples/rlhf/config/train.yaml @@ -0,0 +1,30 @@ +io: + eval_interval: 200 + log_interval: 50 + eval_iters: 100 +data: + batch_size: 16 # if gradient_accumulation_steps > 1, this is the micro-batch size + block_size: 550 +model: + name_or_path: gpt2 # gpt2 for pre-trained, local path for checkpoint + out_dir: ./out + dropout: 0.1 # for pretraining 0 is good, for finetuning try 0.1+ +train: + grad_clip: 1.0 # clip gradients at this value, or disable if == 0.0 + max_iters: 5000 # total number of training iterations + gradient_accumulation_steps: 2 # used to simulate larger batch sizes + always_save_checkpoint: False # if True, always save a checkpoint after each evaluation in out_dir + decay_lr: True # whether to decay the learning rate + optimizer: + # keyword arguments for torch.optim.AdamW + lr: 1.0e-5 + weight_decay: 1.0e-1 + betas: [0.9, 0.95] + scheduler: + # keyword arguments for torch.optim.lr_scheduler.CosineAnnealingLR + T_max: 5000 # maximum number of iterations + eta_min: 1.0e-6 # minimum learning rate +sys: + device: cuda # examples: cpu, cuda, cuda:0, cuda:1 etc., or try mps on macbooks + dtype: bfloat16 # float32, bfloat16, or float16, the latter will auto implement a GradScaler + compile: True # use PyTorch 2.0 to compile the model to be faster diff --git a/examples/rlhf/config/train_reward.yaml b/examples/rlhf/config/train_reward.yaml new file mode 100644 index 00000000000..a5523b75fe2 --- /dev/null +++ b/examples/rlhf/config/train_reward.yaml @@ -0,0 +1,32 @@ +io: + eval_interval: 200 + log_interval: 50 + eval_iters: 100 +data: + batch_size: 16 # if gradient_accumulation_steps > 1, this is the micro-batch size + block_size: 550 +model: + name_or_path: ./out + dropout: 0.1 # for pretraining 0 is good, for finetuning try 0.1+ +reward_model: + out_dir: ./out_reward + init_from: scratch # 'scratch' or 'resume' - if "resume" model will be loaded from out_dir_reward +train: + grad_clip: 1.0 # clip gradients at this value, or disable if == 0.0 + max_iters: 20000 # total number of training iterations + gradient_accumulation_steps: 2 # used to simulate larger batch sizes + always_save_checkpoint: False # if True, always save a checkpoint after each eval + decay_lr: False # whether to decay the learning rate + optimizer: + # keyword arguments for torch.optim.AdamW + lr: 1.0e-5 + weight_decay: 1.0e-1 + betas: [0.9, 0.95] + scheduler: + # keyword arguments for torch.optim.lr_scheduler.CosineAnnealingLR + T_max: 20000 + eta_min: 1.0e-6 +sys: + device: cuda # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks + dtype: bfloat16 # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler + compile: True # use PyTorch 2.0 to compile the model to be faster diff --git a/examples/rlhf/config/train_rlhf.yaml b/examples/rlhf/config/train_rlhf.yaml new file mode 100644 index 00000000000..0aac2d83acd --- /dev/null +++ b/examples/rlhf/config/train_rlhf.yaml @@ -0,0 +1,36 @@ +io: + eval_interval: 6 + log_interval: 1 + eval_iters: 10 +data: + batch_size: 4 # if gradient_accumulation_steps > 1, this is the micro-batch size + block_size: 550 +model: + name_or_path: ./out + out_dir: ./out_rlhf + dropout: 0.1 # for pretraining 0 is good, for finetuning try 0.1+ +reward_model: + name_or_path: ./out_reward +train: + grad_clip: 1.0 + max_epochs: 1000 # total number of training iterations + always_save_checkpoint: True # if True, always save a checkpoint after each eval + decay_lr: True + optimizer: + # keyword arguments for torch.optim.AdamW + lr: 5.0e-5 + weight_decay: 0.0 # 01 + betas: [0.9, 0.999] + scheduler: + # keyword arguments for torch.optim.lr_scheduler.CosineAnnealingLR + T_max: 3000 # max_epochs * num_rollouts / ppo_batch_size + eta_min: 5.0e-6 + ppo: + episode_length: 50 + ppo_batch_size: 16 + ppo_num_epochs: 3 + num_rollouts_per_epoch: 32 +sys: + device: cuda # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks + dtype: bfloat16 # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler + compile: True # use PyTorch 2.0 to compile the model to be faster diff --git a/examples/rlhf/data/__init__.py b/examples/rlhf/data/__init__.py new file mode 100644 index 00000000000..433c23452f2 --- /dev/null +++ b/examples/rlhf/data/__init__.py @@ -0,0 +1,3 @@ +from torchrl.data.rlhf.prompt import get_prompt_dataloader_tldr + +__all__ = ["get_prompt_dataloader_tldr"] diff --git a/examples/rlhf/models/__init__.py b/examples/rlhf/models/__init__.py new file mode 100644 index 00000000000..7bec24cb17b --- /dev/null +++ b/examples/rlhf/models/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. diff --git a/examples/rlhf/models/actor_critic.py b/examples/rlhf/models/actor_critic.py new file mode 100644 index 00000000000..e514cf9b248 --- /dev/null +++ b/examples/rlhf/models/actor_critic.py @@ -0,0 +1,29 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from torchrl.modules.tensordict_module.actors import LMActorCritic +from torchrl.modules.tensordict_module.common import VmapModule + +from .transformer import init_transformer + +__all__ = ["init_actor_critic"] + + +def init_actor_critic(transformer_name_or_path, dropout, device, compile_): + base_model = init_transformer( + transformer_name_or_path, + dropout, + device, + as_tensordictmodule=False, + compile_=compile_, + inference=True, + ) + model = LMActorCritic(base_model) + model.to(device) + model.eval() + actor = model.get_policy_operator() + critic = model.get_value_operator() + critic_head = model.get_value_head() + + return actor, VmapModule(critic), critic_head, base_model diff --git a/examples/rlhf/models/reward.py b/examples/rlhf/models/reward.py new file mode 100644 index 00000000000..ce84c727dd4 --- /dev/null +++ b/examples/rlhf/models/reward.py @@ -0,0 +1,34 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from tensordict.nn import TensorDictModule + +from torchrl.modules.models.rlhf import GPT2RewardModel + + +def init_reward_model( + transformer_path=None, reward_model_path=None, device=None, compile_=False +): + if not ((transformer_path is None) ^ (reward_model_path is None)): + raise ValueError( + "Exactly one of transformer_path or reward_model_path should be specified" + ) + if transformer_path is not None: + model = GPT2RewardModel(transformer_path) + else: + model = GPT2RewardModel.from_pretrained(reward_model_path) + + model.to(device) + if compile_: + print("Compiling the reward model...") + model = torch.compile(model) + + model = TensorDictModule( + model, + in_keys=["input_ids", "attention_mask"], + out_keys=["rewards", "end_scores"], + ) + return model diff --git a/examples/rlhf/models/transformer.py b/examples/rlhf/models/transformer.py new file mode 100644 index 00000000000..cde8ce568ae --- /dev/null +++ b/examples/rlhf/models/transformer.py @@ -0,0 +1,44 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import torch +from tensordict.nn import TensorDictModule +from transformers import GPT2LMHeadModel + + +def init_transformer( + name_or_path, + dropout, + device, + compile_, + as_tensordictmodule=True, + inference=False, +): + model_kwargs = { + "resid_pdrop": dropout, + "embd_pdrop": dropout, + "attn_pdrop": dropout, + "summary_first_dropout": dropout, + } + model = GPT2LMHeadModel.from_pretrained( + name_or_path, return_dict=False, **model_kwargs + ) + model.to(device) + + if compile_: + # TODO: logging instead of printing? + print("Compiling transformer model...") + model = torch.compile(model) + + if as_tensordictmodule: + model = TensorDictModule( + model, + in_keys={ + "input_ids": "input_ids", + "attention_mask": "attention_mask", + "labels": "labels", + }, + out_keys=["logits"] if inference else ["loss", "logits"], + ) + return model diff --git a/examples/rlhf/requirements.txt b/examples/rlhf/requirements.txt new file mode 100644 index 00000000000..9bff1b48453 --- /dev/null +++ b/examples/rlhf/requirements.txt @@ -0,0 +1,11 @@ +datasets +hydra-core +matplotlib +numpy +PyYAML +requests +tiktoken +tqdm +transformers +git+https://github.com/pytorch/rl +git+https://github.com/pytorch-labs/tensordict diff --git a/examples/rlhf/train.py b/examples/rlhf/train.py new file mode 100644 index 00000000000..fe624213ada --- /dev/null +++ b/examples/rlhf/train.py @@ -0,0 +1,155 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +Train the transformer model. Configurable via config/train.yaml, but any argument can +also be overridden at the command line. + +To run on a single GPU, example: +$ python train.py --batch_size=32 --compile=False +""" +import time + +import hydra +import torch +from models.transformer import init_transformer +from torch.optim.lr_scheduler import CosineAnnealingLR + +from torchrl.data.rlhf.dataset import get_dataloader +from torchrl.data.rlhf.prompt import PromptData +from utils import get_file_logger, resolve_name_or_path, setup + + +def create_loss_estimator(eval_iters, ctx): + # helps estimate an arbitrarily accurate loss over either split using many batches + + @torch.no_grad() + def estimate_loss(model, dataloader): + model.eval() + losses = torch.zeros(eval_iters) + for k in range(eval_iters): + batch = next(dataloader) + batch.batch_size = [] + with ctx: + model(batch) + losses[k] = batch.loss.item() + model.train() + return losses.mean() + + return estimate_loss + + +@hydra.main(version_base="1.1", config_path="config", config_name="train") +def main(cfg): + loss_logger = get_file_logger("loss_logger", "transformer_loss_logger.log") + + data_cfg = cfg.data + model_cfg = cfg.model + train_cfg = cfg.train + + eval_interval = cfg.io.eval_interval + log_interval = cfg.io.log_interval + eval_iters = cfg.io.eval_iters + out_dir = model_cfg.out_dir + + grad_clip = train_cfg.grad_clip + max_iters = train_cfg.max_iters + always_save_checkpoint = train_cfg.always_save_checkpoint + gradient_accumulation_steps = train_cfg.gradient_accumulation_steps + + device = cfg.sys.device + dtype = cfg.sys.dtype + compile_ = cfg.sys.compile + + ctx = setup(device=device, dtype=dtype) + + train_loader = get_dataloader( + data_cfg.batch_size, + data_cfg.block_size, + PromptData, + device, + dataset_name="CarperAI/openai_summarize_tldr", + split="train", + ) + val_loader = get_dataloader( + data_cfg.batch_size, + data_cfg.block_size, + PromptData, + device, + dataset_name="CarperAI/openai_summarize_tldr", + split="valid", + ) + + model = init_transformer( + resolve_name_or_path(model_cfg.name_or_path), + model_cfg.dropout, + device, + compile_=compile_, + ) + optimizer = torch.optim.AdamW(model.parameters(), **train_cfg.optimizer) + scheduler = None + if train_cfg.decay_lr: + scheduler = CosineAnnealingLR(optimizer, **train_cfg.scheduler) + + scaler = torch.cuda.amp.GradScaler(enabled=(dtype == "float16")) + estimate_loss = create_loss_estimator(eval_iters, ctx) + + best_val_loss = float("inf") + + t0 = time.time() + next_batch = next(train_loader) # fetch the very first batch + for it in range(1, max_iters + 1): + for _ in range(gradient_accumulation_steps): + batch = next_batch + # TODO: can we handle this better with a differently structured tensorclass? + batch.batch_size = [] + with ctx: + model(batch) + # immediately async prefetch next batch while model is doing the forward pass on the GPU + next_batch = next(train_loader) + # backward pass, with gradient scaling if training in fp16 + scaler.scale(batch.loss).backward() + + # clip the gradient + if grad_clip != 0.0: + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) + + # step the optimizer and scaler if training in fp16 + scaler.step(optimizer) + scaler.update() + # flush the gradients as soon as we can, no need for this memory anymore + optimizer.zero_grad(set_to_none=True) + + # update learning rate + if scheduler is not None: + scheduler.step() + + t1 = time.time() + dt = t1 - t0 + t0 = t1 + if it % eval_interval == 0: + # evaluate the loss on train/val sets and write checkpoints + train_loss = estimate_loss(model, train_loader) + val_loss = estimate_loss(model, val_loader) + msg = f"VALID: {it=}: {train_loss=:.4f}, {val_loss=:.4f}" + print(msg) + loss_logger.info(msg) + if val_loss < best_val_loss or always_save_checkpoint: + best_val_loss = val_loss + if it > 0: + msg = f"saving checkpoint to {out_dir}" + print(msg) + loss_logger.info(msg) + model.module.save_pretrained(out_dir) + elif it % log_interval == 0: + # loss as float. note: this is a CPU-GPU sync point + loss = batch.loss.item() + msg = f"TRAIN: {it=}: {loss=:.4f}, time {dt*1000:.2f}ms" + print(msg) + loss_logger.info(msg) + + +if __name__ == "__main__": + main() diff --git a/examples/rlhf/train_reward.py b/examples/rlhf/train_reward.py new file mode 100644 index 00000000000..850c6d92f1b --- /dev/null +++ b/examples/rlhf/train_reward.py @@ -0,0 +1,164 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import time + +import hydra +import torch +from models.reward import init_reward_model +from torch.optim.lr_scheduler import CosineAnnealingLR +from torchrl.data.rlhf.dataset import get_dataloader +from torchrl.data.rlhf.reward import PairwiseDataset +from utils import get_file_logger, resolve_name_or_path, setup + + +def _accuracy(chosen_end_scores, rejected_end_scores): + return ( + sum(chosen_end_scores > rejected_end_scores) / len(rejected_end_scores) + ).item() + + +# TODO: eliminate redundant repeated definition +# helps estimate an arbitrarily accurate loss over either split using many batches +def create_loss_estimator(eval_iters, ctx): + @torch.no_grad() + def estimate_loss(model, dataloader): + model.eval() + losses = torch.zeros(eval_iters) + accs = torch.zeros(eval_iters) + for k in range(eval_iters): + batch = next(dataloader) + with ctx: + model(batch.chosen_data) + model(batch.rejected_data) + losses[k] = model.compute_reward_loss( + batch.chosen_data, batch.rejected_data + ).item() + accs[k] = _accuracy( + batch.chosen_data.end_scores, batch.rejected_data.end_scores + ) + model.train() + return losses.mean(), accs.mean() + + return estimate_loss + + +@hydra.main(version_base="1.1", config_path="config", config_name="train_reward") +def main(cfg): + loss_logger = get_file_logger("loss_logger", "reward_loss_logger.log") + + data_cfg = cfg.data + model_cfg = cfg.model + reward_model_cfg = cfg.reward_model + train_cfg = cfg.train + + eval_interval = cfg.io.eval_interval + log_interval = cfg.io.log_interval + eval_iters = cfg.io.eval_iters + reward_out_dir = reward_model_cfg.out_dir + + max_iters = train_cfg.max_iters + always_save_checkpoint = train_cfg.always_save_checkpoint + + device = cfg.sys.device + dtype = cfg.sys.dtype + compile_ = cfg.sys.compile + + ctx = setup(device=device, dtype=dtype) + + train_loader = get_dataloader( + data_cfg.batch_size, + data_cfg.block_size, + PairwiseDataset, + device, + dataset_name="CarperAI/openai_summarize_comparisons", + split="train", + ) + val_loader = get_dataloader( + data_cfg.batch_size, + data_cfg.block_size, + PairwiseDataset, + device, + dataset_name="CarperAI/openai_summarize_comparisons", + split="valid1", + ) + + if reward_model_cfg.init_from == "resume": + model = init_reward_model( + reward_model_path=resolve_name_or_path(reward_model_cfg.out_dir), + device=device, + compile_=compile_, + ) + else: + model = init_reward_model( + transformer_path=resolve_name_or_path(model_cfg.name_or_path), + device=device, + compile_=compile_, + ) + # Freeze the first 70% of the hidden layers of the reward model backbone + layers = model.transformer.h + num_layers = len(layers) + num_unfrozen = int(0.3 * num_layers) + for layer in layers[:-num_unfrozen]: + layer.requires_grad_(False) + + # ######## INIT TRAINING FUNCTIONS ######## + + optimizer = torch.optim.AdamW( + [p for p in model.parameters() if p.requires_grad], **train_cfg.optimizer + ) + scheduler = None + if train_cfg.decay_lr: + scheduler = CosineAnnealingLR(optimizer, **train_cfg.scheduler) + + estimate_loss = create_loss_estimator(eval_iters, ctx) + + best_val_loss = float("inf") + + t0 = time.time() + for it in range(1, max_iters + 1): + batch = next(train_loader) + + with ctx: + model(batch.chosen_data) + model(batch.rejected_data) + optimizer.zero_grad(set_to_none=True) + loss = model.compute_reward_loss(batch.chosen_data, batch.rejected_data) + loss.backward() + optimizer.step() + if scheduler is not None: + scheduler.step() + + t1 = time.time() + dt = t1 - t0 + t0 = t1 + if it % eval_interval == 0: + val_loss, val_acc = estimate_loss(model, val_loader) + train_loss, train_acc = estimate_loss(model, train_loader) + + msg = ( + f"VALID: {it=}: {train_loss=:.4f}, {val_loss=:.4f}, " + f"{train_acc=:.4f}, {val_acc=:.4f}" + ) + print(msg) + loss_logger.info(msg) + if val_loss < best_val_loss or always_save_checkpoint: + best_val_loss = val_loss + if it > 0: + msg = f"saving checkpoint to {reward_out_dir}" + print(msg) + loss_logger.info(msg) + model.module.save_pretrained(reward_out_dir) + elif it % log_interval == 0: + loss = loss.item() + acc = _accuracy( + batch.chosen_data.end_scores, batch.rejected_data.end_scores + ) + msg = f"TRAIN: {it=}: {loss=:.4f}, {acc=:.4f} time={dt*1000:.2f}ms" + print(msg) + loss_logger.info(msg) + + +if __name__ == "__main__": + main() diff --git a/examples/rlhf/train_rlhf.py b/examples/rlhf/train_rlhf.py new file mode 100644 index 00000000000..4226bad3160 --- /dev/null +++ b/examples/rlhf/train_rlhf.py @@ -0,0 +1,339 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from copy import deepcopy + +import numpy as np +import torch + +import wandb +from models.actor_critic import init_actor_critic +from models.reward import init_reward_model + +from omegaconf import OmegaConf +from torch.optim.lr_scheduler import CosineAnnealingLR + +from torchrl.data import LazyTensorStorage +from torchrl.data.replay_buffers import ( + SamplerWithoutReplacement, + TensorDictReplayBuffer, +) +from torchrl.data.rlhf.dataset import get_dataloader +from torchrl.data.rlhf.prompt import PromptData +from torchrl.data.rlhf.utils import RolloutFromModel + +from torchrl.objectives import ClipPPOLoss +from torchrl.objectives.value import GAE +from tqdm import tqdm +from transformers import GenerationConfig, GPT2Tokenizer +from utils import get_file_logger, resolve_name_or_path, setup + + +def flatten_td(td): + # our tensordict has shape [B, T] where B = batch_size and T = trajectory length + # some trajectories may have stopped (reached EOS) before generating T tokens + # this function truncates and concatenates the trajectories, resulting in a + # tensordict that has shape [N] where N <= B * T. + done = td["next", "done"] + mask = torch.zeros_like(done) + mask[..., 1:, :] = done[..., :-1, :] # shift by one + mask = ~mask.cumsum(-2).bool().squeeze() + return td[mask] + + +class AdaptiveKLController: + """Adaptive KL Controller as described in Ziegler et al. "Fine-Tuning Language Models from Human Preferences" + Reference: Section 2.2 https://arxiv.org/pdf/1909.08593.pdf#page=2 + Source: https://github.com/openai/lm-human-preferences/blob/master/lm_human_preferences/train_policy.py + """ + + def __init__(self, init_kl_coef: float, target: float, horizon: int): + self.value = init_kl_coef + self.target = target + self.horizon = horizon + + def update(self, current: float, n_steps: int): + """Returns adaptively updated KL coefficient, βₜ₊₁. + Arguments: + current: The current KL value between the newest policy and the initial policy. + """ + proportional_error = np.clip(current / self.target - 1, -0.2, 0.2) # ϵₜ + mult = 1 + proportional_error * n_steps / self.horizon + self.value *= mult # βₜ₊₁ + return self.value + + +def create_reward_estimator( + eval_iters, episode_length, reward_model, batch, ctx, logger=None, ref_model=None +): + """Create a function to estimate the reward via sampling. + + This function creates a new function which, given a model and a dataloader, will + perform multiple rollouts using the model and data sampled from the dataloader then + average the accumulated rewards. + + For debugging purposes, we also generate responses to a fixed prompt so that the + quality of the model can be visually assessed during training. + """ + tokenizer = GPT2Tokenizer.from_pretrained("gpt2") + tokenizer.pad_token = tokenizer.eos_token + + test_rindex = batch.prompt_rindex[0] + test_prompt_ids = batch.input_ids[:1, :test_rindex] + test_label_ids = batch.input_ids[:1, test_rindex:] + generation_config = GenerationConfig( + pad_token_id=tokenizer.pad_token_id, max_new_tokens=episode_length + ) + test_prompt = tokenizer.decode(test_prompt_ids[0, :test_rindex].tolist()) + test_label = tokenizer.decode( + test_label_ids[0, test_label_ids[0] != tokenizer.pad_token_id].tolist() + ) + _, test_label_reward = reward_model( + input_ids=batch.input_ids[:1], attention_mask=batch.attention_mask[:1] + ) + + @torch.no_grad() + def estimate_reward(model, dataloader): + rollout_from_model = RolloutFromModel(model, ref_model, reward_model) + rewards = torch.zeros(eval_iters) + for k in range(eval_iters): + batch = next(dataloader) + # NOTE: disable kl for evaluation + td = rollout_from_model.rollout_from_data(batch, kl_coef=0.0) + rewards[k] = td.get(("next", "reward")).sum(dim=1).mean().item() + test_reward = rewards.mean() + + if logger: + response_ids = model.generate( + input_ids=test_prompt_ids, generation_config=generation_config + ) + with ctx: + _, response_reward = reward_model( + input_ids=response_ids, + attention_mask=(response_ids != tokenizer.pad_token_id).to( + torch.int64 + ), + ) + reward = (response_reward - test_label_reward).item() + response_ids = response_ids[0, test_rindex:] + response = tokenizer.decode( + response_ids[response_ids != tokenizer.eos_token_id].tolist() + ) + string_to_write = ( + f"Query:\n{test_prompt}\n" + f"Response:\n{response}\n" + f"Actual response:\n{test_label}\n" + f"{reward=:4.4f}, " + f"{test_reward=:4.4f}\n" + f"====================================================\n" + ) + logger.info(string_to_write) + + return test_reward + + return estimate_reward + + +# @hydra.main(version_base="1.1", config_path="config", config_name="train_rlhf") +def main(): + cfg = OmegaConf.load("config/train_rlhf.yaml") + wandb.init( + # set the wandb project where this run will be logged + project="rlhf-training", + # track hyperparameters and run metadata + config=cfg, + ) + query_logger = get_file_logger("query_logger", "rlhf_query_logger.log") + val_reward_logger = get_file_logger("val_reward_logger", "rlhf_valid_rewards.log") + + data_cfg = cfg.data + model_cfg = cfg.model + reward_model_cfg = cfg.reward_model + train_cfg = cfg.train + ppo_cfg = train_cfg.ppo + + eval_interval = cfg.io.eval_interval + log_interval = cfg.io.log_interval + eval_iters = cfg.io.eval_iters + + rlhf_out_dir = model_cfg.out_dir + transformer_name_or_path = model_cfg.name_or_path + dropout = model_cfg.dropout + + batch_size = data_cfg.batch_size + + grad_clip = train_cfg.grad_clip + max_epochs = train_cfg.max_epochs + always_save_checkpoint = train_cfg.always_save_checkpoint + + episode_length = ppo_cfg.episode_length + ppo_batch_size = ppo_cfg.ppo_batch_size + ppo_num_epochs = ppo_cfg.ppo_num_epochs + num_rollouts_per_epoch = ppo_cfg.num_rollouts_per_epoch + + device = cfg.sys.device + dtype = cfg.sys.dtype + compile_ = cfg.sys.compile + + ctx = setup(device, dtype) + + train_loader = get_dataloader( + data_cfg.batch_size, + data_cfg.block_size, + PromptData, + device, + dataset_name="CarperAI/openai_summarize_tldr", + split="train", + ) + val_loader = get_dataloader( + data_cfg.batch_size, + data_cfg.block_size, + PromptData, + device, + dataset_name="CarperAI/openai_summarize_tldr", + split="valid", + ) + + actor, critic, critic_head, model = init_actor_critic( + resolve_name_or_path(transformer_name_or_path), dropout, device, compile_ + ) + ref_model = deepcopy(model).to("cuda:1") + ref_model.requires_grad_(False) + layers = model.transformer.h + num_layers = len(layers) + num_unfrozen = int(0.3 * num_layers) + for layer in layers[:-num_unfrozen]: + layer.requires_grad_(False) + + reward_model = init_reward_model( + reward_model_path=resolve_name_or_path(reward_model_cfg.name_or_path), + device=device, + compile_=compile_, + ) + reward_model.eval() + reward_model.requires_grad_(False) + + adv_fn = GAE( + value_network=critic, gamma=0.99, lmbda=0.95, average_gae=True, shifted=True + ) + loss_fn = ClipPPOLoss(actor, critic_head) + + test_prompt = next(val_loader) + estimate_reward = create_reward_estimator( + eval_iters, + episode_length, + reward_model, + test_prompt, + ctx, + logger=query_logger, + ref_model=ref_model, + ) + + optimizer = torch.optim.AdamW( + [p for p in loss_fn.parameters() if p.requires_grad], **train_cfg.optimizer + ) + scheduler = None + if train_cfg.decay_lr: + scheduler = CosineAnnealingLR(optimizer, **train_cfg.scheduler) + + rb = TensorDictReplayBuffer( + storage=LazyTensorStorage(episode_length * num_rollouts_per_epoch), + batch_size=episode_length * batch_size, + sampler=SamplerWithoutReplacement(), + prefetch=10, + ) + rb_ppo = TensorDictReplayBuffer( + storage=LazyTensorStorage(episode_length * batch_size), + batch_size=ppo_batch_size, + sampler=SamplerWithoutReplacement(), + prefetch=10, + ) + + rollout_from_model = RolloutFromModel(model, ref_model, reward_model) + + best_val_reward = float("-inf") + it = 0 # it is equivalent to batch_size number of episodes + with tqdm(total=int(max_epochs * num_rollouts_per_epoch / batch_size)) as pbar: + for _epoch in range(1, max_epochs + 1): + rb.empty() + rollout_rewards = [] + rollout_kl = [] + kl_controller = AdaptiveKLController(0.1, 6, 10000) + for _ in range(0, num_rollouts_per_epoch, batch_size): + batch = next(train_loader) + td = rollout_from_model.rollout_from_data( + batch, kl_coef=kl_controller.value + ) + with torch.no_grad(), ctx: + # moving this to within epoch + adv_fn(td) + # it's possible we didn't fill the replay buffer in the last iteration if + # generation stopped early, so we empty first before repopulating + rb.extend(flatten_td(td)) + done = td.get(("next", "done")) + next_reward = td.get(("next", "reward_raw"))[done] + next_kl = td.get(("next", "reward_kl"))[done] + rollout_rewards.append(next_reward.mean().cpu().item()) + rollout_kl.append(next_kl.mean().cpu().item()) + rollout_reward = torch.tensor(rollout_rewards).mean().cpu().item() + rollout_kl_reward = torch.tensor(rollout_kl).mean().cpu().item() + # recover true kl + rollout_kl = -rollout_kl_reward / kl_controller.value + kl_controller.update(rollout_kl, num_rollouts_per_epoch / batch_size) + + # FIXME: THIS PPO CYCLE WAS DIFFERENT wrt trlx. @tcbegley please double check + # they sample batch_size from rb and then do minibatches ppo_batch_size within + if it % log_interval == 0: + val_reward_logger.info( + f"TRAIN: {it=}: {rollout_reward=:.4f} {rollout_kl_reward=:.4f} {rollout_kl=:.4f}" + ) + wandb.log( + { + "rollout_reward": rollout_reward, + "rollout_kl_reward": rollout_kl_reward, + "rollout_kl": rollout_kl, + }, + step=it, + ) + pbar.set_description(f"TRAIN: {it=}: {rollout_reward=:.4f}") + + for batch in rb: + rb_ppo.empty() + rb_ppo.extend(batch) + for _ in range(ppo_num_epochs): # PPO epochs + optimizer.zero_grad() + # why don't we optimize at each step? Is accumulating grads better? + # usually more small steps is better than a giant one + for minibatch in rb_ppo: # GO over RB + minibatch = minibatch.to(device, non_blocking=True) + with ctx: + loss_vals = loss_fn(minibatch) + loss_val = sum( + value + for key, value in loss_vals.items() + if key.startswith("loss") + ) + loss_val.backward() + torch.nn.utils.clip_grad_norm_(loss_fn.parameters(), grad_clip) + optimizer.step() + if scheduler is not None: + scheduler.step() + it += 1 + pbar.update(1) + if it % eval_interval == 0: + val_reward = estimate_reward(model, val_loader) + val_reward_logger.info(f"VALID: {it=}: {val_reward=:.4f}") + wandb.log({"val_reward": val_reward}, step=it) + pbar.set_description(f"VALID: {it=}: {val_reward=:.4f}") + if val_reward > best_val_reward or always_save_checkpoint: + best_val_reward = val_reward + if it > 0: + val_reward_logger.info( + f"saving checkpoint to {rlhf_out_dir}" + ) + model.save_pretrained(rlhf_out_dir) + + +if __name__ == "__main__": + main() diff --git a/examples/rlhf/utils.py b/examples/rlhf/utils.py new file mode 100644 index 00000000000..aedb0501091 --- /dev/null +++ b/examples/rlhf/utils.py @@ -0,0 +1,47 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import logging +from contextlib import nullcontext + +import torch +import torch._dynamo +from hydra.utils import to_absolute_path + + +def resolve_name_or_path(name_or_path): + """Hydra changes the working directory, so we need to absolutify paths.""" + if name_or_path.startswith("./") or name_or_path.startswith("/"): + return to_absolute_path(name_or_path) + return name_or_path + + +def get_file_logger(name, filename, level=logging.DEBUG): + """ + Set up logger that will log to the given filename. + """ + logger = logging.getLogger(name) + handler = logging.FileHandler(filename) + handler.setFormatter( + # logging.Formatter("%(asctime)s, %(name)s %(levelname)s %(message)s") + logging.Formatter("%(asctime)s - %(message)s") + ) + logger.addHandler(handler) + logger.setLevel(level) + return logger + + +def setup(device, dtype): + """ + Set manual seed, configure backend and autocasting. + """ + torch.manual_seed(1337) + torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul + torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn + torch._dynamo.config.cache_size_limit = 256 + + if "cuda" not in device: + return nullcontext() + + return torch.amp.autocast(device_type="cuda", dtype=getattr(torch, dtype))