Skip to content

Commit

Permalink
feat(tokenizers): add SequencePreTokenizer
Browse files Browse the repository at this point in the history
pytorch#1251
Branch: TokenizersCpp-1251

Signed-off-by: Gabe Goodhart <[email protected]>
  • Loading branch information
gabe-l-hart committed Nov 15, 2024
1 parent 34bea83 commit f1663d5
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 2 deletions.
21 changes: 21 additions & 0 deletions tokenizer/pre_tokenizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,24 @@ ByteLevelPreTokenizer::pre_tokenize(re2::StringPiece input) const {

return unicode_regex_split(input_str, {pattern_});
}

// SequencePreTokenizer ////////////////////////////////////////////////////////

SequencePreTokenizer::SequencePreTokenizer(
std::vector<PreTokenizer::Ptr> pre_tokenizers)
: pre_tokenizers_(std::move(pre_tokenizers))
{}

std::vector<std::string> SequencePreTokenizer::pre_tokenize(re2::StringPiece input) const {
std::vector<std::string> pieces{std::string(input)};
for (const auto& pre_tokenizer : pre_tokenizers_) {
std::vector<std::string> new_pieces;
for (const auto& piece : pieces) {
for (const auto& subpiece : pre_tokenizer->pre_tokenize(piece)) {
new_pieces.push_back(subpiece);
}
}
pieces = std::move(new_pieces);
}
return pieces;
}
23 changes: 23 additions & 0 deletions tokenizer/pre_tokenizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
class PreTokenizer {
public:

/** Shared pointer type */
typedef std::shared_ptr<PreTokenizer> Ptr;

/** Split the input string piece into sub-pieces
*
* This pre-tokenization may result in sub-pieces that are not contained
Expand Down Expand Up @@ -94,3 +97,23 @@ class ByteLevelPreTokenizer : public PreTokenizer {
const bool add_prefix_space_;

}; // end class ByteLevelPreTokenizer

// -- Sequence -----------------------------------------------------------------
// Used by tokenizers
// CITE: https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/pre_tokenizers/sequence.rs

class SequencePreTokenizer : public PreTokenizer {
public:

/**
* @param pre_tokenizers: The sequence of owned pre-tokenizer objects to use
*/
explicit SequencePreTokenizer(std::vector<PreTokenizer::Ptr> pre_tokenizers);

/** Perform pre-tokenization */
std::vector<std::string> pre_tokenize(re2::StringPiece input) const override;

private:
const std::vector<PreTokenizer::Ptr> pre_tokenizers_;

}; // end class ByteLevelPreTokenizer
27 changes: 25 additions & 2 deletions tokenizer/tests/pre_tokenizer_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ void assert_split_match(
) {
re2::StringPiece prompt_view(prompt);
const auto& got = ptok.pre_tokenize(prompt_view);
EXPECT_EQ(got.size(), expected.size());
EXPECT_EQ(expected.size(), got.size());
for (auto i = 0; i < got.size(); ++i) {
EXPECT_EQ(got[i], expected[i]);
EXPECT_EQ(expected[i], got[i]);
}
}

Expand Down Expand Up @@ -93,3 +93,26 @@ TEST_F(ByteLevelPreTokenizerTest, PreTokenizeNoPrefix) {
{"Hello", "ĠWorld"}
);
}

TEST_F(ByteLevelPreTokenizerTest, PreTokenizeCustomRegex) {
ByteLevelPreTokenizer ptok(false, R"(o)");
assert_split_match(
ptok,
"Hello World",
{"Hell", "o", "ĠW", "o", "rld"}
);
}

// SequencePreTokenizer ////////////////////////////////////////////////////////
class SequencePreTokenizerTest : public ::testing::Test {};

TEST_F(SequencePreTokenizerTest, PreTokenizeDigitAndByteLevel) {
PreTokenizer::Ptr dptok(new DigitsPreTokenizer(true));
PreTokenizer::Ptr bptok(new ByteLevelPreTokenizer(false));
SequencePreTokenizer ptok({dptok, bptok});
assert_split_match(
ptok,
"The number 1 then 234 then 5.",
{"The", "Ġnumber", "Ġ", "1", "Ġthen", "Ġ", "2", "3", "4", "Ġthen", "Ġ", "5", "."}
);
}

0 comments on commit f1663d5

Please sign in to comment.