Skip to content

Commit

Permalink
addressing comments about klcontroller
Browse files Browse the repository at this point in the history
  • Loading branch information
apbard committed Jul 5, 2023
1 parent c07ac93 commit f463e0e
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 18 deletions.
8 changes: 4 additions & 4 deletions examples/rlhf/train_rlhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def __call__(self, model, dataloader):
model,
self.ref_model,
self.reward_model,
kl_controller=ConstantKLController(0.0), # disable KL for evaluation
kl_coef=0, # disable KL for evaluation
max_new_tokens=self.episode_length,
)
rewards = torch.zeros(self.eval_iters)
Expand Down Expand Up @@ -188,7 +188,6 @@ def main():
scheduler = None
if train_cfg.decay_lr:
scheduler = CosineAnnealingLR(optimizer, **train_cfg.scheduler)
kl_controller = AdaptiveKLController(0.1, 6, 10000)

rb = TensorDictReplayBuffer(
storage=LazyTensorStorage(episode_length * num_rollouts_per_epoch),
Expand All @@ -203,7 +202,8 @@ def main():
prefetch=10,
)

rollout_from_model = RolloutFromModel(model, ref_model, reward_model, kl_controller)
rollout_from_model = RolloutFromModel(model, ref_model, reward_model)
kl_controller = AdaptiveKLController(rollout_from_model, 0.1, 6, 10000)

best_val_reward = float("-inf")
it = 0 # it is equivalent to batch_size number of episodes
Expand Down Expand Up @@ -231,7 +231,7 @@ def main():
rollout_kl_reward = torch.tensor(rollout_kl).mean().cpu().item()
# recover true kl
rollout_kl = -rollout_kl_reward / kl_controller.coef
rollout_from_model.kl_update(
kl_controller.update(
rollout_kl, num_rollouts_per_epoch / batch_size
)

Expand Down
34 changes: 20 additions & 14 deletions torchrl/data/rlhf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ class KLControllerBase(abc.ABC):
"""Base class for KL controllers.
Each controller must implement an update method that takes the current KL value and
the number of steps and updates the self.coef attribute, which will multiply
the KL during calculation of the reward.
the number of steps and updates the kl_coef attribute of the wrapped model,
which will multiply the KL during calculation of the reward.
"""

@abc.abstractmethod
Expand All @@ -38,21 +38,27 @@ class ConstantKLController(KLControllerBase):
with.
Arguments:
coefficient (float): The coefficient to multiply KL with when calculating the
model: wrapped model that needs to be controlled. Must have attribute 'kl_coef'
kl_coef (float): The coefficient to multiply KL with when calculating the
reward.
"""

def __init__(self, coefficient):
self.coef = coefficient
def __init__(self, model, kl_coef):
self.model = model
if not hasattr(model, "kl_coef"):
raise AttributeError("Model input to ConstantKLController doesn't have attribute 'kl_coef'")
self.coef = kl_coef
self.model.kl_coef = self.coef

def update(self, kl_value: float, n_steps: int):
pass
self.model.kl_coef = self.coef


class AdaptiveKLController(KLControllerBase):
"""Adaptive KL Controller as described in Ziegler et al. "Fine-Tuning Language Models from Human Preferences".
Arguments:
model: wrapped model that needs to be controlled. Must have attribute 'kl_coef'
init_kl_coef (float): The starting value of the coefficient.
target (float): The target KL value. When the observed KL is smaller, the
coefficient is decreased, thereby relaxing the KL penalty in the training
Expand All @@ -66,10 +72,12 @@ class AdaptiveKLController(KLControllerBase):
Source: https://github.com/openai/lm-human-preferences/blob/master/lm_human_preferences/train_policy.py
"""

def __init__(self, init_kl_coef: float, target: float, horizon: int):
def __init__(self, model, init_kl_coef: float, target: float, horizon: int):
self.model = model
self.coef = init_kl_coef
self.target = target
self.horizon = horizon
self.model.kl_coef = self.coef

def update(self, kl_value: float, n_steps: int):
"""Update ``self.coef`` adaptively.
Expand All @@ -82,6 +90,7 @@ def update(self, kl_value: float, n_steps: int):
proportional_error = np.clip(kl_value / self.target - 1, -0.2, 0.2) # ϵₜ
mult = 1 + proportional_error * n_steps / self.horizon
self.coef *= mult # βₜ₊₁
self.model.kl_coef = self.coef


class RolloutFromModel:
Expand All @@ -101,6 +110,7 @@ class RolloutFromModel:
reward_model: (nn.Module, tensordict.nn.TensorDictModule): a model which, given
``input_ids`` and ``attention_mask``, calculates rewards for each token and
end_scores (the reward for the final token in each sequence).
kl_coef: (float, optional): initial kl coefficient.
max_new_tokens (int, optional): the maximum length of the sequence.
Defaults to 50.
score_clip (float, optional): Scores from the reward model are clipped to the
Expand Down Expand Up @@ -159,7 +169,7 @@ def __init__(
model,
ref_model,
reward_model,
kl_controller,
kl_coef=0.1,
max_new_tokens=50,
score_clip=10.0,
):
Expand All @@ -173,11 +183,7 @@ def __init__(
self.reward_model = reward_model
self.max_new_tokens = max_new_tokens
self.score_clip = score_clip
self.kl_controller = kl_controller

def kl_update(self, kl_value, n_steps):
"""Makes a step in the KL coefficient schedule."""
self.kl_controller.update(kl_value, n_steps)
self.kl_coef = kl_coef

@torch.no_grad()
def rollout_from_data(self, batch):
Expand Down Expand Up @@ -242,7 +248,7 @@ def create_rollout_td(self, batch, generated, log_probs, log_ratio):
)
reward_raw = clipped_scores.unsqueeze(-1).unsqueeze(-1)
reward_raw = reward_raw * done
reward_kl = -self.kl_controller.coef * log_ratio.unsqueeze(-1)
reward_kl = -self.kl_coef * log_ratio.unsqueeze(-1)
reward = reward_raw + reward_kl
td = {
"action": action,
Expand Down

0 comments on commit f463e0e

Please sign in to comment.