-
Notifications
You must be signed in to change notification settings - Fork 375
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feature(whl): add rlhf pipeline. #748
base: main
Are you sure you want to change the base?
Conversation
ding/bonus/ppof.py
Outdated
@@ -18,6 +19,7 @@ | |||
from .model import PPOFModel | |||
from .config import get_instance_config, get_instance_env, get_hybrid_shape | |||
from ding.bonus.common import TrainingReturn, EvalReturn | |||
from ..framework.middleware.collector import ChatCollector |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
merge it into ding.framework
""" | ||
Overview: | ||
The class of the collector running by steps, including model inference and transition \ | ||
process. Use the `__call__` method to execute the whole collection process. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why indent here
ding/model/common/utils.py
Outdated
|
||
def top_p_logits(logits, topp=0.9, filter_value=0, min_topk=1): | ||
""" | ||
Filter a distribution of logits using nucleus (top-p) filtering |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
polish comments add add unittest
ding/model/common/utils.py
Outdated
if topp > 0: | ||
logits_sorted, inds = torch.sort(logits, dim=-1, descending=True) | ||
mask = (logits_sorted.cumsum(dim=-1) - logits_sorted) >= topp | ||
mask[:, :min_topk] = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
..., :min_topk
ding/model/template/vac.py
Outdated
@@ -1,4 +1,7 @@ | |||
from typing import Union, Dict, Optional | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
move these modifications to a new single file: lm_vac.py
|
||
def __init__(self, config, opt, tokenizer): | ||
super().__init__(config) | ||
self.opt = opt |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why define opt here
else: | ||
logits = self.reward_head(output.last_hidden_state).squeeze(-1) | ||
|
||
return (logits, ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why return a tuple here
self._init_flag = False | ||
|
||
def reset(self): | ||
self.last_batch = next(self.generator) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you need to restrat generatore here?
|
||
class LlamaRewardModel(LlamaForCausalLM): | ||
|
||
def __init__(self, config, opt, tokenizer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we move the creation of tokenizer insides the constructor of RM?
@@ -0,0 +1,50 @@ | |||
from easydict import EasyDict |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
move it to dizoo/chat/entry
Codecov ReportAttention:
Additional details and impacted files@@ Coverage Diff @@
## main #748 +/- ##
==========================================
+ Coverage 76.78% 76.83% +0.04%
==========================================
Files 671 674 +3
Lines 53196 53935 +739
==========================================
+ Hits 40847 41440 +593
- Misses 12349 12495 +146
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. |
Description
Related Issue
TODO
Check List