From 481d67b4e84d18c2ec161499783d1a337246e1d3 Mon Sep 17 00:00:00 2001 From: Chielo Newctle Date: Wed, 24 Apr 2024 16:56:39 +0800 Subject: [PATCH] feat: init --- .github/dependabot.yml | 10 + .github/workflows/ci.yml | 54 ++++++ .github/workflows/release-plz.yml | 29 +++ .gitignore | 76 ++++++++ Cargo.toml | 25 +++ LICENSE-APACHE | 201 ++++++++++++++++++++ LICENSE-MIT | 21 +++ README.md | 34 ++++ examples/rand-infer.rs | 250 ++++++++++++++++++++++++ src/choice.rs | 30 +++ src/lib.rs | 26 +++ src/search_tree.rs | 304 ++++++++++++++++++++++++++++++ src/tests.rs | 92 +++++++++ src/utils.rs | 203 ++++++++++++++++++++ src/vocab.rs | 122 ++++++++++++ 15 files changed, 1477 insertions(+) create mode 100644 .github/dependabot.yml create mode 100644 .github/workflows/ci.yml create mode 100644 .github/workflows/release-plz.yml create mode 100644 .gitignore create mode 100644 Cargo.toml create mode 100644 LICENSE-APACHE create mode 100644 LICENSE-MIT create mode 100644 README.md create mode 100644 examples/rand-infer.rs create mode 100644 src/choice.rs create mode 100644 src/lib.rs create mode 100644 src/search_tree.rs create mode 100644 src/tests.rs create mode 100644 src/utils.rs create mode 100644 src/vocab.rs diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000..c7ecf5e --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,10 @@ +version: 2 +updates: + - package-ecosystem: cargo + directory: / + schedule: + interval: weekly + - package-ecosystem: github-actions + directory: / + schedule: + interval: weekly diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..eb54e05 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,54 @@ +name: Cargo Build & Test + +on: + push: + pull_request: + workflow_dispatch: + +env: + CARGO_TERM_COLOR: always + +jobs: + build_and_test: + name: Build & Test + runs-on: ubuntu-latest + strategy: + matrix: + toolchain: + - stable + - beta + - nightly + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + - run: rustup update ${{ matrix.toolchain }} && rustup default ${{ matrix.toolchain }} + - run: cargo build --verbose + - run: cargo test --verbose + - run: cargo build --all-features --verbose + - run: cargo test --all-features --verbose + - run: cargo install cargo-all-features + - run: cargo check-all-features --verbose + - run: cargo build-all-features --verbose + - run: cargo test-all-features --verbose + + rustfmt: + name: Check Format + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + - run: rustup update stable && rustup default stable + - run: rustup component add rustfmt + - run: cargo fmt --all --check + + build_docs: + name: Docs + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + - run: rustup update nightly && rustup default nightly + - run: RUSTDOCFLAGS="--cfg doc_cfg --html-in-header ./docs-header.html" cargo +nightly doc --all-features --no-deps diff --git a/.github/workflows/release-plz.yml b/.github/workflows/release-plz.yml new file mode 100644 index 0000000..186b819 --- /dev/null +++ b/.github/workflows/release-plz.yml @@ -0,0 +1,29 @@ +name: Release-plz + +permissions: + pull-requests: write + contents: write + +on: + workflow_run: + workflows: [Cargo Build & Test] + types: [completed] + branches: [main] + +jobs: + release-plz: + name: Release-plz + runs-on: ubuntu-latest + if: ${{ github.event.workflow_run.conclusion == 'success' }} + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + fetch-depth: 0 + - name: Install Rust toolchain + uses: dtolnay/rust-toolchain@stable + - name: Run release-plz + uses: MarcoIeni/release-plz-action@v0.5 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + CARGO_REGISTRY_TOKEN: ${{ secrets.CARGO_REGISTRY_TOKEN }} diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..4cf9427 --- /dev/null +++ b/.gitignore @@ -0,0 +1,76 @@ +/target +Cargo.lock + +# Byte-compiled / optimized / DLL files +__pycache__/ +.pytest_cache/ +*.py[cod] + +# C extensions +*.so + +# Distribution / packaging +.Python +.venv/ +env/ +bin/ +build/ +develop-eggs/ +dist/ +eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +include/ +man/ +venv/ +*.egg-info/ +.installed.cfg +*.egg + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt +pip-selfcheck.json + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.cache +nosetests.xml +coverage.xml + +# Translations +*.mo + +# Mr Developer +.mr.developer.cfg +.project +.pydevproject + +# Rope +.ropeproject + +# Django stuff: +*.log +*.pot + +.DS_Store + +# Sphinx documentation +docs/_build/ + +# PyCharm +.idea/ + +# VSCode +.vscode/ + +# Pyenv +.python-version + +# pytest-readme +python/tests/test_readme.py diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..1c7604f --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,25 @@ +[package] +name = "mtc-token-healing" +version = "0.1.0" +edition = "2021" + +[dependencies] +derive_more = "0.99.17" +general-sam = { version = "1.0.0", features = ["trie"] } +pyo3 = { version = "0.21.2", optional = true } +serde = { version = "1.0.198", optional = true } +smallvec = "1.13.2" +thiserror = "1.0.59" + +[features] +pyo3 = ["dep:pyo3"] +serde = ["dep:serde", "smallvec/serde"] + +[dev-dependencies] +clap = { version = "4.5.4", features = ["derive", "env"] } +color-eyre = "0.6.3" +rand = "0.8.5" +regex = "1.10.4" +serde_json = "1.0.116" +tokenizers = { version = "0.19.1", features = ["hf-hub", "http"] } +tokio = { version = "1.37.0", features = ["rt-multi-thread"] } diff --git a/LICENSE-APACHE b/LICENSE-APACHE new file mode 100644 index 0000000..c98d27d --- /dev/null +++ b/LICENSE-APACHE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + https://www.apache.org/licenses/LICENSE-2.0 + +TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + +1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + +2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + +3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + +4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + +5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + +6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + +7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + +8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + +9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + +END OF TERMS AND CONDITIONS + +APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + +Copyright [yyyy] [name of copyright owner] + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/LICENSE-MIT b/LICENSE-MIT new file mode 100644 index 0000000..dd66111 --- /dev/null +++ b/LICENSE-MIT @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 Chielo Newctle +Copyright (c) 2023 ModelTC Team + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is furnished +to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS +OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, +WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF +OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..f3a6ea3 --- /dev/null +++ b/README.md @@ -0,0 +1,34 @@ +# mtc-token-healing + +[![Crates.io](https://img.shields.io/crates/v/mtc-token-healing.svg)](https://crates.io/crates/mtc-token-healing) +[![Docs.rs](https://img.shields.io/docsrs/mtc-token-healing.svg)](https://docs.rs/mtc-token-healing) +[![License](https://img.shields.io/badge/license-MIT%2FApache--2.0-informational.svg)](#license) +[![Build status](https://github.com/ModelTC/mtc-token-healing/actions/workflows/ci.yml/badge.svg)](https://github.com/ModelTC/mtc-token-healing/actions) + +Token healing implementation in Rust. + +## Usage + +See [`examples/rand-infer.rs`](examples/rand-infer.rs). + +```sh +echo '"def helloworl"' | cargo run --example rand-infer +``` + +## TODOs + +- [ ] Python bindings + +## License + +- © 2023 Chielo Newctle \<[ChieloNewctle@gmail.com](mailto:ChieloNewctle@gmail.com)\> +- © 2023 ModelTC Team + +This project is licensed under either of + +- [Apache License, Version 2.0](https://www.apache.org/licenses/LICENSE-2.0) ([`LICENSE-APACHE`](LICENSE-APACHE)) +- [MIT license](https://opensource.org/licenses/MIT) ([`LICENSE-MIT`](LICENSE-MIT)) + +at your option. + +The [SPDX](https://spdx.dev) license identifier for this project is `MIT OR Apache-2.0`. diff --git a/examples/rand-infer.rs b/examples/rand-infer.rs new file mode 100644 index 0000000..224ac72 --- /dev/null +++ b/examples/rand-infer.rs @@ -0,0 +1,250 @@ +use std::sync::{Arc, OnceLock}; + +use clap::Parser; +use color_eyre::{eyre::eyre, Result}; +use mtc_token_healing::{ + InferRequest, InferResponse, Prediction, ReorderedTokenId, SearchTree, TokenId, + VocabPrefixAutomaton, +}; +use regex::Regex; +use tokenizers::{AddedToken, Tokenizer}; +use tokio::runtime::Runtime; + +pub struct DummyInfer { + tree: SearchTree, + current_tokens_buffer: Vec, +} + +impl DummyInfer { + pub async fn new(tree: SearchTree) -> Result { + Ok(Self { + tree, + current_tokens_buffer: Default::default(), + }) + } + + pub async fn handle_infer_req(&mut self, req: InferRequest) -> Result { + println!("request: {req:?}"); + + if req.backtrace > 0 { + let buf = &mut self.current_tokens_buffer; + assert!(buf.len() >= req.backtrace); + buf.drain(buf.len() - req.backtrace..); + println!("backtracing: {}", req.backtrace); + } + + if let Some(token) = req.feed { + self.current_tokens_buffer.push(token); + println!("decoding: {token:?}\n{:?}", self.current_tokens_buffer); + } else { + assert!(self.current_tokens_buffer.is_empty()); + // println!("prefilling:\n{:?}", self.tree.prefilled_token_ids()) + } + + let decoded_len = self.current_tokens_buffer.len() as i32; + + let sampled = if let Some((lower, upper)) = req.sampling_id_range.as_ref() { + assert!(lower < upper); + let id = rand::random::() % (upper.0 - lower.0) + lower.0; + Some(Prediction { + token_id: ReorderedTokenId(id), + // log_prob: rand::random(), + // NOTE: The factor is to normalize accumulated random fake log_prob. + // **It is not needed for real log_prob generated from language models.** + log_prob: rand::random::() * f64::powi(0.5, decoded_len), + }) + } else { + None + }; + + let sparse_choices = req + .sparse_choices + .iter() + .map(|&id| Prediction { + token_id: id, + // log_prob: rand::random(), + // NOTE: The factor is to normalize accumulated random fake log_prob. + // **It is not needed for real log_prob generated from language models.** + log_prob: rand::random::() * f64::powi(0.5, decoded_len + 1), + }) + .collect(); + + let res = InferResponse { + sampled, + sparse_choices, + }; + + println!("response: {res:?}"); + + Ok(res) + } +} + +fn parse_byte_repr>(s: S) -> Result { + static BYTE_REPR: OnceLock = OnceLock::new(); + let byte_repr = BYTE_REPR + .get_or_init(|| Regex::new("^<0[xX][0-9a-fA-F]{2}>$").expect("invalid byte repr regex?")); + const PRE_LEN: usize = "<0x".len(); + const SUF_LEN: usize = ">".len(); + + if byte_repr.is_match(s.as_ref()) { + if let Some(hex) = s + .as_ref() + .get(PRE_LEN..s.as_ref().len().saturating_sub(SUF_LEN).max(PRE_LEN)) + { + if let Ok(b) = u8::from_str_radix(hex, 16) { + return Ok(b); + } + } + } + + Err(s) +} + +fn build_vocab>(tokenizer: T) -> Result>> { + let mut tokenizer = tokenizer.as_ref().clone(); + let vocab_size = tokenizer.get_vocab_size(true); + + let dummy_special_token = AddedToken::from("<*dummy-surrounding*>", true); + let add_token_res = tokenizer.add_special_tokens(&[dummy_special_token.clone()]); + assert!(add_token_res == 1); + let &dummy_token_id = tokenizer + .get_added_vocabulary() + .get_vocab() + .get(&dummy_special_token.content) + .expect("new dummy special token should be in the vocab"); + assert!((dummy_token_id as usize) >= vocab_size); + + let mut token_bytes = vec![Vec::new(); vocab_size]; + for (token, id) in tokenizer.get_vocab(true) { + if id == dummy_token_id { + continue; + } + assert!((id as usize) < vocab_size); + match parse_byte_repr(token) { + Ok(byte) => token_bytes[id as usize].push(byte), + Err(_) => { + let decoded = tokenizer + .decode(&[dummy_token_id, id, dummy_token_id], false) + .map_err(|e| eyre!(e))?; + + assert!(decoded.starts_with(&dummy_special_token.content)); + assert!(decoded.ends_with(&dummy_special_token.content)); + + let offset = dummy_special_token.content.len(); + token_bytes[id as usize].extend(decoded[offset..decoded.len() - offset].as_bytes()) + } + } + } + + Ok(token_bytes) +} + +#[derive(Clone, Debug, Parser)] +struct Args { + #[arg(short, long, env, default_value = "codellama/CodeLlama-7b-Instruct-hf")] + tokenizer_path: String, +} + +async fn main_body() -> Result<()> { + let args = Args::try_parse()?; + + let tokenizer = + Arc::new(Tokenizer::from_pretrained(&args.tokenizer_path, None).map_err(|e| eyre!(e))?); + + let vocab = build_vocab(tokenizer.clone())?; + + let automaton = Arc::new(VocabPrefixAutomaton::new(vocab)); + + println!("waiting for text (in json format) from stdin..."); + let text: String = serde_json::from_reader(std::io::stdin())?; + + println!("prompt: {text:?}\n"); + let tokenized = tokenizer + .encode(text.as_str(), true) + .map_err(|e| eyre!(e))?; + let prefilled_text = tokenizer + .decode(tokenized.get_ids(), false) + .map_err(|e| eyre!(e))?; + + let Some((tree, mut req)) = SearchTree::new( + automaton.clone(), + |end_pos| async { + let mut res = Vec::new(); + for pos in end_pos { + let tokenized = tokenizer.encode(&text[..pos], true)?; + res.push((pos, tokenized.get_ids().to_vec())) + } + Ok::<_, tokenizers::Error>(res) + }, + text.as_str(), + 0, + ) + .await + .map_err(|e| eyre!(e))? + else { + println!("no token healing required"); + return Ok(()); + }; + + let mut dummy_infer = DummyInfer::new(tree).await?; + + println!( + "prefilled tokens:\n{:?}\n", + Vec::from_iter( + dummy_infer + .tree + .prefilled_token_ids() + .iter() + .map(|&id| tokenizer.id_to_token(id)) + ), + ); + + loop { + let res = dummy_infer.handle_infer_req(req).await?; + req = if let Some(req) = dummy_infer.tree.feed(res)? { + req + } else { + break; + }; + } + + println!( + "\nbest choice:\n{:?}\n", + dummy_infer.tree.get_best_choice()?, + ); + + let best_token_ids_to_decode = dummy_infer.tree.get_best_choice()?.extra_token_ids.clone(); + println!( + "best choice tokens:\n{:?}\n", + Vec::from_iter( + best_token_ids_to_decode + .iter() + .map(|&id| tokenizer.id_to_token(id)) + ), + ); + + let full_token_ids: Vec<_> = dummy_infer + .tree + .prefilled_token_ids() + .iter() + .chain(best_token_ids_to_decode.iter()) + .copied() + .collect(); + let full_text = tokenizer + .decode(&full_token_ids, false) + .map_err(|e| eyre!(e))?; + + println!( + "decoded best choice:\n{:?}\n", + &full_text[prefilled_text.len()..] + ); + println!("complete best choice text:\n{:?}\n", full_text); + + Ok(()) +} + +fn main() -> Result<()> { + let runtime = Runtime::new()?; + runtime.block_on(main_body()) +} diff --git a/src/choice.rs b/src/choice.rs new file mode 100644 index 0000000..7476c53 --- /dev/null +++ b/src/choice.rs @@ -0,0 +1,30 @@ +use crate::TokenId; + +#[derive(Clone, Debug)] +pub struct BestChoice { + pub extra_token_ids: Vec, + pub accum_log_prob: f64, +} + +impl Default for BestChoice { + fn default() -> Self { + Self { + extra_token_ids: Default::default(), + accum_log_prob: f64::NEG_INFINITY, + } + } +} + +impl BestChoice { + pub fn update>(&mut self, token_ids: S, log_prob: f64) { + if log_prob <= self.accum_log_prob { + return; + } + self.accum_log_prob = log_prob; + self.extra_token_ids = token_ids.into_iter().collect(); + } + + pub fn valid(&self) -> bool { + self.accum_log_prob > f64::NEG_INFINITY + } +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..3a9ccc3 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,26 @@ +//! 翻转后的 tokens 构成的后缀自动机的 link 树, +//! 是与 tokens 的前缀树同构的。 +//! +//! 因此查询某一字符串是哪些 tokens 的前缀, +//! 等同于查询翻转后的字符串在后缀自动机上走到的状态所对应的 link 树的子树。 +//! +//! The link tree of a suffix automaton of reversed tokens +//! is isomorphic to the prefix tree of tokens. +//! +//! Thus finding tokens prefixed with a string +//! is the same as walking to the state on the suffix automaton +//! and gathering information among the subtree of the link tree. +pub mod choice; +pub mod search_tree; +pub mod utils; +pub mod vocab; + +pub use crate::{ + choice::BestChoice, + search_tree::{InferRequest, InferResponse, Prediction, SearchTree, SearchTreeError}, + utils::{CountInfo, ReorderedTokenId, TokenId}, + vocab::VocabPrefixAutomaton, +}; + +#[cfg(test)] +mod tests; diff --git a/src/search_tree.rs b/src/search_tree.rs new file mode 100644 index 0000000..eb3e4bd --- /dev/null +++ b/src/search_tree.rs @@ -0,0 +1,304 @@ +use std::{collections::BTreeMap, future::Future, sync::Arc}; + +use general_sam::{BTreeTransTable, Trie, TrieNodeID, TRIE_ROOT_NODE_ID}; + +use crate::{BestChoice, CountInfo, ReorderedTokenId, TokenId, VocabPrefixAutomaton}; + +#[derive(Debug, thiserror::Error)] +pub enum SearchTreeError { + #[error("feed infer results to an empty stack?")] + EmptyStack, + #[error("invalid sparse choices {choices:?}, expecting {expected:?}")] + InvalidSparseChoices { + choices: Vec, + expected: Vec, + }, + #[error("no sampled result, while expecting one")] + NoSampledResult, + #[error("invalid sampled result, {result:?} not in id range [{lower:?}, {upper:?})")] + InvalidSampledResult { + result: Prediction, + lower: ReorderedTokenId, + upper: ReorderedTokenId, + }, + #[error("no best choice can be found (neg inf log probs?)")] + NoBestChoice, +} + +pub type SearchTreeResult = Result; + +#[derive(Clone, Debug, Default)] +pub struct Prediction { + pub token_id: ReorderedTokenId, + pub log_prob: f64, +} + +#[derive(Clone, Debug)] +pub struct InferRequest { + pub backtrace: usize, + pub feed: Option, + pub sampling_id_range: Option<(ReorderedTokenId, ReorderedTokenId)>, + pub sparse_choices: Vec, +} + +#[derive(Clone, Debug, Default)] +pub struct InferResponse { + pub sampled: Option, + pub sparse_choices: Vec, +} + +#[derive(Debug)] +struct SearchState { + log_prob: f64, + sampling_id_range: Option<(ReorderedTokenId, ReorderedTokenId)>, + next_choices: Vec, + next_states: Vec<(ReorderedTokenId, TrieNodeID)>, +} + +#[derive(Debug)] +pub struct SearchTree { + automaton: Arc, + + max_num_tokens: usize, + + trie: Trie>, + sampling_id_range: BTreeMap, + + stack: Vec, + + prefilled_token_ids: Vec, + + current_new_token_ids: Vec, + current_accum_log_prob: f64, + + best_choice: BestChoice, +} + +impl SearchTree { + pub fn prefilled_token_ids(&self) -> &[TokenId] { + &self.prefilled_token_ids + } + + pub fn get_best_choice(&self) -> SearchTreeResult<&BestChoice> { + if self.best_choice.valid() { + Ok(&self.best_choice) + } else { + Err(SearchTreeError::NoSampledResult) + } + } + + pub async fn new>( + automaton: Arc, + tokenize_for_multiple_ending_positions: F, + text: S, + start_from: usize, + ) -> Result, E> + where + F: FnOnce(Vec) -> Fut, + Fut: Future> + Send, + Seq: IntoIterator, + Ids: IntoIterator, + { + let pos_to_cnt_info = BTreeMap::from_iter(automaton.as_ref().parse_chars(text, start_from)); + let end_pos = Vec::from_iter(pos_to_cnt_info.keys().copied()); + let encoded = tokenize_for_multiple_ending_positions(end_pos).await?; + Ok(Self::from_encoded(automaton, pos_to_cnt_info, encoded)) + } + + pub fn new_sync>( + automaton: Arc, + tokenize_for_multiple_ending_positions: F, + text: S, + start_from: usize, + ) -> Result, E> + where + F: FnOnce(Vec) -> Result, + Seq: IntoIterator, + Ids: IntoIterator, + { + let pos_to_cnt_info = BTreeMap::from_iter(automaton.as_ref().parse_chars(text, start_from)); + let end_pos = Vec::from_iter(pos_to_cnt_info.keys().copied()); + let encoded = tokenize_for_multiple_ending_positions(end_pos)?; + Ok(Self::from_encoded(automaton, pos_to_cnt_info, encoded)) + } + + pub fn from_encoded< + Seq: IntoIterator, + Ids: IntoIterator, + >( + automaton: Arc, + pos_to_cnt_info: BTreeMap, + encoded: Seq, + ) -> Option<(Self, InferRequest)> { + let mut tree = SearchTree { + automaton: automaton.clone(), + + max_num_tokens: 0, + + trie: Default::default(), + sampling_id_range: Default::default(), + + stack: Default::default(), + + prefilled_token_ids: Default::default(), + + current_new_token_ids: Default::default(), + current_accum_log_prob: 0.0, + + best_choice: Default::default(), + }; + + encoded.into_iter().for_each(|(pos, ids)| { + let Some(cnt_info) = pos_to_cnt_info.get(&pos) else { + return; + }; + let mut num_tokens = 0; + let node_id = tree.trie.insert(ids.into_iter().map(|x| { + num_tokens += 1; + automaton.rank()[x as usize] + })); + tree.max_num_tokens = tree.max_num_tokens.max(num_tokens); + tree.sampling_id_range.insert( + node_id, + ( + ReorderedTokenId(cnt_info.tot_cnt_lower as _), + ReorderedTokenId(cnt_info.tot_cnt_upper as _), + ), + ); + }); + + let mut node_id = TRIE_ROOT_NODE_ID; + while !tree.sampling_id_range.contains_key(&node_id) { + let Some(node) = tree.trie.get_node(node_id) else { + break; + }; + if node.get_trans().len() > 1 { + break; + } + let Some((&token_id, &next_node_id)) = node.get_trans().first_key_value() else { + break; + }; + tree.prefilled_token_ids + .push(automaton.order()[token_id.0 as usize]); + node_id = next_node_id; + } + + let node = tree.trie.get_node(node_id)?; + let sampling_id_range = tree.sampling_id_range.get(&node_id).copied(); + if node.get_trans().is_empty() && sampling_id_range.is_none() { + return None; + } + + let next_states = Vec::from_iter(node.get_trans().iter().map(|(&u, &v)| (u, v))); + let next_token_ids = next_states.iter().map(|i| i.0).collect(); + tree.stack.push(SearchState { + log_prob: 0.0, + sampling_id_range, + next_choices: Default::default(), + next_states, + }); + let request = InferRequest { + backtrace: 0, + feed: None, + sampling_id_range, + sparse_choices: next_token_ids, + }; + + Some((tree, request)) + } + + pub fn feed(&mut self, res: InferResponse) -> SearchTreeResult> { + let Some(top) = self.stack.last_mut() else { + return Err(SearchTreeError::EmptyStack); + }; + + if let Some((lower, upper)) = top.sampling_id_range.take() { + let Some(sampled) = res.sampled else { + return Err(SearchTreeError::NoSampledResult); + }; + if sampled.token_id < lower || sampled.token_id >= upper { + return Err(SearchTreeError::InvalidSampledResult { + result: sampled, + lower, + upper, + }); + } + self.current_new_token_ids + .push(self.automaton.order()[sampled.token_id.0 as usize]); + self.best_choice.update( + self.current_new_token_ids.iter().copied(), + self.current_accum_log_prob + sampled.log_prob, + ); + self.current_new_token_ids.pop(); + } + + if top.next_choices.len() != top.next_states.len() { + debug_assert!(top.next_choices.is_empty()); + + let expected_token_ids = Vec::from_iter(top.next_states.iter().map(|i| i.0)); + if res.sparse_choices.len() != expected_token_ids.len() + || expected_token_ids + .iter() + .zip(res.sparse_choices.iter()) + .any(|(&i, j)| i != j.token_id) + { + return Err(SearchTreeError::InvalidSparseChoices { + choices: res.sparse_choices, + expected: expected_token_ids, + }); + } + + top.next_choices = res.sparse_choices; + } + + let mut backtrace = 0; + while self + .stack + .last() + .is_some_and(|top| top.next_choices.is_empty()) + { + let res = self.stack.pop().unwrap(); + self.current_accum_log_prob -= res.log_prob; + self.current_new_token_ids.pop(); + backtrace += 1; + } + + let Some(top) = self.stack.last_mut() else { + return Ok(None); + }; + + let prediction = top.next_choices.pop().unwrap(); + let (token_id, node_id) = top.next_states.pop().unwrap(); + + let node = self.trie.get_node(node_id).unwrap(); + + let sampling_id_range = self.sampling_id_range.get(&node_id).copied(); + + let next_states = Vec::from_iter(node.get_trans().iter().map(|(&u, &v)| (u, v))); + let next_token_ids = next_states.iter().map(|i| i.0).collect(); + + self.stack.push(SearchState { + log_prob: prediction.log_prob, + sampling_id_range, + next_choices: Default::default(), + next_states, + }); + + let token_id = self.automaton.order()[token_id.0 as usize]; + self.current_new_token_ids.push(token_id); + self.current_accum_log_prob += prediction.log_prob; + + let request = InferRequest { + backtrace, + feed: Some(token_id), + sampling_id_range, + sparse_choices: next_token_ids, + }; + + Ok(Some(request)) + } + + pub fn max_num_tokens(&self) -> usize { + self.max_num_tokens + } +} diff --git a/src/tests.rs b/src/tests.rs new file mode 100644 index 0000000..1909159 --- /dev/null +++ b/src/tests.rs @@ -0,0 +1,92 @@ +use std::collections::BTreeSet; + +use crate::VocabPrefixAutomaton; + +fn testcase_parse_chars>( + automaton: &VocabPrefixAutomaton, + vocab_sorted: &[&str], + text: T, +) { + let text = text.as_ref(); + let res: BTreeSet<_> = automaton + .parse_chars(text, 0) + .into_iter() + .map(|(pos, cnt_info)| (pos, cnt_info.tot_cnt_lower, cnt_info.tot_cnt_upper)) + .collect(); + println!("{text}: {res:?}"); + for (pos, _) in text.char_indices() { + let prefix = &text[pos..]; + let range = vocab_sorted + .iter() + .enumerate() + .fold(None, |state, (k, token)| { + if token.starts_with(prefix) { + if let Some((u, v)) = state { + assert_eq!(v, k); + Some((u, k + 1)) + } else { + Some((k, k + 1)) + } + } else { + state + } + }); + if let Some((lower, upper)) = range { + println!("{pos}: {:?}", &vocab_sorted[lower..upper]); + assert!(res.contains(&(pos, lower, upper))); + } + } +} + +fn testcase_vocab_prefix(vocab: &[&str], texts: &[&str]) { + let automaton = VocabPrefixAutomaton::new(vocab); + let vocab_sorted = automaton + .order() + .iter() + .map(|&i| vocab[i as usize]) + .collect::>(); + + println!("vocab_sorted: {vocab_sorted:?}"); + + for text in texts { + testcase_parse_chars(&automaton, vocab_sorted.as_slice(), text); + } +} + +#[test] +fn test_chinese_vocab_prefix() { + let vocab = ["歌曲", "聆听歌曲", "播放歌曲", "歌词", "查看歌词"]; + let texts = [ + "歌曲", + "聆听歌曲", + "聆听歌曲", + "聆听歌曲", + "播放歌曲", + "播放歌曲", + "播放歌曲", + "歌词", + "查看歌词", + "查看歌词", + "听歌曲", + "听歌曲", + "放歌曲", + "听歌", + "放歌", + "词", + "查看", + "bba", + "bbb", + "bba", + "bba", + "cacab", + "ccc", + ]; + testcase_vocab_prefix(&vocab, &texts); +} + +#[test] +fn test_simple_vocab_prefix() { + let vocab = ["bb", "ca", "ab", "c", "aa", "bbaa", "a", "cc", "b"]; + let texts = ["bba", "bbb", "bba", "bba", "cacab", "ccc"]; + testcase_vocab_prefix(&vocab, &texts); +} diff --git a/src/utils.rs b/src/utils.rs new file mode 100644 index 0000000..f4e270b --- /dev/null +++ b/src/utils.rs @@ -0,0 +1,203 @@ +use std::convert::Infallible; + +use general_sam::{ + BTreeTransTable, BoxBisectTable, GeneralSam, TransitionTable, Trie, TrieNodeAlike, + SAM_ROOT_NODE_ID, +}; +#[cfg(feature = "pyo3")] +use pyo3::pyclass; +use smallvec::SmallVec; + +pub type TokenId = u32; + +#[derive( + Clone, + Copy, + Debug, + Default, + derive_more::Deref, + derive_more::AsRef, + PartialEq, + Eq, + PartialOrd, + Ord, +)] +pub struct ReorderedTokenId(pub u32); + +#[derive(Clone, Debug, Default, PartialEq, Eq)] +#[cfg_attr(feature = "pyo3", pyclass(get_all, set_all))] +pub struct CountInfo { + pub cnt: usize, + pub tot_cnt_lower: usize, + pub tot_cnt_upper: usize, +} + +pub(crate) type TokenBytes = SmallVec<[u8; 32]>; + +#[derive(Debug)] +pub(crate) struct SortResult { + pub cnt_info_of_vocab: Vec, + pub order: Vec, + pub rank: Vec, +} + +pub(crate) fn gen_sam_cnt_info< + T: AsRef<[u8]>, + V: IntoIterator, + C: AsRef<[CountInfo]>, + TransTable: TransitionTable, +>( + sam_of_rev_tokens: &GeneralSam, + vocab: V, + cnt_info_of_vocab: C, +) -> Vec> { + let mut cnt_info_of_sam_rev = vec![None; sam_of_rev_tokens.num_of_nodes()]; + + for (token, cnt_info) in vocab.into_iter().zip(cnt_info_of_vocab.as_ref().iter()) { + let mut state = sam_of_rev_tokens.get_root_state(); + + state.feed_ref(token.as_ref().iter().rev()); + + let mut new_info = cnt_info.clone(); + new_info.cnt = 1; + cnt_info_of_sam_rev[state.node_id] = Some(new_info); + } + + for &id in sam_of_rev_tokens + .get_topo_and_suf_len_sorted_node_ids() + .iter() + .rev() + { + if id == SAM_ROOT_NODE_ID { + continue; + } + + let Some(node) = sam_of_rev_tokens.get_node(id) else { + continue; + }; + + let Some(cnt_info) = cnt_info_of_sam_rev[id].clone() else { + continue; + }; + + let link_cnt_info = &mut cnt_info_of_sam_rev[node.get_suffix_parent_id()]; + + let Some(link_cnt_info) = link_cnt_info.as_mut() else { + *link_cnt_info = Some(cnt_info); + continue; + }; + + link_cnt_info.cnt += cnt_info.cnt; + link_cnt_info.tot_cnt_lower = link_cnt_info.tot_cnt_lower.min(cnt_info.tot_cnt_lower); + link_cnt_info.tot_cnt_upper = link_cnt_info.tot_cnt_upper.max(cnt_info.tot_cnt_upper); + } + + #[cfg(debug_assertions)] + for (id, cnt_info) in cnt_info_of_sam_rev.iter().enumerate() { + if id == SAM_ROOT_NODE_ID { + continue; + } + let Some(cnt_info) = cnt_info else { + continue; + }; + let Some(node) = sam_of_rev_tokens.get_node(id) else { + continue; + }; + + let link_cnt_info = cnt_info_of_sam_rev[node.get_suffix_parent_id()].as_ref(); + + debug_assert!(link_cnt_info.is_some_and(|link_cnt_info| { + link_cnt_info.tot_cnt_lower <= cnt_info.tot_cnt_lower + && link_cnt_info.tot_cnt_upper >= cnt_info.tot_cnt_upper + })); + } + + cnt_info_of_sam_rev +} + +pub(crate) fn sort_vocab_with_trie, V: ExactSizeIterator>( + vocab: V, +) -> SortResult { + let vocab_size = vocab.len(); + + let (trie, trie_node_ids) = { + let mut trie = Trie::>::default(); + let trie_node_ids: Vec<_> = vocab + .into_iter() + .map(|token| trie.insert(token.as_ref().iter().copied())) + .collect(); + (trie, trie_node_ids) + }; + + let mut cnt_info_of_trie = vec![CountInfo::default(); trie.num_of_nodes()]; + trie_node_ids + .iter() + .for_each(|&i| cnt_info_of_trie[i].cnt += 1); + + let mut tot_cnt = 0; + + let res = trie.get_root_state().dfs_travel(|event| { + match event { + general_sam::TravelEvent::PushRoot(state) + | general_sam::TravelEvent::Push(state, _, _) => { + let id = state.node_id; + let cnt_info = &mut cnt_info_of_trie[id]; + cnt_info.tot_cnt_lower = tot_cnt; + tot_cnt += cnt_info.cnt; + } + general_sam::TravelEvent::Pop(state, _) => { + let id = state.node_id; + let cnt_info = &mut cnt_info_of_trie[id]; + cnt_info.tot_cnt_upper = tot_cnt; + } + } + Ok::<_, Infallible>(()) + }); + match res { + Ok(()) => {} + Err(e) => match e {}, + } + + let cnt_info_of_vocab: Vec<_> = (0..vocab_size) + .map(|i| cnt_info_of_trie[trie_node_ids[i]].clone()) + .collect(); + + let order = { + let mut order: Vec<_> = (0..vocab_size as TokenId).collect(); + order.sort_by_key(|&i| cnt_info_of_vocab[i as usize].tot_cnt_lower); + order + }; + + let rank = { + let mut rank = vec![ReorderedTokenId(0); vocab_size]; + order + .iter() + .enumerate() + .for_each(|(k, &i)| rank[i as usize] = ReorderedTokenId(k as _)); + rank + }; + + debug_assert_eq!(order.len(), vocab_size); + debug_assert_eq!(rank.len(), vocab_size); + debug_assert_eq!(cnt_info_of_vocab.len(), vocab_size); + + SortResult { + cnt_info_of_vocab, + order, + rank, + } +} + +pub(crate) fn build_sam_of_reversed_tokens, V: IntoIterator>( + vocab: V, +) -> GeneralSam> { + let trie_of_rev_tokens = { + let mut trie = Trie::>::default(); + vocab.into_iter().for_each(|token| { + trie.insert(token.as_ref().iter().copied().rev()); + }); + trie + }; + GeneralSam::>::from_trie(trie_of_rev_tokens.get_root_state()) + .alter_trans_table_into() +} diff --git a/src/vocab.rs b/src/vocab.rs new file mode 100644 index 0000000..15e187f --- /dev/null +++ b/src/vocab.rs @@ -0,0 +1,122 @@ +use general_sam::{BoxBisectTable, GeneralSam}; +#[cfg(feature = "pyo3")] +use pyo3::pyclass; + +use crate::{ + utils::{build_sam_of_reversed_tokens, gen_sam_cnt_info, sort_vocab_with_trie, TokenBytes}, + CountInfo, ReorderedTokenId, TokenId, +}; + +#[derive(Clone, Debug)] +#[cfg_attr(feature = "pyo3", pyclass)] +pub struct VocabPrefixAutomaton { + vocab: Vec, + order: Vec, + rank: Vec, + sam_of_rev_tokens: GeneralSam>, + cnt_info_of_sam_rev: Vec>, +} + +impl VocabPrefixAutomaton { + pub fn new, V: IntoIterator>(vocab: V) -> Self { + let vocab: Vec<_> = vocab + .into_iter() + .map(|token| TokenBytes::from_slice(token.as_ref())) + .collect(); + let sort_result = sort_vocab_with_trie(vocab.iter().map(|x| x.as_slice())); + let sam_of_rev_tokens = build_sam_of_reversed_tokens(vocab.iter().map(|x| x.as_slice())); + let cnt_info_of_sam_rev = gen_sam_cnt_info( + &sam_of_rev_tokens, + vocab.iter().map(|x| x.as_slice()), + &sort_result.cnt_info_of_vocab, + ); + Self { + vocab, + order: sort_result.order, + rank: sort_result.rank, + sam_of_rev_tokens, + cnt_info_of_sam_rev, + } + } + + pub fn vocab(&self) -> &[TokenBytes] { + &self.vocab + } + + pub fn order(&self) -> &[TokenId] { + &self.order + } + + pub fn rank(&self) -> &[ReorderedTokenId] { + &self.rank + } + + pub fn parse_chars>( + &self, + text: S, + start_from: usize, + ) -> Vec<(usize, CountInfo)> { + let text = text.as_ref(); + + let mut last = text.len(); + let mut state = self.sam_of_rev_tokens.get_root_state(); + let mut res = Vec::new(); + + for (pos, _) in text.char_indices().rev() { + if pos < start_from { + break; + } + let c = &text.as_bytes()[pos..last]; + state.feed_ref(c.iter().rev()); + if state.is_nil() { + break; + } + if let Some(cnt_info) = self.cnt_info_of_sam_rev[state.node_id].clone() { + res.push((pos, cnt_info)); + } + last = pos; + } + + res + } +} + +#[cfg(feature = "pyo3")] +mod _pyo3 { + use pyo3::pymethods; + + use crate::utils::CountInfo; + + use super::VocabPrefixAutomaton; + + #[pymethods] + impl VocabPrefixAutomaton { + #[new] + fn py_new(vocab: Vec>) -> Self { + Self::new(vocab) + } + + #[pyo3(name = "vocab_size")] + fn vocab_size_py(&self) -> usize { + self.vocab.len() + } + + #[pyo3(name = "get_order")] + fn get_order_py(&self) -> Vec { + self.order.clone() + } + + #[pyo3(name = "get_rank")] + fn get_rank_py(&self) -> Vec { + self.rank.iter().map(|x| x.0).collect() + } + + #[pyo3(name = "parse_chars")] + fn parse_chars_py(&self, text: &str, start_from: usize) -> Vec<(usize, CountInfo)> { + self.parse_chars(text, start_from) + .into_iter() + .map(|(pos, cnt_info)| (pos, cnt_info.clone())) + .collect() + } + } +}