-
Notifications
You must be signed in to change notification settings - Fork 312
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
15 changed files
with
977 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
*.png | ||
*.bin | ||
*.pt | ||
*.json |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
# RLHF example | ||
|
||
This example uses RLHF (Reinforcement Learning with Human Feedback) to train a language model to summarize Reddit posts. | ||
|
||
## Getting started | ||
|
||
Make sure you have PyTorch 2.0 installed. You can find installation instructions [here](https://pytorch.org/get-started/locally/). | ||
|
||
From this directory, you can install extra requirements for running these examples with | ||
|
||
```sh | ||
pip install -r requirements.txt | ||
``` | ||
|
||
## Training the models | ||
### Training the transformer | ||
|
||
Once the data has been prepared, you can train the GPT model. | ||
|
||
```sh | ||
python train.py | ||
``` | ||
|
||
Default configuration can be found in `config/train.yaml`, and any option can be overridden with command-line arguments, for example to run the training script with a different batch size | ||
|
||
```sh | ||
python train.py --batch_size=128 | ||
``` | ||
> **_NOTE:_** Apple Silicon Macbooks users make sure to use `--device=mps` and prepend all commands with `PYTORCH_ENABLE_MPS_FALLBACK=1` to enable CPU fallback | ||
### Training the reward model | ||
|
||
Next you can train the reward model with | ||
|
||
```sh | ||
python train_reward.py | ||
``` | ||
|
||
### Training the final model with RLHF | ||
|
||
To train the final model run | ||
|
||
```sh | ||
python train_rlhf.py | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
io: | ||
eval_interval: 200 | ||
log_interval: 50 | ||
eval_iters: 100 | ||
data: | ||
batch_size: 16 # if gradient_accumulation_steps > 1, this is the micro-batch size | ||
block_size: 550 | ||
model: | ||
name_or_path: gpt2 # gpt2 for pre-trained, local path for checkpoint | ||
out_dir: ./out | ||
dropout: 0.1 # for pretraining 0 is good, for finetuning try 0.1+ | ||
train: | ||
grad_clip: 1.0 # clip gradients at this value, or disable if == 0.0 | ||
max_iters: 5000 # total number of training iterations | ||
gradient_accumulation_steps: 2 # used to simulate larger batch sizes | ||
always_save_checkpoint: False # if True, always save a checkpoint after each evaluation in out_dir | ||
decay_lr: True # whether to decay the learning rate | ||
optimizer: | ||
# keyword arguments for torch.optim.AdamW | ||
lr: 1.0e-5 | ||
weight_decay: 1.0e-1 | ||
betas: [0.9, 0.95] | ||
scheduler: | ||
# keyword arguments for torch.optim.lr_scheduler.CosineAnnealingLR | ||
T_max: 5000 # maximum number of iterations | ||
eta_min: 1.0e-6 # minimum learning rate | ||
sys: | ||
device: cuda # examples: cpu, cuda, cuda:0, cuda:1 etc., or try mps on macbooks | ||
dtype: bfloat16 # float32, bfloat16, or float16, the latter will auto implement a GradScaler | ||
compile: True # use PyTorch 2.0 to compile the model to be faster |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
io: | ||
eval_interval: 200 | ||
log_interval: 50 | ||
eval_iters: 100 | ||
data: | ||
batch_size: 16 # if gradient_accumulation_steps > 1, this is the micro-batch size | ||
block_size: 550 | ||
model: | ||
name_or_path: ./out | ||
dropout: 0.1 # for pretraining 0 is good, for finetuning try 0.1+ | ||
reward_model: | ||
out_dir: ./out_reward | ||
init_from: scratch # 'scratch' or 'resume' - if "resume" model will be loaded from out_dir_reward | ||
train: | ||
grad_clip: 1.0 # clip gradients at this value, or disable if == 0.0 | ||
max_iters: 20000 # total number of training iterations | ||
gradient_accumulation_steps: 2 # used to simulate larger batch sizes | ||
always_save_checkpoint: False # if True, always save a checkpoint after each eval | ||
decay_lr: False # whether to decay the learning rate | ||
optimizer: | ||
# keyword arguments for torch.optim.AdamW | ||
lr: 1.0e-5 | ||
weight_decay: 1.0e-1 | ||
betas: [0.9, 0.95] | ||
scheduler: | ||
# keyword arguments for torch.optim.lr_scheduler.CosineAnnealingLR | ||
T_max: 20000 | ||
eta_min: 1.0e-6 | ||
sys: | ||
device: cuda # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks | ||
dtype: bfloat16 # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler | ||
compile: True # use PyTorch 2.0 to compile the model to be faster |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
io: | ||
eval_interval: 6 | ||
log_interval: 1 | ||
eval_iters: 10 | ||
data: | ||
batch_size: 4 # if gradient_accumulation_steps > 1, this is the micro-batch size | ||
block_size: 550 | ||
model: | ||
name_or_path: ./out | ||
out_dir: ./out_rlhf | ||
dropout: 0.1 # for pretraining 0 is good, for finetuning try 0.1+ | ||
reward_model: | ||
name_or_path: ./out_reward | ||
train: | ||
grad_clip: 1.0 | ||
max_epochs: 1000 # total number of training iterations | ||
always_save_checkpoint: True # if True, always save a checkpoint after each eval | ||
decay_lr: True | ||
optimizer: | ||
# keyword arguments for torch.optim.AdamW | ||
lr: 5.0e-5 | ||
weight_decay: 0.0 # 01 | ||
betas: [0.9, 0.999] | ||
scheduler: | ||
# keyword arguments for torch.optim.lr_scheduler.CosineAnnealingLR | ||
T_max: 3000 # max_epochs * num_rollouts / ppo_batch_size | ||
eta_min: 5.0e-6 | ||
ppo: | ||
episode_length: 50 | ||
ppo_batch_size: 16 | ||
ppo_num_epochs: 3 | ||
num_rollouts_per_epoch: 32 | ||
sys: | ||
device: cuda # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks | ||
dtype: bfloat16 # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler | ||
compile: True # use PyTorch 2.0 to compile the model to be faster |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from torchrl.data.rlhf.prompt import get_prompt_dataloader_tldr | ||
|
||
__all__ = ["get_prompt_dataloader_tldr"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
from torchrl.modules.tensordict_module.actors import LMActorCritic | ||
from torchrl.modules.tensordict_module.common import VmapModule | ||
|
||
from .transformer import init_transformer | ||
|
||
__all__ = ["init_actor_critic"] | ||
|
||
|
||
def init_actor_critic(transformer_name_or_path, dropout, device, compile_): | ||
base_model = init_transformer( | ||
transformer_name_or_path, | ||
dropout, | ||
device, | ||
as_tensordictmodule=False, | ||
compile_=compile_, | ||
inference=True, | ||
) | ||
model = LMActorCritic(base_model) | ||
model.to(device) | ||
model.eval() | ||
actor = model.get_policy_operator() | ||
critic = model.get_value_operator() | ||
critic_head = model.get_value_head() | ||
|
||
return actor, VmapModule(critic), critic_head, base_model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import torch | ||
from tensordict.nn import TensorDictModule | ||
|
||
from torchrl.modules.models.rlhf import GPT2RewardModel | ||
|
||
|
||
def init_reward_model( | ||
transformer_path=None, reward_model_path=None, device=None, compile_=False | ||
): | ||
if not ((transformer_path is None) ^ (reward_model_path is None)): | ||
raise ValueError( | ||
"Exactly one of transformer_path or reward_model_path should be specified" | ||
) | ||
if transformer_path is not None: | ||
model = GPT2RewardModel(transformer_path) | ||
else: | ||
model = GPT2RewardModel.from_pretrained(reward_model_path) | ||
|
||
model.to(device) | ||
if compile_: | ||
print("Compiling the reward model...") | ||
model = torch.compile(model) | ||
|
||
model = TensorDictModule( | ||
model, | ||
in_keys=["input_ids", "attention_mask"], | ||
out_keys=["rewards", "end_scores"], | ||
) | ||
return model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
import torch | ||
from tensordict.nn import TensorDictModule | ||
from transformers import GPT2LMHeadModel | ||
|
||
|
||
def init_transformer( | ||
name_or_path, | ||
dropout, | ||
device, | ||
compile_, | ||
as_tensordictmodule=True, | ||
inference=False, | ||
): | ||
model_kwargs = { | ||
"resid_pdrop": dropout, | ||
"embd_pdrop": dropout, | ||
"attn_pdrop": dropout, | ||
"summary_first_dropout": dropout, | ||
} | ||
model = GPT2LMHeadModel.from_pretrained( | ||
name_or_path, return_dict=False, **model_kwargs | ||
) | ||
model.to(device) | ||
|
||
if compile_: | ||
# TODO: logging instead of printing? | ||
print("Compiling transformer model...") | ||
model = torch.compile(model) | ||
|
||
if as_tensordictmodule: | ||
model = TensorDictModule( | ||
model, | ||
in_keys={ | ||
"input_ids": "input_ids", | ||
"attention_mask": "attention_mask", | ||
"labels": "labels", | ||
}, | ||
out_keys=["logits"] if inference else ["loss", "logits"], | ||
) | ||
return model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
datasets | ||
hydra-core | ||
matplotlib | ||
numpy | ||
PyYAML | ||
requests | ||
tiktoken | ||
tqdm | ||
transformers | ||
git+https://github.com/pytorch/rl | ||
git+https://github.com/pytorch-labs/tensordict |
Oops, something went wrong.