Skip to content

Commit

Permalink
fix(hf_tokenizer): Rename to HFTokenizer and corresponding flags
Browse files Browse the repository at this point in the history
#1251
Branch: TokenizersTokenizer-1251

Co-Authored-By: [email protected]
Signed-off-by: Gabe Goodhart <[email protected]>
  • Loading branch information
gabe-l-hart committed Oct 24, 2024
1 parent 87bcf5c commit 5f332e7
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 22 deletions.
4 changes: 2 additions & 2 deletions tokenizer/tokenizers.py → tokenizer/hf_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
from .base import TokenizerBase


class TokenizersTokenizer(TokenizerBase):
class HFTokenizer(TokenizerBase):
"""
Wrapper around the `tokenizers` library for API compatibility
Wrapper around the Huggingface `tokenizers` library for API compatibility
"""

def __init__(self, file_path: str):
Expand Down
28 changes: 14 additions & 14 deletions torchchat/cli/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ class TokenizerArgs:
tokenizer_path: Optional[Union[Path, str]] = None
is_sentencepiece: bool = False
is_tiktoken: bool = False
is_tokenizers: bool = False
is_hf_tokenizer: bool = False
t: Optional[Any] = None

def __post_init__(self):
Expand All @@ -200,7 +200,7 @@ def __post_init__(self):
self.t = TiktokenTokenizer(model_path=str(self.tokenizer_path))
self.is_tiktoken = True
self.is_sentencepiece = False
self.is_tokenizers = False
self.is_hf_tokenizer = False
return
except:
pass
Expand All @@ -211,25 +211,25 @@ def __post_init__(self):
self.t = SentencePieceProcessor(model_file=str(self.tokenizer_path))
self.is_tiktoken = False
self.is_sentencepiece = True
self.is_tokenizers = False
self.is_hf_tokenizer = False
return
except:
pass

try:
from tokenizer.tokenizers import TokenizersTokenizer
from tokenizer.hf_tokenizer import HFTokenizer

self.t = TokenizersTokenizer(str(self.tokenizer_path))
self.t = HFTokenizer(str(self.tokenizer_path))
self.is_tiktoken = False
self.is_sentencepiece = False
self.is_tokenizers = True
self.is_hf_tokenizer = True
return
except:
pass

self.is_tiktoken = False
self.is_sentencepiece = False
self.is_tokenizers = False
self.is_hf_tokenizer = False
self.t = None
return

Expand All @@ -241,25 +241,25 @@ def validate_model(
if model is None:
return

if len(list(filter(lambda x: x, [self.is_tiktoken, self.is_tokenizers, self.is_sentencepiece]))) != 1:
if sum([self.is_tiktoken, self.is_hf_tokenizer, self.is_sentencepiece]) != 1:
raise RuntimeError(f"no tokenizer was found at {self.tokenizer_path}")

is_tiktoken = self.is_tiktoken
is_sentencepiece = self.is_sentencepiece
is_tokenizers = self.is_tokenizers
is_hf_tokenizer = self.is_hf_tokenizer
use_tiktoken = model.config.use_tiktoken
use_tokenizers = model.config.use_tokenizers
use_sentencepiece = not (use_tiktoken or use_tokenizers)
use_hf_tokenizer = model.config.use_hf_tokenizer
use_sentencepiece = not (use_tiktoken or use_hf_tokenizer)

if (
(is_tiktoken and not use_tiktoken) or
(is_tokenizers and not use_tokenizers) or
(is_hf_tokenizer and not use_hf_tokenizer) or
(is_sentencepiece and not use_sentencepiece)
):
raise RuntimeError(
"model-specified tokenizer ({}) does not match provided tokenizer ({}) for {}".format(
tokenizer_setting_to_name(use_tiktoken, use_tokenizers),
tokenizer_setting_to_name(is_tiktoken, is_tokenizers),
tokenizer_setting_to_name(use_tiktoken, use_hf_tokenizer),
tokenizer_setting_to_name(is_tiktoken, is_hf_tokenizer),
model_description,
)
)
Expand Down
12 changes: 6 additions & 6 deletions torchchat/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ class TransformerArgs:
ffn_dim_multiplier: Optional[int] = None
# Select the desired tokenizer. Defaults to sentencepiece
use_tiktoken: bool = False
use_tokenizers: bool = False
use_hf_tokenizer: bool = False
max_seq_length: int = 8192
rope_scaling: Optional[Dict[str, Any]] = None
# For pipeline parallel
Expand Down Expand Up @@ -329,14 +329,14 @@ class ModelArgs:
model_type: ModelType
transformer_args: Dict[str, Dict[str, Any]]
use_tiktoken: bool
use_tokenizers: bool
use_hf_tokenizer: bool

def __init__(
self,
transformer_args: Dict[str, Dict[str, Any]],
model_type: ModelType = ModelType.TextOnly,
use_tiktoken: bool = False,
use_tokenizers: bool = False,
use_hf_tokenizer: bool = False,
) -> None:
self._sanity_check(transformer_args, model_type)

Expand All @@ -345,7 +345,7 @@ def __init__(

# Model-level attributes
self.use_tiktoken = use_tiktoken
self.use_tokenizers = use_tokenizers
self.use_hf_tokenizer = use_hf_tokenizer

def _sanity_check(
self,
Expand All @@ -372,8 +372,8 @@ def from_params(cls, params_path):
}

use_tiktoken = loaded_params.get("use_tiktoken", False)
use_tokenizers = loaded_params.get("use_tokenizers", False)
return cls(transformer_args, model_type, use_tiktoken, use_tokenizers)
use_hf_tokenizer = loaded_params.get("use_hf_tokenizer", False)
return cls(transformer_args, model_type, use_tiktoken, use_hf_tokenizer)

@classmethod
def from_table(cls, name: str):
Expand Down

0 comments on commit 5f332e7

Please sign in to comment.