Skip to content

Commit

Permalink
Merge pull request #1 from AnubhabB/temp-stella-400M
Browse files Browse the repository at this point in the history
Your implementation of `Stella 400M` and the previous `Stella 1.5B` now supported in a single file and entry point
  • Loading branch information
iskng authored Nov 23, 2024
2 parents 91d4602 + ed7fd9b commit 437f5f1
Show file tree
Hide file tree
Showing 22 changed files with 906 additions and 383 deletions.
18 changes: 9 additions & 9 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ exclude = [
resolver = "2"

[workspace.package]
version = "0.7.2"
version = "0.8.0"
edition = "2021"
description = "Minimalist ML framework."
repository = "https://github.com/huggingface/candle"
Expand All @@ -33,14 +33,14 @@ ab_glyph = "0.2.23"
accelerate-src = { version = "0.3.2" }
anyhow = { version = "1", features = ["backtrace"] }
byteorder = "1.4.3"
candle = { path = "./candle-core", package = "candle-core", version = "0.7.2" }
candle-datasets = { path = "./candle-datasets", version = "0.7.2" }
candle-flash-attn = { path = "./candle-flash-attn", version = "0.7.2" }
candle-kernels = { path = "./candle-kernels", version = "0.7.2" }
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.7.2" }
candle-nn = { path = "./candle-nn", version = "0.7.2" }
candle-onnx = { path = "./candle-onnx", version = "0.7.2" }
candle-transformers = { path = "./candle-transformers", version = "0.7.2" }
candle = { path = "./candle-core", package = "candle-core", version = "0.8.0" }
candle-datasets = { path = "./candle-datasets", version = "0.8.0" }
candle-flash-attn = { path = "./candle-flash-attn", version = "0.8.0" }
candle-kernels = { path = "./candle-kernels", version = "0.8.0" }
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.8.0" }
candle-nn = { path = "./candle-nn", version = "0.8.0" }
candle-onnx = { path = "./candle-onnx", version = "0.8.0" }
candle-transformers = { path = "./candle-transformers", version = "0.8.0" }
clap = { version = "4.2.4", features = ["derive"] }
criterion = { version = "0.5.1", default-features=false }
cudarc = { version = "0.12.1", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
Expand Down
14 changes: 14 additions & 0 deletions candle-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,20 @@
//! Python can really add overhead in more complex workflows and the [GIL](https://www.backblaze.com/blog/the-python-gil-past-present-and-future/) is a notorious source of headaches.
//!
//! Rust is cool, and a lot of the HF ecosystem already has Rust crates [safetensors](https://github.com/huggingface/safetensors) and [tokenizers](https://github.com/huggingface/tokenizers)
//!
//! ## Other Crates
//!
//! Candle consists of a number of crates. This crate holds core the common data structures but you may wish
//! to look at the docs for the other crates which can be found here:
//!
//! - [candle-core](https://docs.rs/candle-core/). Core Datastructures and DataTypes.
//! - [candle-nn](https://docs.rs/candle-nn/). Building blocks for Neural Nets.
//! - [candle-datasets](https://docs.rs/candle-datasets/). Rust access to commonly used Datasets like MNIST.
//! - [candle-examples](https://docs.rs/candle-examples/). Examples of Candle in Use.
//! - [candle-onnx](https://docs.rs/candle-onnx/). Loading and using ONNX models.
//! - [candle-pyo3](https://docs.rs/candle-pyo3/). Access to Candle from Python.
//! - [candle-transformers](https://docs.rs/candle-transformers/). Candle implemntation of many published transformer models.
//!

#[cfg(feature = "accelerate")]
mod accelerate;
Expand Down
11 changes: 10 additions & 1 deletion candle-core/src/metal_backend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1237,7 +1237,7 @@ impl BackendStorage for MetalStorage {
let dst_el = ids_l.shape().elem_count();
let dtype = self.dtype;
let device = self.device();
let buffer = device.new_buffer(dst_el, dtype, "index_select")?;
let buffer = device.new_buffer(dst_el, dtype, "gather")?;
let name = match (ids.dtype, self.dtype) {
(DType::U32, DType::F32) => "gather_u32_f32",
(DType::U32, DType::F16) => "gather_u32_f16",
Expand Down Expand Up @@ -1324,14 +1324,23 @@ impl BackendStorage for MetalStorage {
let device = self.device();
let buffer = device.new_buffer(dst_el, dtype, "index_select")?;
let name = match (ids.dtype, self.dtype) {
(DType::U8, DType::U8) => "is_u8_u8",
(DType::U8, DType::U32) => "is_u8_u32",
(DType::U8, DType::I64) => "is_u8_i64",
(DType::U8, DType::BF16) => "is_u8_bf16",
(DType::U8, DType::F32) => "is_u8_f32",
(DType::U8, DType::F16) => "is_u8_f16",

(DType::U32, DType::U8) => "is_u32_u8",
(DType::U32, DType::U32) => "is_u32_u32",
(DType::U32, DType::I64) => "is_u32_i64",
(DType::U32, DType::F32) => "is_u32_f32",
(DType::U32, DType::F16) => "is_u32_f16",
(DType::U32, DType::BF16) => "is_u32_bf16",

(DType::I64, DType::U8) => "is_i64_u8",
(DType::I64, DType::U32) => "is_i64_u32",
(DType::I64, DType::I64) => "is_i64_i64",
(DType::I64, DType::F32) => "is_i64_f32",
(DType::I64, DType::F16) => "is_i64_f16",
(DType::I64, DType::BF16) => "is_i64_bf16",
Expand Down
24 changes: 22 additions & 2 deletions candle-examples/examples/stella-en-v5/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ Stella_en_1.5B_v5 is trained by [MRL](https://arxiv.org/abs/2205.13147) enabling
The following reproduces the example in the [model card](https://huggingface.co/dunzhang/stella_en_1.5B_v5) for a retrieval task (s2p). The sample queries and docs are hardcoded in the example.

```bash
$ cargo run --example stella-en-v5 --release --features <metal | cuda>
$ cargo run --example stella-en-v5 --release --features <metal | cuda> -- --which 1.5b

>
> Score: 0.8178786
Expand All @@ -37,9 +37,29 @@ $ cargo run --example stella-en-v5 --release --features <metal | cuda>
> caused by free radicals. Regular consumption of green tea has been associated with improved heart health, enhanced cognitive function, and a reduced risk of certain types >
> of cancer. The polyphenols in green tea may also have anti-inflammatory and weight loss properties.
>

$ cargo run --example stella-en-v5 --release --features <metal | cuda> -- --which 400m

>
> Score: 0.8397539
> Query: What are some ways to reduce stress?
> Answer: There are many effective ways to reduce stress. Some common techniques include deep breathing, meditation, and physical activity. Engaging in hobbies, spending
> time in nature, and connecting with loved ones can also help alleviate stress. Additionally, setting boundaries, practicing self-care, and learning to say no can prevent
> stress from building up.
>
>
>
> Score: 0.809545
> Query: What are the benefits of drinking green tea?
> Answer: Green tea has been consumed for centuries and is known for its potential health benefits. It contains antioxidants that may help protect the body against damage
> caused by free radicals. Regular consumption of green tea has been associated with improved heart health, enhanced cognitive function, and a reduced risk of certain types
> of cancer. The polyphenols in green tea may also have anti-inflammatory and weight loss properties.
>
```

## Supported options:
- `Stella_en_15B_v5` supports 256, 768, 1024, 2048, 4096, 6144 and 8192 embedding dimensions (though the model card mentions 512, I couldn't find weights for the same). In the example run this is supported with `--embed-dim` option. E.g. `... --embed-dim 4096`. Defaults to `1024`.
- `Stella_en_v5` has 2 model variants published - a 1.5B variant and 400M variant. This is enabled through the flag `--which`. E.g. `--which 400m` or `--which 1.5b`.

- `Stella_en_v5` supports 256, 768, 1024, 2048, 4096, 6144 and 8192 embedding dimensions (though the model card mentions 512, I couldn't find weights for the same). In the example run this is supported with `--embed-dim` option. E.g. `... --embed-dim 4096`. Defaults to `1024`.

- As per the [model card](https://huggingface.co/dunzhang/stella_en_1.5B_v5), the model has been primarily trained on `s2s` (similarity) and `s2p` (retrieval) tasks. These require a slightly different `query` preprocessing (a different prompt template for each). In this example this is enabled though `--task` option.
74 changes: 51 additions & 23 deletions candle-examples/examples/stella-en-v5/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -212,13 +212,24 @@ impl EncodeTask {
}
}

#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
enum Which {
#[value(name = "1.5b")]
Large,
#[value(name = "400m")]
Small,
}

#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,

#[arg(long)]
which: Which,

/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
Expand Down Expand Up @@ -250,24 +261,33 @@ struct Args {

// Tokenizer creation is super critical in our case.
// We are going to be `padding: Left` for each batch
fn create_tokenizer(tokenizer_file: &Path) -> Result<Tokenizer> {
fn create_tokenizer(tokenizer_file: &Path, which: Which) -> Result<Tokenizer> {
let mut tokenizer = Tokenizer::from_file(tokenizer_file).map_err(E::msg)?;
let pad_id = if let Some(pad_id) = tokenizer.token_to_id("<|endoftext|>") {
pad_id
} else {
return Err(anyhow!(
"Tokenizer doesn't contain expected `<|endoftext|>` token"
));
};

// This part is super important, we are padding the tokens to the *`left`* and not the usual *`right`* padding
tokenizer.with_padding(Some(PaddingParams {
strategy: PaddingStrategy::BatchLongest,
direction: PaddingDirection::Left,
pad_id,
pad_token: "<|endoftext|>".to_string(),
..Default::default()
}));
if which == Which::Large {
let pad_id = if let Some(pad_id) = tokenizer.token_to_id("<|endoftext|>") {
pad_id
} else {
return Err(anyhow!(
"Tokenizer doesn't contain expected `<|endoftext|>` token"
));
};

// This part is super important, we are padding the tokens to the *`left`* and not the usual *`right`* padding
tokenizer.with_padding(Some(PaddingParams {
strategy: PaddingStrategy::BatchLongest,
direction: PaddingDirection::Left,
pad_id,
pad_token: "<|endoftext|>".to_string(),
..Default::default()
}));
} else {
tokenizer.with_padding(Some(PaddingParams {
strategy: PaddingStrategy::BatchLongest,
direction: PaddingDirection::Right,
..Default::default()
}));
}

Ok(tokenizer)
}
Expand Down Expand Up @@ -298,7 +318,19 @@ fn main() -> Result<()> {
Some(d) => d,
None => EmbedDim::Dim1024,
};
let repo = api.repo(Repo::model("dunzhang/stella_en_1.5B_v5".to_string()));

let (repo, cfg) = match args.which {
Which::Large => (
"dunzhang/stella_en_1.5B_v5",
Config::new_1_5_b_v5(embed_dim.embed_dim()),
),
Which::Small => (
"dunzhang/stella_en_400M_v5",
Config::new_400_m_v5(embed_dim.embed_dim()),
),
};

let repo = api.repo(Repo::model(repo.to_string()));
let tokenizer_filename = match args.tokenizer_file {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("tokenizer.json")?,
Expand Down Expand Up @@ -330,7 +362,7 @@ fn main() -> Result<()> {
println!("retrieved the files in {:?}", start.elapsed());

// Initializing the tokenizer which would require us to add padding to the `left` for batch encoding
let tokenizer = create_tokenizer(tokenizer_filename.as_path())?;
let tokenizer = create_tokenizer(tokenizer_filename.as_path(), args.which)?;

let start = std::time::Instant::now();

Expand All @@ -343,11 +375,7 @@ fn main() -> Result<()> {
let embed_vb =
unsafe { VarBuilder::from_mmaped_safetensors(&embed_weight_files, DType::F32, &device)? };

let model = EmbeddingModel::new(
&Config::new_1_5_b_v5(embed_dim.embed_dim()),
base_vb,
embed_vb,
)?;
let model = EmbeddingModel::new(&cfg, base_vb, embed_vb)?;

println!("loaded the model in {:?}", start.elapsed());

Expand Down
4 changes: 2 additions & 2 deletions candle-flash-attn/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "candle-flash-attn"
version = "0.7.2"
version = "0.8.0"
edition = "2021"

description = "Flash attention layer for the candle ML framework."
Expand All @@ -11,7 +11,7 @@ license = "MIT OR Apache-2.0"
readme = "README.md"

[dependencies]
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.7.2" }
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.8.0" }
half = { version = "2.3.1", features = ["num-traits"] }

[build-dependencies]
Expand Down
2 changes: 1 addition & 1 deletion candle-kernels/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "candle-kernels"
version = "0.7.2"
version = "0.8.0"
edition = "2021"

description = "CUDA kernels for Candle"
Expand Down
2 changes: 1 addition & 1 deletion candle-metal-kernels/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "candle-metal-kernels"
version = "0.7.2"
version = "0.8.0"
edition = "2021"

description = "Metal kernels for Candle"
Expand Down
4 changes: 4 additions & 0 deletions candle-metal-kernels/src/indexing.metal
Original file line number Diff line number Diff line change
Expand Up @@ -193,12 +193,16 @@ INDEX_OP(is_i64_f16, int64_t, half)
INDEX_OP(is_i64_bf16, int64_t, bfloat)
#endif

INDEX_OP(is_u32_u8, uint32_t, uint8_t)
INDEX_OP(is_u32_u32, uint32_t, uint32_t)
INDEX_OP(is_u32_f32, uint32_t, float)
INDEX_OP(is_u32_f16, uint32_t, half)
#if defined(__HAVE_BFLOAT__)
INDEX_OP(is_u32_bf16, uint32_t, bfloat)
#endif

INDEX_OP(is_u8_u8, uint8_t, uint8_t)
INDEX_OP(is_u8_u32, uint8_t, uint32_t)
INDEX_OP(is_u8_f32, uint8_t, float)
INDEX_OP(is_u8_f16, uint8_t, half)
#if defined(__HAVE_BFLOAT__)
Expand Down
2 changes: 2 additions & 0 deletions candle-nn/src/activation.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
//! Activation Functions
//!
use candle::{Result, Tensor};
use serde::Deserialize;

Expand Down
2 changes: 2 additions & 0 deletions candle-nn/src/kv_cache.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
//! Cache Implementations
//!
use candle::{Device, Result, Tensor};

#[derive(Debug, Clone)]
Expand Down
17 changes: 17 additions & 0 deletions candle-nn/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,20 @@
//! candle-nn
//!
//! ## Other Crates
//!
//! Candle consists of a number of crates. This crate holds structs and functions
//! that allow you to build and train neural nets. You may wish
//! to look at the docs for the other crates which can be found here:
//!
//! - [candle-core](https://docs.rs/candle-core/). Core Datastructures and DataTypes.
//! - [candle-nn](https://docs.rs/candle-nn/). Building blocks for Neural Nets.
//! - [candle-datasets](https://docs.rs/candle-datasets/). Rust access to commonly used Datasets like MNIST.
//! - [candle-examples](https://docs.rs/candle-examples/). Examples of Candle in Use.
//! - [candle-onnx](https://docs.rs/candle-onnx/). Loading and using ONNX models.
//! - [candle-pyo3](https://docs.rs/candle-pyo3/). Access to Candle from Python.
//! - [candle-transformers](https://docs.rs/candle-transformers/). Candle implemntation of many published transformer models.
//!

pub mod activation;
pub mod batch_norm;
pub mod conv;
Expand Down
2 changes: 2 additions & 0 deletions candle-nn/src/loss.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
//! Loss Calculations
//!
use candle::{Result, Tensor};

/// The negative log likelihood loss.
Expand Down
3 changes: 3 additions & 0 deletions candle-nn/src/ops.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
//! Tensor ops.
//!

use candle::{CpuStorage, DType, Layout, Module, Result, Shape, Tensor, D};
use rayon::prelude::*;

Expand Down
2 changes: 2 additions & 0 deletions candle-nn/src/rotary_emb.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
//! Rotary Embeddings
//!
use candle::{CpuStorage, Layout, Result, Shape, Tensor, D};
use rayon::prelude::*;

Expand Down
2 changes: 2 additions & 0 deletions candle-nn/src/sequential.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
//! Sequential Layer
//!
//! A sequential layer used to chain multiple layers and closures.
use candle::{Module, Result, Tensor};

Expand Down
2 changes: 2 additions & 0 deletions candle-nn/src/var_builder.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
//! A `VarBuilder` for variable retrieval from models
//!
//! A `VarBuilder` is used to retrieve variables used by a model. These variables can either come
//! from a pre-trained checkpoint, e.g. using `VarBuilder::from_mmaped_safetensors`, or initialized
//! for training, e.g. using `VarBuilder::from_varmap`.
Expand Down
2 changes: 2 additions & 0 deletions candle-nn/src/var_map.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
//! A `VarMap` is a store that holds named variables.
//!
use candle::{DType, Device, Result, Shape, Tensor, Var};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
Expand Down
6 changes: 3 additions & 3 deletions candle-onnx/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "candle-onnx"
version = "0.7.2"
version = "0.8.0"
edition = "2021"

description = "ONNX support for Candle"
Expand All @@ -10,8 +10,8 @@ categories = ["science"]
license = "MIT OR Apache-2.0"

[dependencies]
candle = { path = "../candle-core", package = "candle-core", version = "0.7.2" }
candle-nn = { path = "../candle-nn", version = "0.7.2" }
candle = { path = "../candle-core", package = "candle-core", version = "0.8.0" }
candle-nn = { path = "../candle-nn", version = "0.8.0" }
prost = "0.12.1"

[build-dependencies]
Expand Down
3 changes: 2 additions & 1 deletion candle-transformers/src/models/chinese_clip/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,8 @@ impl ChineseClipModel {
) -> Result<Tensor> {
let output = self
.text_model
.forward(input_ids, token_type_ids, attention_mask)?;
.forward(input_ids, token_type_ids, attention_mask)?
.contiguous()?;
self.text_projection.forward(&output)
}

Expand Down
Loading

0 comments on commit 437f5f1

Please sign in to comment.