This example uses RLHF (Reinforcement Learning with Human Feedback) to train a language model to summarize Reddit posts.
Make sure you have PyTorch>=2.0 installed. You can find installation instructions here.
From this directory, you can install extra requirements for running these examples with
pip install -r requirements.txt
Once the data has been prepared, you can train the GPT model.
python train.py
Default configuration can be found in config/train.yaml
, and any option can
be overridden with command-line arguments, for example to run the training
script with a different batch size:
python train.py --batch_size=128
NOTE: Apple Silicon Macbooks users make sure to use
--device=mps
and prepend all commands withPYTORCH_ENABLE_MPS_FALLBACK=1
to enable CPU fallback
Once you have completed supervised fine-tuning, copy the desired model
checkpoint to ./out
or update the config to point model.name_or_path
at
the relevant checkpoint in the timestamped working directory created by Hydra.
You can then train the reward model with:
python train_reward.py
Once again, make sure you have either updated the configuration to point
reward_model.name_or_path
at the relevant timestamped working directory, or
copy the checkpoint to ./out_reward
.
You can then train the final model by running
python train_rlhf.py