Skip to content

Latest commit

 

History

History
 
 

rlhf

Folders and files

NameName
Last commit message
Last commit date

parent directory

..
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

RLHF example

This example uses RLHF (Reinforcement Learning with Human Feedback) to train a language model to summarize Reddit posts.

Getting started

Make sure you have PyTorch>=2.0 installed. You can find installation instructions here.

From this directory, you can install extra requirements for running these examples with

pip install -r requirements.txt

Training the models

Training the transformer

Once the data has been prepared, you can train the GPT model.

python train.py

Default configuration can be found in config/train.yaml, and any option can be overridden with command-line arguments, for example to run the training script with a different batch size:

python train.py --batch_size=128

NOTE: Apple Silicon Macbooks users make sure to use --device=mps and prepend all commands with PYTORCH_ENABLE_MPS_FALLBACK=1 to enable CPU fallback

Training the reward model

Once you have completed supervised fine-tuning, copy the desired model checkpoint to ./out or update the config to point model.name_or_path at the relevant checkpoint in the timestamped working directory created by Hydra. You can then train the reward model with:

python train_reward.py

Training the final model with RLHF

Once again, make sure you have either updated the configuration to point reward_model.name_or_path at the relevant timestamped working directory, or copy the checkpoint to ./out_reward. You can then train the final model by running

python train_rlhf.py