Skip to content

Commit

Permalink
feat(tokenizer): Split Tiktoken out into BPETokenizerBase and Tiktoken
Browse files Browse the repository at this point in the history
This will allow HFTokenizer to reuse all of the BPE logic with different
pre/post tokenization

pytorch#1251
Branch: TokenizersCpp-1251

Signed-off-by: Gabe Goodhart <[email protected]>
  • Loading branch information
gabe-l-hart committed Nov 13, 2024
1 parent 69a5dd0 commit 6128219
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 47 deletions.
67 changes: 35 additions & 32 deletions tokenizer/tiktoken.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -240,20 +240,20 @@ static std::vector<uint64_t> _byte_pair_encode(
});
}
// ------------------------------Util end------------------------------------
// -------------------------private method start-------------------------------
// -------------------------protected method start---------------------------

std::pair<std::optional<std::string>, re2::StringPiece>
Tiktoken::_split_with_allowed_special_token(
BPETokenizerBase::split_with_allowed_special_token_(
re2::StringPiece& input,
const Encoder& allowed_special) const {
if (!_special_token_regex) {
if (!special_token_regex_) {
return std::make_pair(std::nullopt, input);
}

auto start = input.begin();
std::string special;
while (true) {
if (!re2::RE2::FindAndConsume(&input, *_special_token_regex, &special)) {
if (!re2::RE2::FindAndConsume(&input, *special_token_regex_, &special)) {
// No special token.
break;
}
Expand All @@ -269,38 +269,15 @@ Tiktoken::_split_with_allowed_special_token(
return std::make_pair(std::nullopt, input);
}

void Tiktoken::_encode(
re2::StringPiece& input,
std::vector<uint64_t>& ret,
uint64_t& last_piece_token_len) const {
std::string piece;
assert(regexes_.size());
for (const auto& regex : regexes_) {
assert(regex);
while (re2::RE2::FindAndConsume(&input, *regex, &piece)) {
auto iter = encoder_.find(piece);
if (iter != encoder_.end()) {
last_piece_token_len = 1;
ret.push_back(iter->second);
continue;
}
auto tokens = _byte_pair_encode(piece, encoder_);

last_piece_token_len = tokens.size();
ret.insert(ret.end(), tokens.begin(), tokens.end());
}
}
}

std::pair<std::vector<uint64_t>, uint64_t> Tiktoken::_encode_with_special_token(
std::pair<std::vector<uint64_t>, uint64_t> BPETokenizerBase::encode_with_special_token_(
const std::string& text,
const Encoder& allowed_special) const {
std::vector<uint64_t> tokens;
uint64_t last_piece_token_len = 0;
re2::StringPiece input(text);
while (true) {
auto [special, sub_input] =
_split_with_allowed_special_token(input, allowed_special);
split_with_allowed_special_token_(input, allowed_special);

_encode(sub_input, tokens, last_piece_token_len);

Expand Down Expand Up @@ -328,10 +305,36 @@ std::pair<std::vector<uint64_t>, uint64_t> Tiktoken::_encode_with_special_token(
return std::make_pair(tokens, last_piece_token_len);
}

// -------------------------protected method end-------------------------------
// -------------------------private method start-------------------------------

void Tiktoken::_encode(
re2::StringPiece& input,
std::vector<uint64_t>& ret,
uint64_t& last_piece_token_len) const {
std::string piece;
assert(regexes_.size());
for (const auto& regex : regexes_) {
assert(regex);
while (re2::RE2::FindAndConsume(&input, *regex, &piece)) {
auto iter = encoder_.find(piece);
if (iter != encoder_.end()) {
last_piece_token_len = 1;
ret.push_back(iter->second);
continue;
}
auto tokens = _byte_pair_encode(piece, encoder_);

last_piece_token_len = tokens.size();
ret.insert(ret.end(), tokens.begin(), tokens.end());
}
}
}

// -------------------------private method end-------------------------------
// -------------------------public method start-------------------------------

Tiktoken::Tiktoken() : Tokenizer() {}
Tiktoken::Tiktoken() : BPETokenizerBase() {}

void Tiktoken::load(const std::string& path) {
encoder_ = _load_encoder(path);
Expand All @@ -341,7 +344,7 @@ void Tiktoken::load(const std::string& path) {
special_token_decoder_ = _build_decoder(special_token_encoder_);

regexes_.push_back(_create_regex(_pattern));
_special_token_regex = _build_special_token_regex(special_token_encoder_);
special_token_regex_ = _build_special_token_regex(special_token_encoder_);

// initialize vocab_size, bos_tok, eos_tok
vocab_size_ = encoder_.size() + special_token_encoder_.size();
Expand All @@ -355,7 +358,7 @@ Tiktoken::encode(const std::string& text, int8_t bos, int8_t eos) const {
if (!initialized_) {
exit(EXIT_FAILURE);
}
auto res = _encode_with_special_token(text, special_token_encoder_).first;
auto res = encode_with_special_token_(text, special_token_encoder_).first;
for (auto i = 0; i < bos; ++i) {
res.insert(res.begin(), bos_tok_);
}
Expand Down
51 changes: 36 additions & 15 deletions tokenizer/tokenizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,15 +80,48 @@ class SPTokenizer : public Tokenizer {

// ----------------------- Tiktoken -----------------------
// Used by OpenAI, adapted from https://github.com/sewenew/tokenizer
//
// The main changes from the upstream implementation are to split out the core
// of the BPE logic into a base class that both Tiktoken and HFTokenizer can
// inherit from.

using Encoder = std::unordered_map<std::string, uint64_t>;
using Decoder = std::unordered_map<uint64_t, std::string>;
using Re2UPtr = std::unique_ptr<re2::RE2>;

class Tiktoken : public Tokenizer {
class BPETokenizerBase : public Tokenizer {
protected:

explicit BPETokenizerBase() {};
virtual ~BPETokenizerBase() {};

std::pair<std::optional<std::string>, re2::StringPiece>
split_with_allowed_special_token_(
re2::StringPiece& input,
const Encoder& allowed_special) const;

std::pair<std::vector<uint64_t>, uint64_t> encode_with_special_token_(
const std::string& text,
const Encoder& allowed_special) const;

// Protected members that can be overloaded by other BPE tokenizers
Re2UPtr special_token_regex_;
Encoder encoder_;
Encoder special_token_encoder_;
Decoder decoder_;
Decoder special_token_decoder_;

private:
virtual void _encode(
re2::StringPiece& input,
std::vector<uint64_t>& ret,
uint64_t& last_piece_token_len) const = 0;
};

class Tiktoken : public BPETokenizerBase {
public:
explicit Tiktoken();
~Tiktoken(){};
~Tiktoken() override {};

void load(const std::string& tokenizer_path) override;

Expand Down Expand Up @@ -118,27 +151,15 @@ class Tiktoken : public Tokenizer {
return special_tokens;
}

std::pair<std::optional<std::string>, re2::StringPiece>
_split_with_allowed_special_token(
re2::StringPiece& input,
const Encoder& allowed_special) const;

void _encode(
re2::StringPiece& input,
std::vector<uint64_t>& ret,
uint64_t& last_piece_token_len) const;

std::pair<std::vector<uint64_t>, uint64_t> _encode_with_special_token(
const std::string& text,
const Encoder& allowed_special) const;
uint64_t& last_piece_token_len) const override;

// Removed negative lookahead \s+(?!\S) since it's not supported by RE2.
const std::string _pattern =
R"((?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+)";

// Private members tht cannot be overloaded by other BPE tokenizers
Re2UPtr _special_token_regex;

protected:

// Protected members that can be overloaded by other BPE tokenizers
Expand Down

0 comments on commit 6128219

Please sign in to comment.