Skip to content

Commit

Permalink
Add gemma chat template
Browse files Browse the repository at this point in the history
  • Loading branch information
huyiwen committed Jul 11, 2024
1 parent db7b76c commit a71e77a
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 15 deletions.
40 changes: 26 additions & 14 deletions utilization/chat_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@ def add_space(
msg: str,
context: str,
auto_leading_space: bool = True,
remove_space_between: bool = True,
no_space_between: bool = True,
starts: Optional[List[str]] = None,
ends: Optional[List[str]] = None
) -> str:
if starts is None or ends is None or remove_space_between is False:
if starts is None or ends is None or no_space_between is False:
context_ends_special = False
msg_starts_special = False
else:
Expand All @@ -24,9 +24,7 @@ def add_space(
return msg


def smart_space(
parts: List[Tuple[str, bool]], auto_leading_space: bool, remove_space_between: bool, seq: List[str]
) -> str:
def smart_space(parts: List[Tuple[str, bool]], auto_leading_space: bool, no_space_between: bool, seq: List[str]) -> str:
starts = [seq[role + "_start"] for role in ["system", "user", "assistant"] if (role + "_start") in seq]
ends = [seq[role + "_end"] for role in ["system", "user", "assistant"] if (role + "_end") in seq]
if "bos_token" in seq:
Expand All @@ -38,7 +36,7 @@ def smart_space(
part[0],
rendered,
auto_leading_space=auto_leading_space and part[1],
remove_space_between=remove_space_between,
no_space_between=no_space_between,
starts=starts,
ends=ends
)
Expand All @@ -65,7 +63,7 @@ def smart_space(
"{%- set data.parts = data.parts + [(seq['generation_prompt'], True)] -%}"
"{%- endif -%}"
""
"{{ data.parts | smart_space(auto_leading_space, remove_space_between, seq) }}"
"{{ data.parts | smart_space(auto_leading_space, no_space_between, seq) }}"
)

# Chat configs format:
Expand All @@ -86,7 +84,9 @@ def smart_space(
# - assistant_end: The string to append to the assistant message.
# - auto_leading_space: Whether to add a leading space when concatenating two
# strings if the first string does not end with a whitespace.
# - no_space_between: Whether to not add the leading space between special tokens.
# - default_stop: A list of strings that indicate the end of a message.
# - merge_system_to_user: Whether to convert system message to part of next user message.
#
DEFAULT_CHAT_CONFIGS: Dict[str, Union[Dict[str, Any], str]] = {
"base": {
Expand All @@ -98,7 +98,7 @@ def smart_space(
"assistant_end": "\n\n",
"auto_leading_space": True,
"final_rstrip": True,
"remove_space_between": False,
"no_space_between": False,
"default_stop": [],
},
"llama2": {
Expand All @@ -110,7 +110,7 @@ def smart_space(
"assistant_end": " </s>",
"auto_leading_space": True,
"final_rstrip": False,
"remove_space_between": True,
"no_space_between": True,
"default_stop": [],
},
"chatml": {
Expand All @@ -122,7 +122,7 @@ def smart_space(
"assistant_end": "<|im_end|>\n",
"auto_leading_space": True,
"final_rstrip": False,
"remove_space_between": True,
"no_space_between": True,
"default_stop": ["<|im_end|>"],
},
"zephyr": {
Expand All @@ -134,7 +134,7 @@ def smart_space(
"assistant_end": "</s>\n",
"auto_leading_space": True,
"final_rstrip": False,
"remove_space_between": True,
"no_space_between": True,
"default_stop": ["</s>"],
},
"phi3": {
Expand All @@ -146,7 +146,7 @@ def smart_space(
"assistant_end": "<|end|>\n",
"auto_leading_space": True,
"final_rstrip": False,
"remove_space_between": True,
"no_space_between": True,
"default_stop": ["<|end|>", "<|endoftext|>"],
},
"llama3": {
Expand All @@ -158,7 +158,7 @@ def smart_space(
"assistant_end": "<|eot_id|>",
"auto_leading_space": True,
"final_rstrip": False,
"remove_space_between": True,
"no_space_between": True,
"default_stop": ["<|eot_id|>"],
},
"alpaca": {
Expand All @@ -170,7 +170,19 @@ def smart_space(
"assistant_end": "\n\n",
"auto_leading_space": True,
"final_rstrip": False,
"remove_space_between": False,
"no_space_between": False,
"default_stop": ["###"],
},
"gemma": {
"all_start": "<bos>",
"merge_system_to_user": True,
"user_start": "<start_of_turn>user\n",
"user_end": "<end_of_turn>",
"assistant_start": "<start_of_turn>model\n",
"assistant_end": "<end_of_turn>",
"auto_leading_space": True,
"final_rstrip": False,
"no_space_between": True,
"default_stop": ["<end_of_turn>", "<start_of_turn>"],
}
}
9 changes: 9 additions & 0 deletions utilization/model/model_utils/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def __init__(
self.auto_leading_space = chat_config.pop("auto_leading_space", True)
self.final_lstrip = chat_config.pop("final_lstrip", True)
self.final_rstrip = chat_config.pop("final_rstrip", True)
self.merge_system_to_user = chat_config.pop("merge_system_to_user", False)

# api model does not need bos_token
if "bos_token" not in chat_config:
Expand Down Expand Up @@ -375,6 +376,12 @@ def get_generation_results(self) -> Union[str, Tuple[str, ...]]:
assert self.messages[-1]["role"] == "assistant"
return self.messages[-1]["content"]

def _merge_system_to_user(self):
"""Whether to convert system message to part of next user message."""
if self.merge_system_to_user and self.messages[0]["role"] == "system":
self.messages[1]["content"] = self.messages[0]["content"] + self.messages[1]["content"]
self.messages.pop(0)

def set_formatter(
self,
formatter: ConversationFormatter,
Expand All @@ -384,6 +391,8 @@ def set_formatter(
self.formatter = formatter
self.model_evaluation_method = model_evaluation_method
self.split = split and self.get_segs_num() > 1
self.merge_system_to_user = self.formatter.merge_system_to_user
self._merge_system_to_user()

def to_model_prompt(
self,
Expand Down
3 changes: 2 additions & 1 deletion utilization/utils/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,8 @@ def __post_init__(self):
if self.model_name_or_path in API_MODELS:
auto_model_type = API_MODELS[self.model_name_or_path]["model_type"]
elif self.is_local_model():
auto_model_type = "chat" if re.search(r"chat|instruct", self.model_name_or_path.lower()) else "base"
# gemma uses it: instruction-tuned
auto_model_type = "chat" if re.search(r"chat|instruct|it", self.model_name_or_path.lower()) else "base"
else:
auto_model_type = None

Expand Down

0 comments on commit a71e77a

Please sign in to comment.