Skip to content

Commit

Permalink
fix(tokenizers): Fix how bos/eos tokens are parsed from tokenizers (lib)
Browse files Browse the repository at this point in the history
Branch: GraniteCodeSupport

Signed-off-by: Gabe Goodhart <[email protected]>
  • Loading branch information
gabe-l-hart committed Oct 9, 2024
1 parent 79e4ccb commit f2cba4c
Showing 1 changed file with 50 additions and 22 deletions.
72 changes: 50 additions & 22 deletions tokenizer/tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
# LICENSE file in the root directory of this source tree.

# Standard
from typing import List
from typing import List, Optional
import json
import os

# Third Party
from tokenizers import Tokenizer
Expand All @@ -21,26 +22,53 @@ class TokenizersTokenizer(TokenizerBase):
"""

def __init__(self, file_path: str):
self._tokenizer = Tokenizer.from_file(file_path)
# The BOS and EOS tokens are not easily visible from the tokenizer
# object itself, so we extract them at construction with a sample call
self._bos_token = self._tokenizer.encode("Test", add_special_tokens=True).ids[0]
# There is no explicit BOS token in many tokenizers, so we look for a
# single special token that most resembles the BOS token.
self._eos_token = None
tok_content = json.loads(self._tokenizer.to_str())
end_toks = [
tok for tok in tok_content['added_tokens']
if tok["special"] and "end" in tok["content"]
]
assert end_toks, "Unable to find an EOS token in the added tokens"
if len(end_toks) > 1:
end_text_toks = [
tok for tok in end_toks if "text" in tok["content"]
# If the path is a directory, look for "tokenizer.json" which is
# standard for transformers checkpoints and also look for the
# "tokenizer_config.json" file to parse eos/bos tokens
if os.path.isdir(file_path):
tokenizer_path = os.path.join(file_path, "tokenizer.json")
tokenizer_config_path = os.path.join(file_path, "tokenizer_config.json")
else:
tokenizer_path = file_path
tokenizer_config_path = os.path.join(os.path.dirname(file_path), "tokenizer_config.json")
if not os.path.isfile(tokenizer_path):
tokenizer_config_path = None

# Load the tokenizer itself
self._tokenizer = Tokenizer.from_file(tokenizer_path)

# If available, parse bos/eos tokens from the tokenizer config
self._bos_id, self._eos_id = None, None
if tokenizer_config_path is not None:
with open(tokenizer_config_path, "r") as handle:
tok_config = json.load(handle)
bos_token = tok_config.get("bos_token")
eos_token = tok_config.get("eos_token")
if bos_token is not None:
self._bos_id = self._tokenizer.token_to_id(bos_token)
if eos_token is not None:
self._eos_id = self._tokenizer.token_to_id(eos_token)

# If no eos/bos tokens found, go looking for them!
if None in [self._bos_id, self._eos_id]:
tok_content = json.loads(self._tokenizer.to_str())
if self._bos_id is None:
self._bos_id = self._look_for_special_token(tok_content, ["begin", "text"])
if self._eos_id is None:
self._eos_id = self._look_for_special_token(tok_content, ["end", "text"])

assert None not in [self._bos_id, self._eos_id], "Unable to find an BOS/EOS tokens"

@staticmethod
def _look_for_special_token(added_tokens: dict, search_strs: List[str]) -> Optional[int]:
candidate_toks = added_tokens
for search_str in search_strs:
candidate_toks = [
tok for tok in candidate_toks
if tok["special"] and search_str in tok["content"]
]
if len(end_text_toks) == 1:
self._eos_token = end_text_toks[0]["id"]
assert self._eos_token is not None, "Unable to find an EOS token in the added tokens"
if len(candidate_toks) == 1:
return candidate_toks[0]["id"]

def encode(
self,
Expand All @@ -58,7 +86,7 @@ def decode(self, ids: List[int]) -> str:
return self._tokenizer.decode(ids)

def bos_id(self) -> int:
return self._bos_token
return self._bos_id

def eos_id(self) -> int:
return self._eos_token
return self._eos_id

0 comments on commit f2cba4c

Please sign in to comment.