Skip to content

Commit

Permalink
readme
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Oct 3, 2023
1 parent d983ebd commit 097c443
Showing 1 changed file with 19 additions and 7 deletions.
26 changes: 19 additions & 7 deletions examples/rlhf/README.md
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
# RLHF example

This example uses RLHF (Reinforcement Learning with Human Feedback) to train a language model to summarize Reddit posts.
This example uses RLHF (Reinforcement Learning with Human Feedback) to train a
language model to summarize Reddit posts.

## Getting started

Make sure you have PyTorch 2.0 installed. You can find installation instructions [here](https://pytorch.org/get-started/locally/).
Make sure you have PyTorch>=2.0 installed. You can find installation instructions
[here](https://pytorch.org/get-started/locally/).

From this directory, you can install extra requirements for running these examples with
From this directory, you can install extra requirements for running these
examples with

```sh
pip install -r requirements.txt
Expand All @@ -21,24 +24,33 @@ Once the data has been prepared, you can train the GPT model.
python train.py
```

Default configuration can be found in `config/train.yaml`, and any option can be overridden with command-line arguments, for example to run the training script with a different batch size
Default configuration can be found in `config/train.yaml`, and any option can
be overridden with command-line arguments, for example to run the training
script with a different batch size:

```sh
python train.py --batch_size=128
```
> **_NOTE:_** Apple Silicon Macbooks users make sure to use `--device=mps` and prepend all commands with `PYTORCH_ENABLE_MPS_FALLBACK=1` to enable CPU fallback
> **_NOTE:_** Apple Silicon Macbooks users make sure to use `--device=mps`
> and prepend all commands with `PYTORCH_ENABLE_MPS_FALLBACK=1` to enable CPU fallback
### Training the reward model

Once you have completed supervised fine-tuning, copy the desired model checkpoint to `./out` or update the config to point `model.name_or_path` at the relevant checkpoint in the timestamped working directory created by Hydra. You can then train the reward model with
Once you have completed supervised fine-tuning, copy the desired model
checkpoint to `./out` or update the config to point `model.name_or_path` at
the relevant checkpoint in the timestamped working directory created by Hydra.
You can then train the reward model with:

```sh
python train_reward.py
```

### Training the final model with RLHF

Once again, make sure you have either updated the configuration to point `reward_model.name_or_path` at the relevant timestamped working directory, or copy the checkpoint to `./out_reward`. You can then train the final model by running
Once again, make sure you have either updated the configuration to point
`reward_model.name_or_path` at the relevant timestamped working directory, or
copy the checkpoint to `./out_reward`.
You can then train the final model by running

```sh
python train_rlhf.py
Expand Down

0 comments on commit 097c443

Please sign in to comment.