forked from pytorch/torchchat
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix(tokenizer): TokenizersTokenizer -> HFTokenizer
pytorch#1251 Branch: TokenizersCpp-1251 Signed-off-by: Gabe Goodhart <[email protected]>
- Loading branch information
1 parent
2bef32d
commit 2e6f8b6
Showing
5 changed files
with
305 additions
and
598 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,296 @@ | ||
/* | ||
* Copyright (c) Meta Platforms, Inc. and affiliates. | ||
* All rights reserved. | ||
* | ||
* This source code is licensed under the BSD-style license found in the | ||
* LICENSE file in the root directory of this source tree. | ||
*/ | ||
#include "tokenizer.h" | ||
|
||
// Standard | ||
#include <iostream> | ||
#include <filesystem> | ||
#include <fstream> | ||
#include <string> | ||
|
||
// Third Party | ||
#include <nlohmann/json.hpp> | ||
|
||
// Local | ||
#include "pre_tokenizer.h" | ||
|
||
namespace fs = std::filesystem; | ||
using json = nlohmann::json; | ||
|
||
|
||
// -------------------------private method end------------------------------- | ||
// -------------------------public method start------------------------------- | ||
|
||
HFTokenizer::HFTokenizer() : Tiktoken() {} | ||
|
||
void HFTokenizer::load(const std::string& path) { | ||
|
||
// If this is a directory, look for tokenizer.json and tokenizer_config.json | ||
std::string model_json = path; | ||
std::string model_config_json = ""; | ||
if (fs::is_directory(path)) { | ||
const fs::path root(path); | ||
model_json = root / "tokenizer.json"; | ||
if (!fs::exists(model_json)) { | ||
fprintf(stderr, "no tokenizer.json found in %s\n", path.c_str()); | ||
exit(EXIT_FAILURE); | ||
} | ||
const auto model_config_json_path = root / "tokenizer_config.json"; | ||
if (fs::exists(model_config_json_path)) { | ||
model_config_json = model_config_json_path; | ||
} | ||
} | ||
|
||
// Load the tokenizer.json file | ||
std::ifstream file(model_json); | ||
if (!file) { | ||
fprintf(stderr, "failed to open encoder file: %s\n", path.c_str()); | ||
exit(EXIT_FAILURE); | ||
} | ||
std::string contents( | ||
(std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>()); | ||
json parsed_json; | ||
try { | ||
parsed_json = json::parse(contents); | ||
} catch (const json::exception& e) { | ||
std::cout << "Error parsing json file: " << e.what() << std::endl; | ||
exit(EXIT_FAILURE); | ||
} | ||
|
||
// Parse the special tokens | ||
try { | ||
const auto& special_tokens = parsed_json.at("added_tokens"); | ||
for (auto it = special_tokens.begin(); it != special_tokens.end(); ++it) { | ||
const std::string token = it->at("content"); | ||
const uint64_t token_id = it->at("id"); | ||
if (!special_token_encoder_.emplace(token, token_id).second) { | ||
fprintf(stderr, "duplicate special token: %s\n", token.c_str()); | ||
exit(EXIT_FAILURE); | ||
} | ||
if (!special_token_decoder_.emplace(token_id, token).second) { | ||
fprintf(stderr, "duplicate special token id: %llu\n", token_id); | ||
exit(EXIT_FAILURE); | ||
} | ||
} | ||
} catch (const json::out_of_range& e) { | ||
fprintf(stderr, "Could not parse special tokens: %s\n", e.what()); | ||
exit(EXIT_FAILURE); | ||
} | ||
|
||
// Parse the standard tokens | ||
try { | ||
const auto& vocab = parsed_json.at("/model/vocab"_json_pointer); | ||
for (const auto& entry : vocab.items()) { | ||
const std::string token = entry.key(); | ||
const uint64_t token_id = entry.value(); | ||
// Skip adding special tokens to the standard encoder/decoder | ||
if (special_token_decoder_.find(token_id) == special_token_decoder_.end()) { | ||
if (!encoder_.emplace(token, token_id).second) { | ||
fprintf(stderr, "duplicate token: %s\n", token.c_str()); | ||
exit(EXIT_FAILURE); | ||
} | ||
if (!decoder_.emplace(token_id, token).second) { | ||
fprintf(stderr, "duplicate token id: %llu\n", token_id); | ||
exit(EXIT_FAILURE); | ||
} | ||
} | ||
} | ||
} catch (const json::out_of_range& e) { | ||
fprintf(stderr, "Could not parse tokens: %s\n", e.what()); | ||
exit(EXIT_FAILURE); | ||
} | ||
|
||
// Set the vocab size to include special tokens | ||
vocab_size_ = encoder_.size() + special_token_encoder_.size(); | ||
|
||
// Parse out the regexes for the pre-tokenizer | ||
try { | ||
const auto& pre_tokenizer = parsed_json.at("pre_tokenizer"); | ||
const std::string pt_type = pre_tokenizer.at("type"); | ||
if (pt_type == "Sequence") { | ||
for (const auto& entry : pre_tokenizer.at("pretokenizers")) { | ||
const std::string entry_type = entry.at("type"); | ||
std::string regex; | ||
if (entry_type == "Split") { | ||
regex = entry.at("Regex"); | ||
} else if (entry_type == "Digits") { | ||
regex = "\\p{N}"; | ||
} else if (entry_type == "ByteLevel") { | ||
//HACK! This should be better..... | ||
regex = "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+"; | ||
} else { | ||
fprintf(stderr, "Unknown pre-tokenizer type: %s\n", entry_type.c_str()); | ||
exit(EXIT_FAILURE); | ||
} | ||
|
||
//DEBUG -- Disable for now to compile | ||
// // Build the regex and add it to the list | ||
// regexes_.push_back(_create_regex(regex)); | ||
} | ||
} else { | ||
fprintf(stderr, "Unknown pre-tokenizer type: %s\n", pt_type.c_str()); | ||
exit(EXIT_FAILURE); | ||
} | ||
} catch (const json::out_of_range& e) { | ||
fprintf(stderr, "Could not parse pre_tokenizer: %s\n", e.what()); | ||
exit(EXIT_FAILURE); | ||
} | ||
|
||
// TODO: Do we need to parse the merges? | ||
|
||
// If a tokenizer config file is found, parse it to look up the eos/bos tokens | ||
if (!model_config_json.empty()) { | ||
|
||
// Load it and parse it as json | ||
std::ifstream file(model_config_json); | ||
if (!file) { | ||
fprintf(stderr, "failed to open encoder file: %s\n", path.c_str()); | ||
exit(EXIT_FAILURE); | ||
} | ||
std::string contents( | ||
(std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>()); | ||
json parsed_json; | ||
try { | ||
parsed_json = json::parse(contents); | ||
} catch (const json::exception& e) { | ||
std::cout << "Error parsing model config json json file: " << e.what() << std::endl; | ||
exit(EXIT_FAILURE); | ||
} | ||
|
||
// Pull out the token strings | ||
try { | ||
const std::string bos_token = parsed_json.at("bos_token"); | ||
const std::string eos_token = parsed_json.at("eos_token"); | ||
const auto& bos_it = special_token_encoder_.find(bos_token); | ||
const auto& eos_it = special_token_encoder_.find(eos_token); | ||
if (bos_it == special_token_encoder_.end()) { | ||
fprintf(stderr, "BOS token %s not in special tokens\n", bos_token.c_str()); | ||
exit(EXIT_FAILURE); | ||
} | ||
if (eos_it == special_token_encoder_.end()) { | ||
fprintf(stderr, "EOS token %s not in special tokens\n", eos_token.c_str()); | ||
exit(EXIT_FAILURE); | ||
} | ||
bos_tok_ = bos_it->second; | ||
eos_tok_ = eos_it->second; | ||
} catch (const json::out_of_range& e) { | ||
fprintf(stderr, "Could not eos/bos from tokenizer config: %s\n", e.what()); | ||
exit(EXIT_FAILURE); | ||
} | ||
} | ||
|
||
// Otherwise, make an educated guess with the following logic: | ||
// 1. Look for special tokens with "bos"/"begin" or "eos"/"end" in them | ||
// 2. Sub-qualify with the word "text" if needed | ||
// 3. If EOS found, but BOS is not (or vice versa), assume they are the same | ||
else { | ||
std::vector<std::string> bos_candidates; | ||
std::vector<std::string> eos_candidates; | ||
for (const auto& token : special_token_encoder_) { | ||
if ( | ||
token.first.find("bos") != std::string::npos || | ||
token.first.find("begin") != std::string::npos | ||
) { | ||
bos_candidates.push_back(token.first); | ||
} | ||
if ( | ||
token.first.find("eos") != std::string::npos || | ||
token.first.find("end") != std::string::npos | ||
) { | ||
eos_candidates.push_back(token.first); | ||
} | ||
} | ||
if (bos_candidates.size() > 1) { | ||
const auto orig_candidates = bos_candidates; | ||
bos_candidates.clear(); | ||
for (const auto& cand : orig_candidates) { | ||
if (cand.find("text") != std::string::npos) { | ||
bos_candidates.push_back(cand); | ||
} | ||
} | ||
} | ||
if (eos_candidates.size() > 1) { | ||
const auto orig_candidates = eos_candidates; | ||
eos_candidates.clear(); | ||
for (const auto& cand : orig_candidates) { | ||
if (cand.find("text") != std::string::npos) { | ||
eos_candidates.push_back(cand); | ||
} | ||
} | ||
} | ||
|
||
// Use if a single candidate | ||
bool bos_found = false; | ||
bool eos_found = false; | ||
if (bos_candidates.size() == 1) { | ||
bos_found = true; | ||
bos_tok_ = special_token_encoder_[bos_candidates[0]]; | ||
} | ||
if (eos_candidates.size() == 1) { | ||
eos_found = true; | ||
eos_tok_ = special_token_encoder_[eos_candidates[0]]; | ||
} | ||
|
||
// Make them the same if only one found | ||
if (bos_found && ! eos_found) { | ||
eos_tok_ = bos_tok_; | ||
} else if (! bos_found && eos_found) { | ||
bos_tok_ = eos_tok_; | ||
} | ||
} | ||
|
||
// Mark initialized once everything is done | ||
initialized_ = true; | ||
} | ||
|
||
// std::vector<uint64_t> | ||
// HFTokenizer::encode(const std::string& text, int8_t bos, int8_t eos) const { | ||
// // if (!initialized_) { | ||
// // exit(EXIT_FAILURE); | ||
// // } | ||
// // auto res = _encode_with_special_token(text, _special_token_encoder).first; | ||
// // for (auto i = 0; i < bos; ++i) { | ||
// // res.insert(res.begin(), bos_tok_); | ||
// // } | ||
// // for (auto i = 0; i < eos; ++i) { | ||
// // res.push_back(eos_tok_); | ||
// // } | ||
// // return res; | ||
|
||
// //DEBUG | ||
// return {}; | ||
// } | ||
|
||
// std::string HFTokenizer::decode(uint64_t prev, uint64_t cur) const { | ||
// // (void)prev; | ||
// // if (!initialized_) { | ||
// // exit(EXIT_FAILURE); | ||
// // } | ||
// // std::string ret; | ||
|
||
// // std::string token_bytes; | ||
// // auto iter = _decoder.find(cur); | ||
// // if (iter != _decoder.end()) { | ||
// // token_bytes = iter->second; | ||
// // } else { | ||
// // iter = _special_token_decoder.find(cur); | ||
// // if (iter != _special_token_decoder.end()) { | ||
// // token_bytes = iter->second; | ||
// // } else { | ||
// // fprintf(stderr, "unknown token: %" PRIu64 "\n", cur); | ||
// // exit(EXIT_FAILURE); | ||
// // } | ||
// // } | ||
// // ret += token_bytes; | ||
|
||
// // return ret; | ||
|
||
// //DEBUG | ||
// return ""; | ||
// } | ||
// -------------------------public method end------------------------------- |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.