diff --git a/Cargo.toml b/Cargo.toml index f27ec93326..17e7e4ba57 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,7 +20,7 @@ exclude = [ resolver = "2" [workspace.package] -version = "0.7.2" +version = "0.8.0" edition = "2021" description = "Minimalist ML framework." repository = "https://github.com/huggingface/candle" @@ -33,14 +33,14 @@ ab_glyph = "0.2.23" accelerate-src = { version = "0.3.2" } anyhow = { version = "1", features = ["backtrace"] } byteorder = "1.4.3" -candle = { path = "./candle-core", package = "candle-core", version = "0.7.2" } -candle-datasets = { path = "./candle-datasets", version = "0.7.2" } -candle-flash-attn = { path = "./candle-flash-attn", version = "0.7.2" } -candle-kernels = { path = "./candle-kernels", version = "0.7.2" } -candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.7.2" } -candle-nn = { path = "./candle-nn", version = "0.7.2" } -candle-onnx = { path = "./candle-onnx", version = "0.7.2" } -candle-transformers = { path = "./candle-transformers", version = "0.7.2" } +candle = { path = "./candle-core", package = "candle-core", version = "0.8.0" } +candle-datasets = { path = "./candle-datasets", version = "0.8.0" } +candle-flash-attn = { path = "./candle-flash-attn", version = "0.8.0" } +candle-kernels = { path = "./candle-kernels", version = "0.8.0" } +candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.8.0" } +candle-nn = { path = "./candle-nn", version = "0.8.0" } +candle-onnx = { path = "./candle-onnx", version = "0.8.0" } +candle-transformers = { path = "./candle-transformers", version = "0.8.0" } clap = { version = "4.2.4", features = ["derive"] } criterion = { version = "0.5.1", default-features=false } cudarc = { version = "0.12.1", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false } diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs index 39ca909d88..4b73d00696 100644 --- a/candle-core/src/lib.rs +++ b/candle-core/src/lib.rs @@ -32,6 +32,20 @@ //! Python can really add overhead in more complex workflows and the [GIL](https://www.backblaze.com/blog/the-python-gil-past-present-and-future/) is a notorious source of headaches. //! //! Rust is cool, and a lot of the HF ecosystem already has Rust crates [safetensors](https://github.com/huggingface/safetensors) and [tokenizers](https://github.com/huggingface/tokenizers) +//! +//! ## Other Crates +//! +//! Candle consists of a number of crates. This crate holds core the common data structures but you may wish +//! to look at the docs for the other crates which can be found here: +//! +//! - [candle-core](https://docs.rs/candle-core/). Core Datastructures and DataTypes. +//! - [candle-nn](https://docs.rs/candle-nn/). Building blocks for Neural Nets. +//! - [candle-datasets](https://docs.rs/candle-datasets/). Rust access to commonly used Datasets like MNIST. +//! - [candle-examples](https://docs.rs/candle-examples/). Examples of Candle in Use. +//! - [candle-onnx](https://docs.rs/candle-onnx/). Loading and using ONNX models. +//! - [candle-pyo3](https://docs.rs/candle-pyo3/). Access to Candle from Python. +//! - [candle-transformers](https://docs.rs/candle-transformers/). Candle implemntation of many published transformer models. +//! #[cfg(feature = "accelerate")] mod accelerate; diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs index 34931c9dfd..de107a61b0 100644 --- a/candle-core/src/metal_backend/mod.rs +++ b/candle-core/src/metal_backend/mod.rs @@ -1237,7 +1237,7 @@ impl BackendStorage for MetalStorage { let dst_el = ids_l.shape().elem_count(); let dtype = self.dtype; let device = self.device(); - let buffer = device.new_buffer(dst_el, dtype, "index_select")?; + let buffer = device.new_buffer(dst_el, dtype, "gather")?; let name = match (ids.dtype, self.dtype) { (DType::U32, DType::F32) => "gather_u32_f32", (DType::U32, DType::F16) => "gather_u32_f16", @@ -1324,14 +1324,23 @@ impl BackendStorage for MetalStorage { let device = self.device(); let buffer = device.new_buffer(dst_el, dtype, "index_select")?; let name = match (ids.dtype, self.dtype) { + (DType::U8, DType::U8) => "is_u8_u8", + (DType::U8, DType::U32) => "is_u8_u32", + (DType::U8, DType::I64) => "is_u8_i64", (DType::U8, DType::BF16) => "is_u8_bf16", (DType::U8, DType::F32) => "is_u8_f32", (DType::U8, DType::F16) => "is_u8_f16", + (DType::U32, DType::U8) => "is_u32_u8", + (DType::U32, DType::U32) => "is_u32_u32", + (DType::U32, DType::I64) => "is_u32_i64", (DType::U32, DType::F32) => "is_u32_f32", (DType::U32, DType::F16) => "is_u32_f16", (DType::U32, DType::BF16) => "is_u32_bf16", + (DType::I64, DType::U8) => "is_i64_u8", + (DType::I64, DType::U32) => "is_i64_u32", + (DType::I64, DType::I64) => "is_i64_i64", (DType::I64, DType::F32) => "is_i64_f32", (DType::I64, DType::F16) => "is_i64_f16", (DType::I64, DType::BF16) => "is_i64_bf16", diff --git a/candle-examples/examples/stella-en-v5/README.md b/candle-examples/examples/stella-en-v5/README.md index 5fcc67c351..3a87b2956a 100644 --- a/candle-examples/examples/stella-en-v5/README.md +++ b/candle-examples/examples/stella-en-v5/README.md @@ -21,7 +21,7 @@ Stella_en_1.5B_v5 is trained by [MRL](https://arxiv.org/abs/2205.13147) enabling The following reproduces the example in the [model card](https://huggingface.co/dunzhang/stella_en_1.5B_v5) for a retrieval task (s2p). The sample queries and docs are hardcoded in the example. ```bash -$ cargo run --example stella-en-v5 --release --features +$ cargo run --example stella-en-v5 --release --features -- --which 1.5b > > Score: 0.8178786 @@ -37,9 +37,29 @@ $ cargo run --example stella-en-v5 --release --features > caused by free radicals. Regular consumption of green tea has been associated with improved heart health, enhanced cognitive function, and a reduced risk of certain types > > of cancer. The polyphenols in green tea may also have anti-inflammatory and weight loss properties. > + +$ cargo run --example stella-en-v5 --release --features -- --which 400m + +> +> Score: 0.8397539 +> Query: What are some ways to reduce stress? +> Answer: There are many effective ways to reduce stress. Some common techniques include deep breathing, meditation, and physical activity. Engaging in hobbies, spending +> time in nature, and connecting with loved ones can also help alleviate stress. Additionally, setting boundaries, practicing self-care, and learning to say no can prevent +> stress from building up. +> +> +> +> Score: 0.809545 +> Query: What are the benefits of drinking green tea? +> Answer: Green tea has been consumed for centuries and is known for its potential health benefits. It contains antioxidants that may help protect the body against damage +> caused by free radicals. Regular consumption of green tea has been associated with improved heart health, enhanced cognitive function, and a reduced risk of certain types +> of cancer. The polyphenols in green tea may also have anti-inflammatory and weight loss properties. +> ``` ## Supported options: -- `Stella_en_15B_v5` supports 256, 768, 1024, 2048, 4096, 6144 and 8192 embedding dimensions (though the model card mentions 512, I couldn't find weights for the same). In the example run this is supported with `--embed-dim` option. E.g. `... --embed-dim 4096`. Defaults to `1024`. +- `Stella_en_v5` has 2 model variants published - a 1.5B variant and 400M variant. This is enabled through the flag `--which`. E.g. `--which 400m` or `--which 1.5b`. + +- `Stella_en_v5` supports 256, 768, 1024, 2048, 4096, 6144 and 8192 embedding dimensions (though the model card mentions 512, I couldn't find weights for the same). In the example run this is supported with `--embed-dim` option. E.g. `... --embed-dim 4096`. Defaults to `1024`. - As per the [model card](https://huggingface.co/dunzhang/stella_en_1.5B_v5), the model has been primarily trained on `s2s` (similarity) and `s2p` (retrieval) tasks. These require a slightly different `query` preprocessing (a different prompt template for each). In this example this is enabled though `--task` option. \ No newline at end of file diff --git a/candle-examples/examples/stella-en-v5/main.rs b/candle-examples/examples/stella-en-v5/main.rs index 2408262b1a..68ed7e70c6 100644 --- a/candle-examples/examples/stella-en-v5/main.rs +++ b/candle-examples/examples/stella-en-v5/main.rs @@ -212,6 +212,14 @@ impl EncodeTask { } } +#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)] +enum Which { + #[value(name = "1.5b")] + Large, + #[value(name = "400m")] + Small, +} + #[derive(Parser, Debug)] #[command(author, version, about, long_about = None)] struct Args { @@ -219,6 +227,9 @@ struct Args { #[arg(long)] cpu: bool, + #[arg(long)] + which: Which, + /// Enable tracing (generates a trace-timestamp.json file). #[arg(long)] tracing: bool, @@ -250,24 +261,33 @@ struct Args { // Tokenizer creation is super critical in our case. // We are going to be `padding: Left` for each batch -fn create_tokenizer(tokenizer_file: &Path) -> Result { +fn create_tokenizer(tokenizer_file: &Path, which: Which) -> Result { let mut tokenizer = Tokenizer::from_file(tokenizer_file).map_err(E::msg)?; - let pad_id = if let Some(pad_id) = tokenizer.token_to_id("<|endoftext|>") { - pad_id - } else { - return Err(anyhow!( - "Tokenizer doesn't contain expected `<|endoftext|>` token" - )); - }; - // This part is super important, we are padding the tokens to the *`left`* and not the usual *`right`* padding - tokenizer.with_padding(Some(PaddingParams { - strategy: PaddingStrategy::BatchLongest, - direction: PaddingDirection::Left, - pad_id, - pad_token: "<|endoftext|>".to_string(), - ..Default::default() - })); + if which == Which::Large { + let pad_id = if let Some(pad_id) = tokenizer.token_to_id("<|endoftext|>") { + pad_id + } else { + return Err(anyhow!( + "Tokenizer doesn't contain expected `<|endoftext|>` token" + )); + }; + + // This part is super important, we are padding the tokens to the *`left`* and not the usual *`right`* padding + tokenizer.with_padding(Some(PaddingParams { + strategy: PaddingStrategy::BatchLongest, + direction: PaddingDirection::Left, + pad_id, + pad_token: "<|endoftext|>".to_string(), + ..Default::default() + })); + } else { + tokenizer.with_padding(Some(PaddingParams { + strategy: PaddingStrategy::BatchLongest, + direction: PaddingDirection::Right, + ..Default::default() + })); + } Ok(tokenizer) } @@ -298,7 +318,19 @@ fn main() -> Result<()> { Some(d) => d, None => EmbedDim::Dim1024, }; - let repo = api.repo(Repo::model("dunzhang/stella_en_1.5B_v5".to_string())); + + let (repo, cfg) = match args.which { + Which::Large => ( + "dunzhang/stella_en_1.5B_v5", + Config::new_1_5_b_v5(embed_dim.embed_dim()), + ), + Which::Small => ( + "dunzhang/stella_en_400M_v5", + Config::new_400_m_v5(embed_dim.embed_dim()), + ), + }; + + let repo = api.repo(Repo::model(repo.to_string())); let tokenizer_filename = match args.tokenizer_file { Some(file) => std::path::PathBuf::from(file), None => repo.get("tokenizer.json")?, @@ -330,7 +362,7 @@ fn main() -> Result<()> { println!("retrieved the files in {:?}", start.elapsed()); // Initializing the tokenizer which would require us to add padding to the `left` for batch encoding - let tokenizer = create_tokenizer(tokenizer_filename.as_path())?; + let tokenizer = create_tokenizer(tokenizer_filename.as_path(), args.which)?; let start = std::time::Instant::now(); @@ -343,11 +375,7 @@ fn main() -> Result<()> { let embed_vb = unsafe { VarBuilder::from_mmaped_safetensors(&embed_weight_files, DType::F32, &device)? }; - let model = EmbeddingModel::new( - &Config::new_1_5_b_v5(embed_dim.embed_dim()), - base_vb, - embed_vb, - )?; + let model = EmbeddingModel::new(&cfg, base_vb, embed_vb)?; println!("loaded the model in {:?}", start.elapsed()); diff --git a/candle-flash-attn/Cargo.toml b/candle-flash-attn/Cargo.toml index dbae908bfd..861aa86ad5 100644 --- a/candle-flash-attn/Cargo.toml +++ b/candle-flash-attn/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "candle-flash-attn" -version = "0.7.2" +version = "0.8.0" edition = "2021" description = "Flash attention layer for the candle ML framework." @@ -11,7 +11,7 @@ license = "MIT OR Apache-2.0" readme = "README.md" [dependencies] -candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.7.2" } +candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.8.0" } half = { version = "2.3.1", features = ["num-traits"] } [build-dependencies] diff --git a/candle-kernels/Cargo.toml b/candle-kernels/Cargo.toml index 40c5f01f4a..02eb95626b 100644 --- a/candle-kernels/Cargo.toml +++ b/candle-kernels/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "candle-kernels" -version = "0.7.2" +version = "0.8.0" edition = "2021" description = "CUDA kernels for Candle" diff --git a/candle-metal-kernels/Cargo.toml b/candle-metal-kernels/Cargo.toml index 52e6f210a6..30cf531f24 100644 --- a/candle-metal-kernels/Cargo.toml +++ b/candle-metal-kernels/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "candle-metal-kernels" -version = "0.7.2" +version = "0.8.0" edition = "2021" description = "Metal kernels for Candle" diff --git a/candle-metal-kernels/src/indexing.metal b/candle-metal-kernels/src/indexing.metal index 9eee97ca0a..c14f2c1ff1 100644 --- a/candle-metal-kernels/src/indexing.metal +++ b/candle-metal-kernels/src/indexing.metal @@ -193,12 +193,16 @@ INDEX_OP(is_i64_f16, int64_t, half) INDEX_OP(is_i64_bf16, int64_t, bfloat) #endif +INDEX_OP(is_u32_u8, uint32_t, uint8_t) +INDEX_OP(is_u32_u32, uint32_t, uint32_t) INDEX_OP(is_u32_f32, uint32_t, float) INDEX_OP(is_u32_f16, uint32_t, half) #if defined(__HAVE_BFLOAT__) INDEX_OP(is_u32_bf16, uint32_t, bfloat) #endif +INDEX_OP(is_u8_u8, uint8_t, uint8_t) +INDEX_OP(is_u8_u32, uint8_t, uint32_t) INDEX_OP(is_u8_f32, uint8_t, float) INDEX_OP(is_u8_f16, uint8_t, half) #if defined(__HAVE_BFLOAT__) diff --git a/candle-nn/src/activation.rs b/candle-nn/src/activation.rs index fc1819f5e0..772548a01a 100644 --- a/candle-nn/src/activation.rs +++ b/candle-nn/src/activation.rs @@ -1,3 +1,5 @@ +//! Activation Functions +//! use candle::{Result, Tensor}; use serde::Deserialize; diff --git a/candle-nn/src/kv_cache.rs b/candle-nn/src/kv_cache.rs index 68addb98bf..918dca702f 100644 --- a/candle-nn/src/kv_cache.rs +++ b/candle-nn/src/kv_cache.rs @@ -1,3 +1,5 @@ +//! Cache Implementations +//! use candle::{Device, Result, Tensor}; #[derive(Debug, Clone)] diff --git a/candle-nn/src/lib.rs b/candle-nn/src/lib.rs index fcac58308c..eb3cde4a75 100644 --- a/candle-nn/src/lib.rs +++ b/candle-nn/src/lib.rs @@ -1,3 +1,20 @@ +//! candle-nn +//! +//! ## Other Crates +//! +//! Candle consists of a number of crates. This crate holds structs and functions +//! that allow you to build and train neural nets. You may wish +//! to look at the docs for the other crates which can be found here: +//! +//! - [candle-core](https://docs.rs/candle-core/). Core Datastructures and DataTypes. +//! - [candle-nn](https://docs.rs/candle-nn/). Building blocks for Neural Nets. +//! - [candle-datasets](https://docs.rs/candle-datasets/). Rust access to commonly used Datasets like MNIST. +//! - [candle-examples](https://docs.rs/candle-examples/). Examples of Candle in Use. +//! - [candle-onnx](https://docs.rs/candle-onnx/). Loading and using ONNX models. +//! - [candle-pyo3](https://docs.rs/candle-pyo3/). Access to Candle from Python. +//! - [candle-transformers](https://docs.rs/candle-transformers/). Candle implemntation of many published transformer models. +//! + pub mod activation; pub mod batch_norm; pub mod conv; diff --git a/candle-nn/src/loss.rs b/candle-nn/src/loss.rs index fb1e11f413..03e8524d6d 100644 --- a/candle-nn/src/loss.rs +++ b/candle-nn/src/loss.rs @@ -1,3 +1,5 @@ +//! Loss Calculations +//! use candle::{Result, Tensor}; /// The negative log likelihood loss. diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index 0f35285d0b..c84e297b99 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -1,3 +1,6 @@ +//! Tensor ops. +//! + use candle::{CpuStorage, DType, Layout, Module, Result, Shape, Tensor, D}; use rayon::prelude::*; diff --git a/candle-nn/src/rotary_emb.rs b/candle-nn/src/rotary_emb.rs index 1084cfb512..0191bd7e6a 100644 --- a/candle-nn/src/rotary_emb.rs +++ b/candle-nn/src/rotary_emb.rs @@ -1,3 +1,5 @@ +//! Rotary Embeddings +//! use candle::{CpuStorage, Layout, Result, Shape, Tensor, D}; use rayon::prelude::*; diff --git a/candle-nn/src/sequential.rs b/candle-nn/src/sequential.rs index bef9975287..de5ae4971b 100644 --- a/candle-nn/src/sequential.rs +++ b/candle-nn/src/sequential.rs @@ -1,3 +1,5 @@ +//! Sequential Layer +//! //! A sequential layer used to chain multiple layers and closures. use candle::{Module, Result, Tensor}; diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs index 00669468d6..0d836c7fd4 100644 --- a/candle-nn/src/var_builder.rs +++ b/candle-nn/src/var_builder.rs @@ -1,3 +1,5 @@ +//! A `VarBuilder` for variable retrieval from models +//! //! A `VarBuilder` is used to retrieve variables used by a model. These variables can either come //! from a pre-trained checkpoint, e.g. using `VarBuilder::from_mmaped_safetensors`, or initialized //! for training, e.g. using `VarBuilder::from_varmap`. diff --git a/candle-nn/src/var_map.rs b/candle-nn/src/var_map.rs index 3cb27c632e..ba020746b5 100644 --- a/candle-nn/src/var_map.rs +++ b/candle-nn/src/var_map.rs @@ -1,3 +1,5 @@ +//! A `VarMap` is a store that holds named variables. +//! use candle::{DType, Device, Result, Shape, Tensor, Var}; use std::collections::HashMap; use std::sync::{Arc, Mutex}; diff --git a/candle-onnx/Cargo.toml b/candle-onnx/Cargo.toml index 5b16ae858f..fbace8cdfc 100644 --- a/candle-onnx/Cargo.toml +++ b/candle-onnx/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "candle-onnx" -version = "0.7.2" +version = "0.8.0" edition = "2021" description = "ONNX support for Candle" @@ -10,8 +10,8 @@ categories = ["science"] license = "MIT OR Apache-2.0" [dependencies] -candle = { path = "../candle-core", package = "candle-core", version = "0.7.2" } -candle-nn = { path = "../candle-nn", version = "0.7.2" } +candle = { path = "../candle-core", package = "candle-core", version = "0.8.0" } +candle-nn = { path = "../candle-nn", version = "0.8.0" } prost = "0.12.1" [build-dependencies] diff --git a/candle-transformers/src/models/chinese_clip/mod.rs b/candle-transformers/src/models/chinese_clip/mod.rs index 88472f0b88..0f6eedd0f2 100644 --- a/candle-transformers/src/models/chinese_clip/mod.rs +++ b/candle-transformers/src/models/chinese_clip/mod.rs @@ -171,7 +171,8 @@ impl ChineseClipModel { ) -> Result { let output = self .text_model - .forward(input_ids, token_type_ids, attention_mask)?; + .forward(input_ids, token_type_ids, attention_mask)? + .contiguous()?; self.text_projection.forward(&output) } diff --git a/candle-transformers/src/models/stella_en_v5.rs b/candle-transformers/src/models/stella_en_v5.rs index 9d933fade5..08cd808c43 100644 --- a/candle-transformers/src/models/stella_en_v5.rs +++ b/candle-transformers/src/models/stella_en_v5.rs @@ -1,31 +1,47 @@ use crate::models::with_tracing::{linear, linear_no_bias, Linear, RmsNorm}; -use candle::{DType, Device, IndexOp, Module, Result, Tensor}; -use candle_nn::{Activation, VarBuilder}; +use candle::{DType, Device, Error, IndexOp, Module, Result, Tensor, D}; +use candle_nn::{layer_norm, Activation, LayerNorm, VarBuilder}; use std::sync::Arc; +// internal representation for identifying which model is being used +#[derive(Debug, Copy, Clone, PartialEq, serde::Deserialize)] +pub enum ModelVariant { + Large, // 1.5B + Small, // 400M +} + +impl Default for ModelVariant { + fn default() -> Self { + Self::Large + } +} + // Same as `qwen2` family of models with the exception being the `embed_head` // The final `output` causal modelling head is swapped with a learned `dense` layer, `embed_head` -#[derive(Debug, Clone, PartialEq, serde::Deserialize)] +#[derive(Debug, Default, Clone, PartialEq, serde::Deserialize)] pub struct Config { + pub variant: ModelVariant, pub vocab_size: usize, pub hidden_size: usize, pub intermediate_size: usize, pub num_hidden_layers: usize, pub num_attention_heads: usize, - pub num_key_value_heads: usize, pub max_position_embeddings: usize, - pub max_window_layers: usize, - pub tie_word_embeddings: bool, pub rope_theta: f64, - pub rms_norm_eps: f64, - pub hidden_act: Activation, pub embed_head: EmbedHead, + pub norm_eps: f64, // RMSNorm for 1.5B || LayerNorm for 400M + pub activation_fn: Activation, // Silu for 1.5B || Gelu for 400M + // Unique to 1.5B + pub num_key_value_heads: usize, + // Unique to 400M + pub type_vocab_size: usize, + pub scaling_factor: f64, } // Excerpt from `stella` model card: // `Stella_en_1.5B_v5` models have been trained on [MRL](https://arxiv.org/abs/2205.13147) enabling multiple output dimensions // Embed head represents the config for various embedding dims supported -#[derive(Debug, Clone, PartialEq, serde::Deserialize)] +#[derive(Debug, Default, Clone, PartialEq, serde::Deserialize)] pub struct EmbedHead { pub in_features: usize, pub out_features: usize, @@ -51,9 +67,9 @@ impl Default for EmbedDim { } impl EmbedDim { - pub fn config(&self) -> EmbedHead { + pub fn config(&self, in_features: usize) -> EmbedHead { EmbedHead { - in_features: 1536, + in_features, out_features: match &self { Self::Dim256 => 256, Self::Dim768 => 768, @@ -74,7 +90,8 @@ impl Config { // Representing config.json at https://huggingface.co/dunzhang/stella_en_1.5B_v5/blob/main/config.json // Removed `sliding_window` related config which is basically being carried forward from `qwen2` but not used here Self { - hidden_act: candle_nn::Activation::Silu, + variant: ModelVariant::Large, + activation_fn: candle_nn::Activation::Silu, vocab_size: 151646, hidden_size: 1536, intermediate_size: 8960, @@ -82,11 +99,30 @@ impl Config { num_attention_heads: 12, num_key_value_heads: 2, max_position_embeddings: 131072, - max_window_layers: 21, - tie_word_embeddings: false, rope_theta: 1000000., - rms_norm_eps: 1e-06, - embed_head: embed_dim.config(), + norm_eps: 1e-06, + embed_head: embed_dim.config(1536), + ..Default::default() + } + } + + /// Initialize new `stella_en_400M_v5` + pub fn new_400_m_v5(embed_dim: EmbedDim) -> Self { + Self { + variant: ModelVariant::Small, + vocab_size: 30528, + hidden_size: 1024, + intermediate_size: 4096, + num_hidden_layers: 24, + num_attention_heads: 16, + max_position_embeddings: 8192, + type_vocab_size: 2, + norm_eps: 1e-12, + scaling_factor: 2.0, + rope_theta: 160000.0, + activation_fn: Activation::Gelu, + embed_head: embed_dim.config(1024), + ..Default::default() } } } @@ -100,27 +136,57 @@ struct RotaryEmbedding { impl RotaryEmbedding { fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result { let dim = cfg.hidden_size / cfg.num_attention_heads; - let max_seq_len = cfg.max_position_embeddings; + // Factoring in `scaling factor` for `400M` variant + let max_seq_len = if cfg.scaling_factor == 0. { + cfg.max_position_embeddings + } else { + ((cfg.max_position_embeddings as f64) * cfg.scaling_factor) as usize + }; + + // let rot_dim = if cfg.variant == ModelVariant::Small { dim / 2 } else { dim }; let inv_freq: Vec<_> = (0..dim) .step_by(2) - .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32) + .map(|i| { + // Scaled rope_theta for 400M variant + let rope_theta = if cfg.scaling_factor == 0. { + cfg.rope_theta + } else { + cfg.rope_theta * cfg.scaling_factor + }; + let mut freq = 1. / rope_theta.powf(i as f64 / dim as f64); + + if cfg.scaling_factor != 0. { + freq /= cfg.scaling_factor.powf(2.0 / (dim as f64)) + } + + freq as f32 + }) .collect(); + let inv_freq_len = inv_freq.len(); let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?; + + // Calculate position embeddings with scaled sequence length let t = Tensor::arange(0u32, max_seq_len as u32, dev)? .to_dtype(dtype)? .reshape((max_seq_len, 1))?; let freqs = t.matmul(&inv_freq)?; + // if cfg.variant == ModelVariant::Small { + // freqs = Tensor::cat(&[&freqs, &freqs], 1)? + // } + Ok(Self { sin: freqs.sin()?, cos: freqs.cos()?, }) } + // TODO: re-visit this fn apply_rotary_emb_qkv(&self, q: &Tensor, k: &Tensor) -> Result<(Tensor, Tensor)> { let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?; let cos = self.cos.narrow(0, 0, seq_len)?; let sin = self.sin.narrow(0, 0, seq_len)?; + let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?; let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?; Ok((q_embed, k_embed)) @@ -130,8 +196,9 @@ impl RotaryEmbedding { #[derive(Debug, Clone)] #[allow(clippy::upper_case_acronyms)] struct MLP { + variant: ModelVariant, gate_proj: Linear, - up_proj: Linear, + up_proj: Option, // `up_proj` only for 1.5B variant down_proj: Linear, act_fn: Activation, } @@ -140,31 +207,65 @@ impl MLP { fn new(cfg: &Config, vb: VarBuilder) -> Result { let hidden_sz = cfg.hidden_size; let intermediate_sz = cfg.intermediate_size; - let gate_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("gate_proj"))?; - let up_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("up_proj"))?; - let down_proj = linear_no_bias(intermediate_sz, hidden_sz, vb.pp("down_proj"))?; + + let (gate_proj, up_proj, down_proj) = match cfg.variant { + ModelVariant::Large => ( + linear_no_bias(hidden_sz, intermediate_sz, vb.pp("gate_proj"))?, + Some(linear_no_bias( + hidden_sz, + intermediate_sz, + vb.pp("up_proj"), + )?), + linear_no_bias(intermediate_sz, hidden_sz, vb.pp("down_proj"))?, + ), + ModelVariant::Small => ( + linear_no_bias(hidden_sz, intermediate_sz * 2, vb.pp("up_gate_proj"))?, + None, + linear(intermediate_sz, hidden_sz, vb.pp("down_proj"))?, + ), + }; + Ok(Self { + variant: cfg.variant, gate_proj, up_proj, down_proj, - act_fn: cfg.hidden_act, + act_fn: cfg.activation_fn, }) } } impl Module for MLP { fn forward(&self, xs: &Tensor) -> Result { - let lhs = xs.apply(&self.gate_proj)?.apply(&self.act_fn)?; - let rhs = xs.apply(&self.up_proj)?; + let up = self.gate_proj.forward(xs)?; + + let (lhs, rhs) = match self.variant { + ModelVariant::Large => { + let lhs = up.apply(&self.act_fn)?; + let rhs = xs.apply(self.up_proj.as_ref().unwrap())?; + + (lhs, rhs) + } + ModelVariant::Small => { + // Get the dimensions + let (_batch_size, _seq_len, hidden_dim) = up.dims3()?; + let split_size = hidden_dim / 2; + + // Split along the last dimension (hidden_dim) + let up_states = up.narrow(2, 0, split_size)?; + let gate = up.narrow(2, split_size, split_size)?.apply(&self.act_fn)?; + + (up_states, gate) + } + }; + (lhs * rhs)?.apply(&self.down_proj) } } #[derive(Debug, Clone)] struct Attention { - q_proj: Linear, - k_proj: Linear, - v_proj: Linear, + qkv_proj: Linear, o_proj: Linear, num_heads: usize, num_kv_heads: usize, @@ -172,6 +273,7 @@ struct Attention { head_dim: usize, hidden_size: usize, rotary_emb: Arc, + variant: ModelVariant, } impl Attention { @@ -179,16 +281,47 @@ impl Attention { let hidden_sz = cfg.hidden_size; let num_heads = cfg.num_attention_heads; let num_kv_heads = cfg.num_key_value_heads; - let num_kv_groups = num_heads / num_kv_heads; + let num_kv_groups = if num_kv_heads > 0 { + num_heads / num_kv_heads + } else { + 0 + }; let head_dim = hidden_sz / num_heads; - let q_proj = linear(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?; - let k_proj = linear(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?; - let v_proj = linear(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?; - let o_proj = linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?; + + let (qkv_proj, o_proj) = match cfg.variant { + ModelVariant::Large => { + // The 1.5B variant comes with separate `q, k, v` layers, let's merge it and standardize + // Weights + let q_w = vb + .pp("q_proj") + .get((num_heads * head_dim, hidden_sz), "weight")?; + let k_w = vb + .pp("k_proj") + .get((num_kv_heads * head_dim, hidden_sz), "weight")?; + let v_w = vb + .pp("v_proj") + .get((num_kv_heads * head_dim, hidden_sz), "weight")?; + // Biases + let q_b = vb.pp("q_proj").get(num_heads * head_dim, "bias")?; + let k_b = vb.pp("k_proj").get(num_kv_heads * head_dim, "bias")?; + let v_b = vb.pp("v_proj").get(num_kv_heads * head_dim, "bias")?; + + let qkv_w = Tensor::cat(&[&q_w, &k_w, &v_w], 0)?; + let qkv_b = Tensor::cat(&[&q_b, &k_b, &v_b], 0)?; + + ( + Linear::from_weights(qkv_w, Some(qkv_b)), + linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?, + ) + } + ModelVariant::Small => ( + linear(hidden_sz, 3 * num_heads * head_dim, vb.pp("qkv_proj"))?, + linear(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?, + ), + }; + Ok(Self { - q_proj, - k_proj, - v_proj, + qkv_proj, o_proj, num_heads, num_kv_heads, @@ -196,45 +329,90 @@ impl Attention { head_dim, hidden_size: hidden_sz, rotary_emb, + variant: cfg.variant, }) } fn forward(&mut self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result { let (b_sz, q_len, _) = xs.dims3()?; - let query_states = self.q_proj.forward(xs)?; - let key_states = self.k_proj.forward(xs)?; - let value_states = self.v_proj.forward(xs)?; + let qkv = self.qkv_proj.forward(xs)?; - let query_states = query_states - .reshape((b_sz, q_len, self.num_heads, self.head_dim))? - .transpose(1, 2)?; - let key_states = key_states - .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))? - .transpose(1, 2)?; - let value_states = value_states - .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))? - .transpose(1, 2)?; + let n_kv_heads = match self.variant { + ModelVariant::Large => self.num_kv_heads, + ModelVariant::Small => self.num_heads, + }; + + let (query_states, key_states, value_states) = match self.variant { + ModelVariant::Large => { + let q_sz = self.num_heads * self.head_dim; + let kv_sz = n_kv_heads * self.head_dim; + + let q = qkv.narrow(D::Minus1, 0, q_sz)?.reshape(( + b_sz, + q_len, + self.num_heads, + self.head_dim, + ))?; + let k = qkv.narrow(D::Minus1, q_sz, kv_sz)?.reshape(( + b_sz, + q_len, + n_kv_heads, + self.head_dim, + ))?; + let v = qkv.narrow(D::Minus1, q_sz + kv_sz, kv_sz)?.reshape(( + b_sz, + q_len, + n_kv_heads, + self.head_dim, + ))?; + + (q, k, v) + } + ModelVariant::Small => { + // Split into Q, K, V and reshape to match PyTorch shapes + let qkv = qkv.reshape((b_sz, q_len, 3, self.num_heads, self.head_dim))?; + + ( + qkv.i((.., .., 0, .., ..))?, + qkv.i((.., .., 1, .., ..))?, + qkv.i((.., .., 2, .., ..))?, + ) + } + }; + + let query_states = query_states.transpose(1, 2)?.contiguous()?; + let key_states = key_states.transpose(1, 2)?.contiguous()?; + let value_states = value_states.transpose(1, 2)?.contiguous()?; let (query_states, key_states) = self .rotary_emb .apply_rotary_emb_qkv(&query_states, &key_states)?; - let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?.contiguous()?; - let value_states = - crate::utils::repeat_kv(value_states, self.num_kv_groups)?.contiguous()?; + // The 1.5B is expected to have grouped query attention + let (key_states, value_states) = if self.variant == ModelVariant::Large { + ( + crate::utils::repeat_kv(key_states, self.num_kv_groups)?.contiguous()?, + crate::utils::repeat_kv(value_states, self.num_kv_groups)?.contiguous()?, + ) + } else { + (key_states, value_states) + }; let attn_output = { let scale = 1f64 / f64::sqrt(self.head_dim as f64); - let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?; + let attn_weights = query_states.matmul(&key_states.transpose(2, 3)?)?; + let attn_weights = (attn_weights * scale)?; let attn_weights = match attention_mask { None => attn_weights, Some(mask) => attn_weights.broadcast_add(mask)?, }; let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?; + attn_weights.matmul(&value_states)? }; + attn_output .transpose(1, 2)? .reshape((b_sz, q_len, self.hidden_size))? @@ -243,70 +421,282 @@ impl Attention { } #[derive(Debug, Clone)] -struct DecoderLayer { - self_attn: Attention, +enum NormType { + Layer(LayerNorm), + Rms(RmsNorm), +} + +#[derive(Debug, Clone)] +struct Layer { + variant: ModelVariant, + attention: Attention, mlp: MLP, - input_layernorm: RmsNorm, - post_attention_layernorm: RmsNorm, + // For 1.5B: this is `input_layernorm` + // For 400M: this is `output_layernorm` + layernorm: NormType, + post_attention_layernorm: NormType, } -impl DecoderLayer { +impl Layer { fn new(rotary_emb: Arc, cfg: &Config, vb: VarBuilder) -> Result { - let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?; - let mlp = MLP::new(cfg, vb.pp("mlp"))?; - let input_layernorm = - RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?; - let post_attention_layernorm = RmsNorm::new( - cfg.hidden_size, - cfg.rms_norm_eps, - vb.pp("post_attention_layernorm"), + let attention = Attention::new( + rotary_emb, + cfg, + vb.pp(if cfg.variant == ModelVariant::Large { + "self_attn" + } else { + "attention" + }), )?; + let mlp = MLP::new(cfg, vb.pp("mlp"))?; + let (layernorm, post_attention_layernorm) = match cfg.variant { + ModelVariant::Large => ( + NormType::Rms(RmsNorm::new( + cfg.hidden_size, + cfg.norm_eps, + vb.pp("input_layernorm"), + )?), + NormType::Rms(RmsNorm::new( + cfg.hidden_size, + cfg.norm_eps, + vb.pp("post_attention_layernorm"), + )?), + ), + ModelVariant::Small => ( + NormType::Layer(layer_norm( + cfg.hidden_size, + candle_nn::LayerNormConfig { + eps: cfg.norm_eps, + ..Default::default() + }, + vb.pp("mlp_ln"), + )?), + NormType::Layer(layer_norm( + cfg.hidden_size, + candle_nn::LayerNormConfig { + eps: cfg.norm_eps, + ..Default::default() + }, + vb.pp("attn_ln"), + )?), + ), + }; + Ok(Self { - self_attn, + variant: cfg.variant, + attention, mlp, - input_layernorm, + layernorm, post_attention_layernorm, }) } fn forward(&mut self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result { + // Here, the application of normalizations and activation calculations differ + // For Large [1.5B]: + // residual = x + // state = other_layernorm(xs) + // state = attention(state) + // state += residual + // residual = state + // state = mlp(attention_layernorm(state)) + // -> residual + state + // For Small [400M]: + // residual = x; + // state = attention(x) + // state += residual + // state = attention_layernorm(state) + // residual = state + // state = mlp(state) + // state += residual + // -> other_layernorm(state) let residual = xs; - let xs = self.input_layernorm.forward(xs)?; - let xs = self.self_attn.forward(&xs, attention_mask)?; - let xs = (xs + residual)?; - let residual = &xs; - let xs = xs.apply(&self.post_attention_layernorm)?.apply(&self.mlp)?; - residual + xs + + match self.variant { + ModelVariant::Large => { + let (attn_ln, input_ln) = if let (NormType::Rms(attn_ln), NormType::Rms(input_ln)) = + (&self.post_attention_layernorm, &self.layernorm) + { + (attn_ln, input_ln) + } else { + return Err(candle::error::Error::Msg( + "Stella 1.5B expects RMSNorm".to_string(), + )); + }; + + let xs = input_ln.forward(xs)?; + let xs = (self.attention.forward(&xs, attention_mask)? + residual)?; + + let residual = &xs; + let xs = xs.apply(attn_ln)?.apply(&self.mlp)?; + + residual + xs + } + ModelVariant::Small => { + let (attn_ln, output_ln) = + if let (NormType::Layer(attn_ln), NormType::Layer(input_ln)) = + (&self.post_attention_layernorm, &self.layernorm) + { + (attn_ln, input_ln) + } else { + return Err(candle::error::Error::Msg( + "Stella 400M expects RMSNorm".to_string(), + )); + }; + + let xs = (self.attention.forward(xs, attention_mask)? + residual)?; + let xs = attn_ln.forward(&xs)?; + + let residual = &xs; + let xs = (self.mlp.forward(&xs)? + residual)?; + + output_ln.forward(&xs) + } + } + } +} + +#[derive(Debug, Clone)] +pub struct Embeddings { + variant: ModelVariant, + // For 1.5B: this is the `embed_tokens` + // For 400M: this is the `word_embeddings` + embeddings: candle_nn::Embedding, + // folloing are specifically for 400M + token_type_embeddings: Option, + layer_norm: Option, + position_ids: Option, +} + +impl Embeddings { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result { + let (embeddings, token_type_embeddings, layer_norm, position_ids) = match cfg.variant { + ModelVariant::Large => ( + candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("embed_tokens"))?, + None, + None, + None, + ), + ModelVariant::Small => { + let vb = vb.pp("embeddings"); + let weight = vb.pp("LayerNorm").get_with_hints( + cfg.hidden_size, + "weight", + candle_nn::Init::Const(1.0), + )?; + let bias = vb.pp("LayerNorm").get_with_hints( + cfg.hidden_size, + "bias", + candle_nn::Init::Const(0.0), + )?; + let dev = bias.device().clone(); + + let layer_norm = candle_nn::LayerNorm::new(weight, bias, cfg.norm_eps); + + ( + candle_nn::embedding( + cfg.vocab_size, + cfg.hidden_size, + vb.pp("word_embeddings"), + )?, + Some(candle_nn::embedding( + cfg.type_vocab_size, + cfg.hidden_size, + vb.pp("token_type_embeddings"), + )?), + Some(layer_norm), + Some(Tensor::arange( + 0u32, + cfg.max_position_embeddings as u32, + &dev, + )?), + ) + } + }; + + Ok(Self { + variant: cfg.variant, + embeddings, + token_type_embeddings, + layer_norm, + position_ids, + }) + } +} + +impl Module for Embeddings { + fn forward(&self, xs: &Tensor) -> Result { + let embd = self.embeddings.forward(xs)?; + // For 1.5B just forward the embeddings + if self.variant == ModelVariant::Large { + return Ok(embd); + } + + let (token_type_embed, layer_norm, pos_ids) = + if let (Some(token_type_embd), Some(layer_norm), Some(position_ids)) = ( + &self.token_type_embeddings, + &self.layer_norm, + &self.position_ids, + ) { + (token_type_embd, layer_norm, position_ids) + } else { + return Err(Error::Msg( + "Stella 400M requires `token_type_embeddings`, `layer_norm` and `position_ids`" + .to_string(), + )); + }; + + let (batch_size, seq_length) = xs.dims2()?; + + let pos_ids = pos_ids + .as_ref() + .narrow(0, 0, seq_length)? + .expand((batch_size, seq_length))?; + + layer_norm.forward(&embd.add(&token_type_embed.forward(&pos_ids.zeros_like()?)?)?) } } #[derive(Debug, Clone)] pub struct Model { - embed_tokens: candle_nn::Embedding, - layers: Vec, - norm: RmsNorm, + embeddings: Embeddings, + layers: Vec, + norm: Option, device: Device, dtype: DType, } impl Model { pub fn new(cfg: &Config, vb: VarBuilder) -> Result { - let vb_m = vb.pp("model"); - let embed_tokens = - candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?; + let vb_m = match cfg.variant { + ModelVariant::Large => vb.pp("model"), + ModelVariant::Small => vb.pp("new"), + }; + // let embed_tokens = + // candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?; + let embeddings = Embeddings::new(cfg, vb_m.clone())?; let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb_m.device())?); let mut layers = Vec::with_capacity(cfg.num_hidden_layers); - let vb_l = vb_m.pp("layers"); + let vb_l = match cfg.variant { + ModelVariant::Large => vb_m.pp("layers"), + ModelVariant::Small => vb_m.pp("encoder").pp("layer"), + }; for layer_idx in 0..cfg.num_hidden_layers { - let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?; + let layer = Layer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?; layers.push(layer) } - let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?; + let norm = match cfg.variant { + ModelVariant::Large => Some(RmsNorm::new( + cfg.hidden_size, + cfg.norm_eps, + vb_m.pp("norm"), + )?), + ModelVariant::Small => None, + }; Ok(Self { - embed_tokens, + embeddings, layers, norm, - // sliding_window: 0, device: vb.device().clone(), dtype: vb.dtype(), }) @@ -335,15 +725,20 @@ impl Model { Some(self.prepare_attention_mask(mask)?) }; - let mut xs = self.embed_tokens.forward(input_ids)?; + let mut xs = self.embeddings.forward(input_ids)?; for layer in self.layers.iter_mut() { xs = layer.forward(&xs, attention_mask.as_ref())? } - xs.apply(&self.norm) + + if let Some(n) = &self.norm { + xs.apply(n) + } else { + Ok(xs) + } } } -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct EmbeddingModel { base_model: Model, lm_head: Linear, diff --git a/candle-transformers/src/models/stella_en_v5_400m.rs b/candle-transformers/src/models/stella_en_v5_400m.rs index 3b54d14234..bbe16377d6 100644 --- a/candle-transformers/src/models/stella_en_v5_400m.rs +++ b/candle-transformers/src/models/stella_en_v5_400m.rs @@ -1,9 +1,9 @@ use candle::{ DType, Device, IndexOp, Module, Result, Tensor }; -use candle_nn::{ Activation, VarBuilder, layer_norm }; +use candle_nn::{ layer_norm, Activation, LayerNorm, VarBuilder }; use std::sync::Arc; use std::time::Instant; -use super::with_tracing::{ linear, linear_no_bias, Linear, RmsNorm }; +use super::with_tracing::{ linear, linear_no_bias, Linear }; #[derive(Debug, Clone, Copy)] pub enum EmbedDim { @@ -37,19 +37,19 @@ pub struct Config { pub num_attention_heads: usize, pub max_position_embeddings: usize, pub type_vocab_size: usize, - pub pad_token_id: usize, - pub hidden_dropout_prob: f64, - pub attention_probs_dropout_prob: f64, - pub layer_norm_eps: f64, - pub initializer_range: f64, - pub position_embedding_type: String, + // pub pad_token_id: usize, + // pub hidden_dropout_prob: f64, + // pub attention_probs_dropout_prob: f64, + pub norm_eps: f64, + // pub initializer_range: f64, + // pub position_embedding_type: String, pub scaling_factor: f64, pub rope_theta: f64, - pub use_memory_efficient_attention: bool, - pub unpad_inputs: bool, - pub layer_norm_type: String, - pub logn_attention_scale: bool, - pub logn_attention_clip1: bool, + // pub use_memory_efficient_attention: bool, + // pub unpad_inputs: bool, + // pub layer_norm_type: String, + // pub logn_attention_scale: bool, + // pub logn_attention_clip1: bool, pub activation_fn: Activation, pub embed_head: EmbedHead, } @@ -77,19 +77,19 @@ impl Config { num_attention_heads: 16, max_position_embeddings: 8192, type_vocab_size: 2, - pad_token_id: 0, - hidden_dropout_prob: 0.1, - attention_probs_dropout_prob: 0.0, - layer_norm_eps: 1e-12, - initializer_range: 0.02, - position_embedding_type: "rope".to_string(), + // pad_token_id: 0, + // hidden_dropout_prob: 0.1, + // attention_probs_dropout_prob: 0.0, + norm_eps: 1e-12, + // initializer_range: 0.02, + // position_embedding_type: "rope".to_string(), scaling_factor: 2.0, rope_theta: 160000.0, - use_memory_efficient_attention: true, - unpad_inputs: false, - layer_norm_type: "layer_norm".to_string(), - logn_attention_scale: false, - logn_attention_clip1: false, + // use_memory_efficient_attention: true, + // unpad_inputs: false, + // layer_norm_type: "layer_norm".to_string(), + // logn_attention_scale: false, + // logn_attention_clip1: false, activation_fn: Activation::Gelu, embed_head, } @@ -100,10 +100,10 @@ impl Config { struct RotaryEmbedding { sin: Tensor, cos: Tensor, - _scaling_factor: f64, - _mixed_b: Option, - _dim: usize, - _base: f64, + // _scaling_factor: f64, + // _mixed_b: Option, + // _dim: usize, + // _base: f64, } impl RotaryEmbedding { @@ -144,10 +144,10 @@ impl RotaryEmbedding { Ok(Self { sin: emb.sin()?, cos: emb.cos()?, - _scaling_factor: scaling_factor, - _mixed_b: None, - _dim: dim, - _base: base, + // _scaling_factor: scaling_factor, + // _mixed_b: None, + // _dim: dim, + // _base: base, }) } } @@ -155,14 +155,14 @@ impl RotaryEmbedding { #[derive(Debug, Clone)] enum NormType { LayerNorm(candle_nn::LayerNorm), - RmsNorm(RmsNorm), + // RmsNorm(RmsNorm), } impl NormType { fn forward(&self, x: &Tensor) -> Result { match self { Self::LayerNorm(ln) => ln.forward(x), - Self::RmsNorm(rms) => rms.forward(x), + // Self::RmsNorm(rms) => rms.forward(x), } } } @@ -170,13 +170,13 @@ impl NormType { #[derive(Debug)] pub struct Embeddings { word_embeddings: candle_nn::Embedding, - position_embeddings: Option, - token_type_embeddings: Option, - layer_norm: NormType, - _padding_idx: usize, - _position_embedding_type: String, - rotary_emb: Option>, - position_ids: Option, + // position_embeddings: Option, + token_type_embeddings: candle_nn::Embedding, + layer_norm: LayerNorm, + // _padding_idx: usize, + // _position_embedding_type: String, + rotary_emb: Arc, + position_ids: Tensor, } impl Embeddings { @@ -187,63 +187,70 @@ impl Embeddings { vb.pp("word_embeddings") )?; - let position_embeddings = if cfg.position_embedding_type == "absolute" { - Some( - candle_nn::embedding( - cfg.max_position_embeddings, - cfg.hidden_size, - vb.pp("position_embeddings") - )? - ) - } else { - None - }; - - let token_type_embeddings = if cfg.type_vocab_size > 0 { - Some( - candle_nn::embedding( - cfg.type_vocab_size, - cfg.hidden_size, - vb.pp("token_type_embeddings") - )? - ) - } else { - None - }; - - let layer_norm = if cfg.layer_norm_type == "layer_norm" { + // let position_embeddings = if cfg.position_embedding_type == "absolute" { + // Some( + // candle_nn::embedding( + // cfg.max_position_embeddings, + // cfg.hidden_size, + // vb.pp("position_embeddings") + // )? + // ) + // } else { + // None + // }; + + let token_type_embeddings = candle_nn::embedding( + cfg.type_vocab_size, + cfg.hidden_size, + vb.pp("token_type_embeddings") + )?; + // if cfg.type_vocab_size > 0 { + // Some( + // candle_nn::embedding( + // cfg.type_vocab_size, + // cfg.hidden_size, + // vb.pp("token_type_embeddings") + // )? + // ) + // } else { + // None + // }; + + //if cfg.layer_norm_type == "layer_norm" { let weight = vb .pp("LayerNorm") .get_with_hints(cfg.hidden_size, "weight", candle_nn::Init::Const(1.0))?; let bias = vb .pp("LayerNorm") .get_with_hints(cfg.hidden_size, "bias", candle_nn::Init::Const(0.0))?; - NormType::LayerNorm(candle_nn::LayerNorm::new(weight, bias, cfg.layer_norm_eps)) - } else { - NormType::RmsNorm( - RmsNorm::new(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("LayerNorm"))? - ) - }; - - let rotary_emb = if cfg.position_embedding_type == "rope" { - Some(Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb.device())?)) - } else { - None - }; - - let position_ids = if cfg.position_embedding_type == "absolute" { - Some(Tensor::arange(0u32, cfg.max_position_embeddings as u32, vb.device())?) - } else { - None - }; + let layer_norm = candle_nn::LayerNorm::new(weight, bias, cfg.norm_eps); + // } else { + // NormType::RmsNorm( + // RmsNorm::new(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("LayerNorm"))? + // ) + // }; + + // let rotary_emb = if cfg.position_embedding_type == "rope" { + // Some(Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb.device())?)) + // } else { + // None + // }; + let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb.device())?); + + // let position_ids = if cfg.position_embedding_type == "absolute" { + // Some(Tensor::arange(0u32, cfg.max_position_embeddings as u32, vb.device())?) + // } else { + // None + // }; + let position_ids = Tensor::arange(0u32, cfg.max_position_embeddings as u32, word_embeddings.embeddings().device())?; Ok(Self { word_embeddings, - position_embeddings, + // position_embeddings, token_type_embeddings, layer_norm, - _padding_idx: cfg.pad_token_id, - _position_embedding_type: cfg.position_embedding_type.clone(), + // _padding_idx: cfg.pad_token_id, + // _position_embedding_type: cfg.position_embedding_type.clone(), rotary_emb, position_ids, }) @@ -252,57 +259,64 @@ impl Embeddings { pub fn forward( &mut self, input_ids: &Tensor, - token_type_ids: Option<&Tensor>, - position_ids: Option<&Tensor>, - inputs_embeds: Option<&Tensor>, - unpad_inputs: bool, - attention_mask: Option<&Tensor> - ) -> Result<(Tensor, Option, Option<(Tensor, Tensor)>, Option>)> { + // token_type_ids: Option<&Tensor>, + // position_ids: Option<&Tensor>, + // inputs_embeds: Option<&Tensor>, + // unpad_inputs: bool, + // attention_mask: Option<&Tensor> + ) -> Result<(Tensor, Option<(Tensor, Tensor)>)> { let (batch_size, seq_length) = input_ids.dims2()?; - let mut embeddings = match inputs_embeds { - Some(e) => e.clone(), - None => self.word_embeddings.forward(input_ids)?, - }; + let mut embeddings = self.word_embeddings.forward(input_ids)?; + // match inputs_embeds { + // Some(e) => e.clone(), + // None => self.word_embeddings.forward(input_ids)?, + // }; // Get position_ids first - let position_ids = if let Some(ids) = position_ids { - ids.clone() - } else { + // let position_ids = if let Some(ids) = position_ids { + // ids.clone() + // } else { // Get device from input_ids which is always available - let device = input_ids.device(); + // let device = input_ids.device(); // Initialize position_ids if None - if self.position_ids.is_none() { - self.position_ids = Some(Tensor::arange(0u32, seq_length as u32, device)?); - } + // if self.position_ids.is_none() { + // self.position_ids = Some(Tensor::arange(0u32, seq_length as u32, device)?); + // } - // Now check if we need to extend it - if seq_length > self.position_ids.as_ref().unwrap().dim(0)? { - self.position_ids = Some(Tensor::arange(0u32, seq_length as u32, device)?); - } + // // Now check if we need to extend it + // if seq_length > self.position_ids.as_ref().unwrap().dim(0)? { + // self.position_ids = Some(Tensor::arange(0u32, seq_length as u32, device)?); + // } - if unpad_inputs { + // let position_ids = + /*if unpad_inputs { // For now, just use the same position IDs as padded case since we don't have lengths self.position_ids .as_ref() .unwrap() .narrow(0, 0, seq_length)? .expand((batch_size, seq_length))? - } else { - self.position_ids - .as_ref() - .unwrap() - .narrow(0, 0, seq_length)? - .expand((batch_size, seq_length))? - } - }; - - // Get rotary embeddings if using RoPE - let rope_embeds = if let Some(rotary) = &self.rotary_emb { + } else {*/ + // self.position_ids + // .as_ref() + // .unwrap() + // .narrow(0, 0, seq_length)? + // .expand((batch_size, seq_length))?; + // }; + // }; + + let position_ids = self.position_ids + .as_ref() + .narrow(0, 0, seq_length)? + .expand((batch_size, seq_length))?; + + + let rope_embeds = { // Get the cos and sin for this sequence length - let cos = rotary.cos.narrow(0, 0, seq_length)?; // [seq_len, head_dim] - let sin = rotary.sin.narrow(0, 0, seq_length)?; // [seq_len, head_dim] + let cos = self.rotary_emb.cos.narrow(0, 0, seq_length)?; // [seq_len, head_dim] + let sin = self.rotary_emb.sin.narrow(0, 0, seq_length)?; // [seq_len, head_dim] // Index using position_ids if needed let position_ids = position_ids.flatten_all()?; @@ -310,33 +324,38 @@ impl Embeddings { let sin = sin.index_select(&position_ids, 0)?; // Use index_select instead of i() Some((cos, sin)) - } else { - None }; + // // Get rotary embeddings if using RoPE + // let rope_embeds = if let Some(rotary) = &self.rotary_emb { + + // } else { + // None + // }; // Handle token type embeddings - if let Some(token_emb) = &self.token_type_embeddings { - let token_type_ids = if let Some(ids) = token_type_ids { - ids.clone() - } else { - position_ids.zeros_like()? // Use mul(0) equivalent - }; - if unpad_inputs { - todo!("Implement unpadded case"); - } else { - embeddings = embeddings.add(&token_emb.forward(&token_type_ids)?)?; - } - } + embeddings = embeddings.add(&self.token_type_embeddings.forward(&position_ids.zeros_like()?)?).unwrap(); + // if let Some(token_emb) = &self.token_type_embeddings { + // let token_type_ids = if let Some(ids) = token_type_ids { + // ids.clone() + // } else { + // position_ids.zeros_like()? // Use mul(0) equivalent + // }; + // if unpad_inputs { + // todo!("Implement unpadded case"); + // } else { + // embeddings = embeddings.add(&token_emb.forward(&position_ids.zeros_like()?)?).unwrap(); + // } + // } // Handle absolute position embeddings - if let Some(pos_emb) = &self.position_embeddings { - let position_embeddings = pos_emb.forward(&position_ids)?; - embeddings = embeddings.add(&position_embeddings)?; - } + // if let Some(pos_emb) = &self.position_embeddings { + // let position_embeddings = pos_emb.forward(&position_ids)?; + // embeddings = embeddings.add(&position_embeddings)?; + // } let embeddings = self.layer_norm.forward(&embeddings)?; - Ok((embeddings, attention_mask.cloned(), rope_embeds, None)) + Ok((embeddings, rope_embeds)) } } @@ -347,7 +366,7 @@ struct NewAttention { num_heads: usize, head_dim: usize, hidden_size: usize, - _use_memory_efficient_attention: bool, + // _use_memory_efficient_attention: bool, } impl NewAttention { @@ -365,7 +384,7 @@ impl NewAttention { num_heads, head_dim, hidden_size: hidden_sz, - _use_memory_efficient_attention: cfg.use_memory_efficient_attention, + // _use_memory_efficient_attention: cfg.use_memory_efficient_attention, }) } @@ -374,7 +393,7 @@ impl NewAttention { hidden_states: &Tensor, attention_bias: Option<&Tensor>, rope_embeds: Option<&(Tensor, Tensor)>, - _attention_scale: Option<&Tensor> + // _attention_scale: Option<&Tensor> ) -> Result { let (b_sz, seq_len, _) = hidden_states.dims3()?; @@ -407,10 +426,10 @@ impl NewAttention { // Prepare tensors for batched matmul using matmul // Reshape tensors to merge batch and head dimensions - let bsz = b_sz as usize; - let nh = self.num_heads as usize; - let s_len = seq_len as usize; - let h_dim = self.head_dim as usize; + let bsz = b_sz; + let nh = self.num_heads; + let s_len = seq_len; + let h_dim = self.head_dim; // Reshape tensors to [batch_size * num_heads, seq_len, head_dim] let query_states_reshaped = query_states.reshape((bsz * nh, s_len, h_dim))?; @@ -435,8 +454,8 @@ impl NewAttention { // Apply attention mask let mut attn_weights = if let Some(bias) = attention_bias { - let attn_weights = attn_weights.broadcast_add(bias)?; - attn_weights + // let attn_weights = attn_weights.broadcast_add(bias)?; + attn_weights.broadcast_add(bias)? } else { attn_weights }; @@ -525,26 +544,28 @@ impl NewLayer { let attention = NewAttention::new(cfg, vb.pp("attention"))?; let mlp = NewGatedMLP::new(cfg, vb.pp("mlp"))?; - let ln_eps = cfg.layer_norm_eps; + // let ln_eps = cfg.layer_norm_eps; // Use LayerNorm or RmsNorm based on config - let (attn_ln, mlp_ln) = if cfg.layer_norm_type == "layer_norm" { + let (attn_ln, mlp_ln) = { let attn_ln = layer_norm( cfg.hidden_size, - candle_nn::LayerNormConfig { eps: ln_eps, ..Default::default() }, + candle_nn::LayerNormConfig { eps: cfg.norm_eps, ..Default::default() }, vb.pp("attn_ln") )?; let mlp_ln = layer_norm( cfg.hidden_size, - candle_nn::LayerNormConfig { eps: ln_eps, ..Default::default() }, + candle_nn::LayerNormConfig { eps: cfg.norm_eps, ..Default::default() }, vb.pp("mlp_ln") )?; (NormType::LayerNorm(attn_ln), NormType::LayerNorm(mlp_ln)) - } else { - let attn_ln = RmsNorm::new(cfg.hidden_size, ln_eps, vb.pp("attn_ln"))?; - let mlp_ln = RmsNorm::new(cfg.hidden_size, ln_eps, vb.pp("mlp_ln"))?; - (NormType::RmsNorm(attn_ln), NormType::RmsNorm(mlp_ln)) }; + // else + // { + // let attn_ln = RmsNorm::new(cfg.hidden_size, ln_eps, vb.pp("attn_ln"))?; + // let mlp_ln = RmsNorm::new(cfg.hidden_size, ln_eps, vb.pp("mlp_ln"))?; + // (NormType::RmsNorm(attn_ln), NormType::RmsNorm(mlp_ln)) + // }; Ok(Self { attention, @@ -559,7 +580,7 @@ impl NewLayer { hidden_states: &Tensor, attention_bias: Option<&Tensor>, rope_embeds: Option<&(Tensor, Tensor)>, - attention_scale: Option<&Tensor> + // attention_scale: Option<&Tensor> ) -> Result { // Store original input let original = hidden_states; @@ -569,7 +590,7 @@ impl NewLayer { original, attention_bias, rope_embeds, - attention_scale + // attention_scale )?; let hidden_states = original.add(&hidden_states)?; @@ -611,7 +632,7 @@ impl NewEncoder { hidden_states: &Tensor, attention_bias: Option<&Tensor>, rope_embeds: Option<&(Tensor, Tensor)>, - attention_scale: Option<&Tensor> + // attention_scale: Option<&Tensor> ) -> Result { let mut hidden_states = hidden_states.clone(); @@ -620,7 +641,7 @@ impl NewEncoder { &hidden_states, attention_bias, rope_embeds, - attention_scale + // attention_scale )?; } @@ -634,7 +655,7 @@ pub struct NewModel { encoder: NewEncoder, device: Device, dtype: DType, - config: Config, + // config: Config, } impl NewModel { @@ -647,7 +668,7 @@ impl NewModel { encoder, device: vb.device().clone(), dtype: vb.dtype(), - config: cfg.clone(), + // config: cfg.clone(), }) } @@ -672,56 +693,53 @@ impl NewModel { pub fn forward( &mut self, input_ids: &Tensor, - attention_mask: Option<&Tensor>, - token_type_ids: Option<&Tensor>, - position_ids: Option<&Tensor> + attention_mask: &Tensor, + // token_type_ids: Option<&Tensor>, + // position_ids: Option<&Tensor> ) -> Result { - let (batch_size, seq_length) = input_ids.dims2()?; + let (_, seq_length) = input_ids.dims2()?; // Get attention mask if not provided - let attention_mask = match attention_mask { - Some(mask) => mask.clone(), - None => Tensor::ones((batch_size, seq_length), self.dtype, &self.device)?, - }; + // let attention_mask = mask; // Prepare attention bias let attention_bias = if seq_length <= 1 { None } else { - Some(self.prepare_attention_mask(&attention_mask)?) + Some(self.prepare_attention_mask(attention_mask)?) }; // Get embeddings and rotary embeddings - let (hidden_states, _, rope_embeds, _) = self.embeddings.forward( + let (hidden_states, rope_embeds) = self.embeddings.forward( input_ids, - token_type_ids, - position_ids, - None, - self.config.unpad_inputs, - Some(&attention_mask) + // token_type_ids, + // position_ids, + // None, + // self.config.unpad_inputs, + // Some(&attention_mask) )?; // Compute attention scale if needed - let attention_scale = if self.config.logn_attention_scale { - let scale = - attention_mask.sum_keepdim(1)?.log()? / - (self.config.max_position_embeddings as f64).ln(); - if self.config.logn_attention_clip1 { - let scale = scale?; - Some(scale.maximum(&Tensor::new(1f64, &self.device)?)?) - } else { - Some(scale?) - } - } else { - None - }; + // let attention_scale = if self.config.logn_attention_scale { + // let scale = + // attention_mask.sum_keepdim(1)?.log()? / + // (self.config.max_position_embeddings as f64).ln(); + // if self.config.logn_attention_clip1 { + // let scale = scale?; + // Some(scale.maximum(&Tensor::new(1f64, &self.device)?)?) + // } else { + // Some(scale?) + // } + // } else { + // None + // }; // Forward through encoder let hidden_states = self.encoder.forward( &hidden_states, attention_bias.as_ref(), rope_embeds.as_ref(), - attention_scale.as_ref() + // attention_scale.as_ref() )?; Ok(hidden_states) @@ -729,65 +747,65 @@ impl NewModel { } // Optional pooler implementation -#[derive(Debug)] -pub struct NewPooler { - dense: Linear, -} - -impl NewPooler { - pub fn new(cfg: &Config, vb: VarBuilder) -> Result { - let dense = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("dense"))?; - Ok(Self { dense }) - } - - pub fn forward(&self, hidden_states: &Tensor) -> Result { - let first_token = hidden_states.i((.., 0, ..))?; - let pooled = self.dense.forward(&first_token)?; - pooled.tanh() - } -} - -// Complete model with pooler -#[derive(Debug)] -pub struct NewModelWithPooler { - model: NewModel, - pooler: Option, -} - -impl NewModelWithPooler { - pub fn new(cfg: &Config, vb: VarBuilder, add_pooling_layer: bool) -> Result { - let vb_m = vb.pp("new"); - let model = NewModel::new(cfg, vb_m.pp("model"))?; - let pooler = if add_pooling_layer { - Some(NewPooler::new(cfg, vb.pp("new").pp("pooler"))?) - } else { - None - }; - Ok(Self { model, pooler }) - } - - pub fn forward( - &mut self, - input_ids: &Tensor, - attention_mask: Option<&Tensor>, - token_type_ids: Option<&Tensor>, - position_ids: Option<&Tensor> - ) -> Result<(Tensor, Option)> { - let hidden_states = self.model.forward( - input_ids, - attention_mask, - token_type_ids, - position_ids - )?; - - let pooled_output = match &self.pooler { - Some(pooler) => Some(pooler.forward(&hidden_states)?), - None => None, - }; - - Ok((hidden_states, pooled_output)) - } -} +// #[derive(Debug)] +// pub struct NewPooler { +// dense: Linear, +// } + +// impl NewPooler { +// pub fn new(cfg: &Config, vb: VarBuilder) -> Result { +// let dense = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("dense"))?; +// Ok(Self { dense }) +// } + +// pub fn forward(&self, hidden_states: &Tensor) -> Result { +// let first_token = hidden_states.i((.., 0, ..))?; +// let pooled = self.dense.forward(&first_token)?; +// pooled.tanh() +// } +// } + +// // Complete model with pooler +// #[derive(Debug)] +// pub struct NewModelWithPooler { +// model: NewModel, +// pooler: Option, +// } + +// impl NewModelWithPooler { +// pub fn new(cfg: &Config, vb: VarBuilder, add_pooling_layer: bool) -> Result { +// let vb_m = vb.pp("new"); +// let model = NewModel::new(cfg, vb_m.pp("model"))?; +// let pooler = if add_pooling_layer { +// Some(NewPooler::new(cfg, vb.pp("new").pp("pooler"))?) +// } else { +// None +// }; +// Ok(Self { model, pooler }) +// } + +// pub fn forward( +// &mut self, +// input_ids: &Tensor, +// attention_mask: Option<&Tensor>, +// token_type_ids: Option<&Tensor>, +// position_ids: Option<&Tensor> +// ) -> Result<(Tensor, Option)> { +// let hidden_states = self.model.forward( +// input_ids, +// attention_mask, +// token_type_ids, +// position_ids +// )?; + +// let pooled_output = match &self.pooler { +// Some(pooler) => Some(pooler.forward(&hidden_states)?), +// None => None, +// }; + +// Ok((hidden_states, pooled_output)) +// } +// } #[derive(Debug)] pub struct EmbeddingModel { @@ -811,7 +829,7 @@ impl EmbeddingModel { } pub fn forward(&mut self, input_ids: &Tensor, mask: &Tensor) -> Result { - let x = self.base_model.forward(input_ids, Some(mask), None, None)?; + let x = self.base_model.forward(input_ids, mask)?;//, None, None)?; let x = self.pool(&x, mask)?; self.lm_head.forward(&x.to_dtype(DType::F32)?) }