-
Notifications
You must be signed in to change notification settings - Fork 312
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Example] RLHF end to end example #1324
Conversation
examples/rlhf/train_rlhf.py
Outdated
"""Returns adaptively updated KL coefficient, βₜ₊₁. | ||
Arguments: | ||
current: The current KL value between the newest policy and the initial policy. | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wrong formatting
|
||
For debugging purposes, we also generate responses to a fixed prompt so that the | ||
quality of the model can be visually assessed during training. | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing args and example
batch = next(dataloader) | ||
# NOTE: disable kl for evaluation | ||
td = rollout_from_model.rollout_from_data(batch, kl_coef=0.0) | ||
rewards[k] = td.get(("next", "reward")).sum(dim=1).mean().item() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why item?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to get a scalar instead of a scalar-tensor
Co-authored-by: Vincent Moens <[email protected]>
Co-authored-by: Alessandro Pietro Bardelli <[email protected]>
Co-authored-by: Vincent Moens <[email protected]>
# Conflicts: # test/test_rlhf.py # torchrl/data/rlhf/utils.py # torchrl/modules/tensordict_module/actors.py # torchrl/modules/tensordict_module/common.py
torchrl/data/rlhf/utils.py
Outdated
model, | ||
ref_model, | ||
reward_model, | ||
kl_controller, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is missing from the doc.
I'm not so sure about this kl_controller that is passed to the module. I feel it should be handled separately. It's like passing the lr_scheduler to the optimizer, the reason we don't do that is that it mixes responsibilities between modules. It gives the impression that one module has multiple responsibilities but it is less clear than doing things explicitly in the main code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are you suggesting we go back to passing just the kl coefficient?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the KL controller should be another class, but then we have 2 options:
the KL controller changes the KL coefficienbt of the other class (like the LR scheduler changes the LR of the optimizer or the target param updaters in torchrl update the target params of the loss) or we explicitely pass the kl coef.
I think the first option is more "pytorch"-style
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the KL controller should be another class
actually is another class
the KL controller changes the KL coefficienbt of the other class
isn't this what we are currently doing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we not remove it then?
) | ||
|
||
rollout_from_model = RolloutFromModel(model, ref_model, reward_model) | ||
kl_controller = AdaptiveKLController(rollout_from_model, 0.1, 6, 10000) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think AdaptiveKLController takes the rollout_from_model as input does it?
torchrl/data/rlhf/utils.py
Outdated
class KLControllerBase(abc.ABC): | ||
"""Base class for KL controllers. | ||
|
||
Each controller must implement an update method that takes the current KL value and | ||
the number of steps and updates the self.coef attribute, which will multiply | ||
the KL during calculation of the reward. | ||
""" | ||
|
||
@abc.abstractmethod | ||
def update(self, kl_value: float, n_steps: int): | ||
pass | ||
|
||
|
||
class ConstantKLController(KLControllerBase): | ||
"""Constant KL Controller. | ||
|
||
This controller maintains a fixed coefficient no matter what values it is updated | ||
with. | ||
|
||
Arguments: | ||
coefficient (float): The coefficient to multiply KL with when calculating the | ||
reward. | ||
""" | ||
|
||
def __init__(self, coefficient): | ||
self.coef = coefficient | ||
|
||
def update(self, kl_value: float, n_steps: int): | ||
pass | ||
|
||
|
||
class AdaptiveKLController(KLControllerBase): | ||
"""Adaptive KL Controller as described in Ziegler et al. "Fine-Tuning Language Models from Human Preferences". | ||
|
||
Arguments: | ||
init_kl_coef (float): The starting value of the coefficient. | ||
target (float): The target KL value. When the observed KL is smaller, the | ||
coefficient is decreased, thereby relaxing the KL penalty in the training | ||
objective and allowing the model to stray further from the reference model. | ||
When the observed KL is greater than the target, the KL coefficient is | ||
increased, thereby pulling the model back towards the reference model. | ||
horizon (int): Scaling factor to control how aggressively we update the | ||
coefficient. | ||
|
||
Reference: Section 2.2 https://arxiv.org/pdf/1909.08593.pdf#page=2 | ||
Source: https://github.com/openai/lm-human-preferences/blob/master/lm_human_preferences/train_policy.py | ||
""" | ||
|
||
def __init__(self, init_kl_coef: float, target: float, horizon: int): | ||
self.coef = init_kl_coef | ||
self.target = target | ||
self.horizon = horizon | ||
|
||
def update(self, kl_value: float, n_steps: int): | ||
"""Update ``self.coef`` adaptively. | ||
|
||
Arguments: | ||
kl_value: The current KL value between the newest policy and the initial | ||
policy. | ||
n_steps: The number of training steps taken since last update. | ||
""" | ||
proportional_error = np.clip(kl_value / self.target - 1, -0.2, 0.2) # ϵₜ | ||
mult = 1 + proportional_error * n_steps / self.horizon | ||
self.coef *= mult # βₜ₊₁ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are these guys really part of data?
They seem more related to the model to me.
They act on a class that belongs to data (maybe should be moved to collector tbh) but the KL coef is something that has to do with the stochastic policy (the language model, in our case), not the data.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
New classes should be added to the doc (provided we're sure of where they belong)
torchrl/data/rlhf/utils.py
Outdated
model, | ||
ref_model, | ||
reward_model, | ||
kl_controller, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we not remove it then?
torchrl/data/rlhf/utils.py
Outdated
"""Makes a step in the KL coefficient schedule.""" | ||
raise NotImplementedError | ||
self.kl_controller.update(kl_value, n_steps) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto, maybe this function should go away?
torchrl/data/rlhf/utils.py
Outdated
@@ -167,7 +242,7 @@ def create_rollout_td(self, batch, generated, log_probs, log_ratio, kl_coef=0.1) | |||
) | |||
reward_raw = clipped_scores.unsqueeze(-1).unsqueeze(-1) | |||
reward_raw = reward_raw * done | |||
reward_kl = -kl_coef * log_ratio.unsqueeze(-1) | |||
reward_kl = -self.kl_controller.coef * log_ratio.unsqueeze(-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here
merge after #1309, #1319, #1316, #1315 + rebase
Adds a complete end 2 end RLHF pipeline