From 5c532ec5dc51cd17cd4bb9ae940ecf2c9baf89f6 Mon Sep 17 00:00:00 2001 From: rainiwu Date: Fri, 26 Jan 2024 00:29:35 -0800 Subject: [PATCH 1/8] remove deprecated ftz intrinsics --- dfdx-core/src/lib.rs | 38 -------------------------------------- dfdx/examples/12-mnist.rs | 3 --- 2 files changed, 41 deletions(-) diff --git a/dfdx-core/src/lib.rs b/dfdx-core/src/lib.rs index 31e61643..c126db2c 100644 --- a/dfdx-core/src/lib.rs +++ b/dfdx-core/src/lib.rs @@ -128,44 +128,6 @@ pub mod prelude { pub use crate::tensor_ops::*; } -/// Sets a CPU `sse` flag to flush denormal floating point numbers to zero. The opposite of this is [keep_denormals()]. -/// -/// Some resources: -/// 1. [Effects of Flush-To-Zero mode](https://developer.arm.com/documentation/dui0473/c/neon-and-vfp-programming/the-effects-of-using-flush-to-zero-mode?lang=en) -/// 2. [When to use Flush-To-Zero mode](https://developer.arm.com/documentation/dui0473/c/neon-and-vfp-programming/when-to-use-flush-to-zero-mode?lang=en) -pub fn flush_denormals_to_zero() { - #[cfg(all(target_arch = "x86", target_feature = "sse"))] - { - use std::arch::x86::{_MM_FLUSH_ZERO_ON, _MM_SET_FLUSH_ZERO_MODE}; - unsafe { _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON) } - } - - #[cfg(all(target_arch = "x86_64", target_feature = "sse"))] - { - use std::arch::x86_64::{_MM_FLUSH_ZERO_ON, _MM_SET_FLUSH_ZERO_MODE}; - unsafe { _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON) } - } -} - -/// Sets a CPU flag to keep denormal floating point numbers. The opposite of this is [flush_denormals_to_zero()]. -/// -/// Some resources: -/// 1. [Effects of Flush-To-Zero mode](https://developer.arm.com/documentation/dui0473/c/neon-and-vfp-programming/the-effects-of-using-flush-to-zero-mode?lang=en) -/// 2. [When to use Flush-To-Zero mode](https://developer.arm.com/documentation/dui0473/c/neon-and-vfp-programming/when-to-use-flush-to-zero-mode?lang=en) -pub fn keep_denormals() { - #[cfg(all(target_arch = "x86", target_feature = "sse"))] - { - use std::arch::x86::{_MM_FLUSH_ZERO_OFF, _MM_SET_FLUSH_ZERO_MODE}; - unsafe { _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_OFF) } - } - - #[cfg(all(target_arch = "x86_64", target_feature = "sse"))] - { - use std::arch::x86_64::{_MM_FLUSH_ZERO_OFF, _MM_SET_FLUSH_ZERO_MODE}; - unsafe { _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_OFF) } - } -} - #[cfg(test)] pub(crate) mod tests { pub use num_traits::{Float, NumCast, Zero}; diff --git a/dfdx/examples/12-mnist.rs b/dfdx/examples/12-mnist.rs index 705d14c8..00d43452 100644 --- a/dfdx/examples/12-mnist.rs +++ b/dfdx/examples/12-mnist.rs @@ -62,9 +62,6 @@ type Mlp = ( const BATCH_SIZE: usize = 32; fn main() { - // ftz substantially improves performance - dfdx::flush_denormals_to_zero(); - let mnist_path = std::env::args() .nth(1) .unwrap_or_else(|| "./datasets/MNIST/raw".to_string()); From fb91f13314fb24a67c2d8e14ad40345d2d334805 Mon Sep 17 00:00:00 2001 From: rainiwu Date: Fri, 26 Jan 2024 00:55:48 -0800 Subject: [PATCH 2/8] suppress spurious cargo clippy warning --- dfdx-core/src/data/collate.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/dfdx-core/src/data/collate.rs b/dfdx-core/src/data/collate.rs index d38a2a67..5f52d636 100644 --- a/dfdx-core/src/data/collate.rs +++ b/dfdx-core/src/data/collate.rs @@ -55,6 +55,7 @@ impl Collate for Vec<(A, B)> { impl<'a, A, B> Collate for Vec<&'a (A, B)> { type Collated = (Vec<&'a A>, Vec<&'a B>); fn collated(self) -> Self::Collated { + #[allow(clippy::map_identity)] self.into_iter().map(|(a, b)| (a, b)).unzip() } } From 4e3f7c7a24728668f72cf3617a66f4476280f6fb Mon Sep 17 00:00:00 2001 From: Thiago Machado Date: Tue, 6 Feb 2024 18:27:46 -0500 Subject: [PATCH 3/8] avoid conv1d bound for cudnn --- dfdx-core/src/tensor_ops/utilities/device.rs | 50 +++++++++++++++----- 1 file changed, 39 insertions(+), 11 deletions(-) diff --git a/dfdx-core/src/tensor_ops/utilities/device.rs b/dfdx-core/src/tensor_ops/utilities/device.rs index 8cbc2137..91f87cf6 100644 --- a/dfdx-core/src/tensor_ops/utilities/device.rs +++ b/dfdx-core/src/tensor_ops/utilities/device.rs @@ -114,25 +114,49 @@ pub trait Device: + crate::tensor_ops::axpy::AxpyKernel // conv1d - + super::super::conv1d::Conv1DKernel + + NonCudnnCuda +{ +} + +#[cfg(feature = "cudnn")] +pub trait NonCudnnCuda {} + +#[cfg(not(feature = "cudnn"))] +pub trait NonCudnnCuda: + // conv1d + super::super::conv1d::Conv1DKernel { } #[cfg(feature = "f16")] -impl Device for crate::tensor::Cpu {} -#[cfg(feature = "f16")] -impl Device> for crate::tensor::Cpu {} +mod f16_ { + use super::*; + impl Device for crate::tensor::Cpu {} + impl NonCudnnCuda for crate::tensor::Cpu {} + impl Device> for crate::tensor::Cpu {} + impl NonCudnnCuda> for crate::tensor::Cpu {} +} impl Device for crate::tensor::Cpu {} +impl NonCudnnCuda for crate::tensor::Cpu {} impl Device for crate::tensor::Cpu {} +impl NonCudnnCuda for crate::tensor::Cpu {} #[cfg(all(feature = "cuda", feature = "f16"))] -impl Device for crate::tensor::Cuda {} -#[cfg(all(feature = "cuda", feature = "f16"))] -impl Device> for crate::tensor::Cuda {} -#[cfg(feature = "cuda")] -impl Device for crate::tensor::Cuda {} +mod cuda_f16 { + use super::*; + impl Device for crate::tensor::Cuda {} + impl NonCudnnCuda for crate::tensor::Cuda {} + impl Device> for crate::tensor::Cuda {} + impl NonCudnnCuda> for crate::tensor::Cuda {} +} #[cfg(feature = "cuda")] -impl Device for crate::tensor::Cuda {} +mod cuda { + use super::*; + impl Device for crate::tensor::Cuda {} + impl NonCudnnCuda for crate::tensor::Cuda {} + impl Device for crate::tensor::Cuda {} + impl NonCudnnCuda for crate::tensor::Cuda {} +} // TODO: How can we implement this for f16 when WGSL doesn't support f16 yet? // #[cfg(all(feature = "webgpu", feature = "f16"))] @@ -140,7 +164,11 @@ impl Device for crate::tensor::Cuda {} // #[cfg(all(feature = "webgpu", feature = "f16"))] // impl Device> for crate::tensor::Webgpu {} #[cfg(feature = "webgpu")] -impl Device for crate::tensor::Webgpu {} +mod webgpu { + use super::*; + impl Device for crate::tensor::Webgpu {} + impl NonCudnnCuda for crate::tensor::Webgpu {} +} // TODO: How can we implement this for f64 when WGSL doesn't support f64 yet? // #[cfg(feature = "webgpu")] From a8bc54c5c8e02c68fe09e72fc94ba0a8b3273b9a Mon Sep 17 00:00:00 2001 From: Thiago Machado Date: Fri, 9 Feb 2024 11:53:40 -0500 Subject: [PATCH 4/8] bump gemm --- dfdx-core/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dfdx-core/Cargo.toml b/dfdx-core/Cargo.toml index 5309ef7c..0f6cd5c6 100644 --- a/dfdx-core/Cargo.toml +++ b/dfdx-core/Cargo.toml @@ -35,7 +35,7 @@ num-traits = { workspace = true } safetensors = { workspace = true, optional = true } memmap2 = { workspace = true, optional = true } half = { version = "2.3.1", optional = true, features = ["num-traits", "rand_distr"] } -gemm = { version = "0.16.14", default-features = false, optional = true, features = ["rayon"] } +gemm = { version = "0.17.1", default-features = false, optional = true, features = ["rayon"] } rayon = { version = "1.7.0", optional = true } libm = { workspace = true } wgpu = { version = "0.18.0", features = ["glsl", "spirv"], optional = true } From 557687c0a9e29dfba2311fe67414863c6c5137bf Mon Sep 17 00:00:00 2001 From: Thiago Machado Date: Fri, 9 Feb 2024 12:52:05 -0500 Subject: [PATCH 5/8] clippy fix --- dfdx-core/src/tensor/gradients.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dfdx-core/src/tensor/gradients.rs b/dfdx-core/src/tensor/gradients.rs index 86974ec6..d24e2e32 100644 --- a/dfdx-core/src/tensor/gradients.rs +++ b/dfdx-core/src/tensor/gradients.rs @@ -153,7 +153,7 @@ impl> Gradients { #[inline] pub(crate) fn many_and_ref( &mut self, - ls: &Vec>, + ls: &[impl Tensorlike], r: &impl Tensorlike, ) -> (Vec<&mut D::Vec>, &D::Vec) { for i in 0..ls.len() { From d971e905295203307ed29dbca4638d5c672aaa75 Mon Sep 17 00:00:00 2001 From: Thiago Machado Date: Wed, 31 Jan 2024 01:01:29 -0500 Subject: [PATCH 6/8] Update safetensors module and naming - Makes the safetensors module private. - Doesn't get exported on the preamble, avoiding a naming clash with the safetensors external crate. - Change how and when the period is inserted. - This should make it closer to how the fields are accessed in the code. --- dfdx-core/src/nn_traits/tuples.rs | 4 ++-- dfdx-core/src/nn_traits/vecs.rs | 4 ++-- dfdx-core/src/tensor/mod.rs | 2 +- dfdx-derives/src/lib.rs | 20 ++++++++++++++++---- 4 files changed, 21 insertions(+), 9 deletions(-) diff --git a/dfdx-core/src/nn_traits/tuples.rs b/dfdx-core/src/nn_traits/tuples.rs index 97e8c7de..205c0419 100644 --- a/dfdx-core/src/nn_traits/tuples.rs +++ b/dfdx-core/src/nn_traits/tuples.rs @@ -25,7 +25,7 @@ macro_rules! tuple_impls { location: &str, tensors: &mut Vec<(String, safetensors::Dtype, Vec, Vec)>, ) { - $(self.$idx.write_safetensors(&format!("{location}{}.", $idx), tensors);)+ + $(self.$idx.write_safetensors(&format!("{location}.{}", $idx), tensors);)+ } } @@ -36,7 +36,7 @@ macro_rules! tuple_impls { location: &str, tensors: &safetensors::SafeTensors, ) -> Result<(), safetensors::SafeTensorError> { - $(self.$idx.read_safetensors(&format!("{location}{}.", $idx), tensors)?;)+ + $(self.$idx.read_safetensors(&format!("{location}.{}", $idx), tensors)?;)+ Ok(()) } } diff --git a/dfdx-core/src/nn_traits/vecs.rs b/dfdx-core/src/nn_traits/vecs.rs index 803a07d8..593b1a55 100644 --- a/dfdx-core/src/nn_traits/vecs.rs +++ b/dfdx-core/src/nn_traits/vecs.rs @@ -66,7 +66,7 @@ impl crate::nn_traits::SaveSafeTensors for tensors: &mut Vec<(String, safetensors::Dtype, Vec, Vec)>, ) { for (i, t) in self.iter().enumerate() { - t.write_safetensors(&format!("{location}{i}."), tensors); + t.write_safetensors(&format!("{location}.{i}"), tensors); } } } @@ -79,7 +79,7 @@ impl crate::nn_traits::LoadSafeTensors for tensors: &safetensors::SafeTensors, ) -> Result<(), safetensors::SafeTensorError> { for (i, t) in self.iter_mut().enumerate() { - t.read_safetensors(&format!("{location}{i}."), tensors)?; + t.read_safetensors(&format!("{location}.{i}"), tensors)?; } Ok(()) } diff --git a/dfdx-core/src/tensor/mod.rs b/dfdx-core/src/tensor/mod.rs index acc4074a..0163480a 100644 --- a/dfdx-core/src/tensor/mod.rs +++ b/dfdx-core/src/tensor/mod.rs @@ -151,7 +151,7 @@ pub(crate) mod webgpu; pub use numpy::NumpyDtype; mod error; #[cfg(feature = "safetensors")] -pub mod safetensors; +mod safetensors; mod tensorlike; mod unique_id; diff --git a/dfdx-derives/src/lib.rs b/dfdx-derives/src/lib.rs index 4eca0d82..7af885f9 100644 --- a/dfdx-derives/src/lib.rs +++ b/dfdx-derives/src/lib.rs @@ -850,7 +850,10 @@ pub fn save_safetensors(input: proc_macro::TokenStream) -> proc_macro::TokenStre where_clause .predicates .push(parse_quote!(#ty: ::dfdx::nn_traits::SaveSafeTensors)); - quote_spanned!(f.span()=>self.#name.write_safetensors(&format!("{location}{}", #name_str), tensors);) + quote_spanned!(f.span()=>self.#name.write_safetensors( + &format!("{location}{}{}", if location.is_empty() { "" } else { "." }, #name_str), + tensors + );) } else { Default::default() } @@ -866,7 +869,10 @@ pub fn save_safetensors(input: proc_macro::TokenStream) -> proc_macro::TokenStre where_clause .predicates .push(parse_quote!(#ty: ::dfdx::nn_traits::SaveSafeTensors)); - quote_spanned!(f.span()=>self.#index.write_safetensors(&format!("{location}{}", #index), tensors);) + quote_spanned!(f.span()=>self.#index.write_safetensors( + &format!("{location}{}{}", if location.is_empty() { "" } else { "." }, #index), + tensors + );) } else { Default::default() } @@ -913,7 +919,10 @@ pub fn load_safetensors(input: proc_macro::TokenStream) -> proc_macro::TokenStre where_clause .predicates .push(parse_quote!(#ty: ::dfdx::nn_traits::LoadSafeTensors)); - quote_spanned!(f.span()=>self.#name.read_safetensors(&format!("{location}{}", #name_str), tensors)?;) + quote_spanned!(f.span()=>self.#name.read_safetensors( + &format!("{location}{}{}", if location.is_empty() { "" } else { "." }, #name_str), + tensors + )?;) } else { Default::default() } @@ -928,7 +937,10 @@ pub fn load_safetensors(input: proc_macro::TokenStream) -> proc_macro::TokenStre where_clause .predicates .push(parse_quote!(#ty: ::dfdx::nn_traits::LoadSafeTensors)); - quote_spanned!(f.span()=>self.#index.read_safetensors(&format!("{location}{}", #index), tensors)?;) + quote_spanned!(f.span()=>self.#index.read_safetensors( + &format!("{location}{}{}", if location.is_empty() { "" } else { "." }, #index), + tensors + )?;) } else { Default::default() } From 4b9824ed4c9bbee499789dbb570fa93933200c97 Mon Sep 17 00:00:00 2001 From: Thiago Machado Date: Fri, 9 Feb 2024 11:29:05 -0500 Subject: [PATCH 7/8] Added {load/read/save/write}_safetensor_with methods This alternative method: - Requires load/read to decide whether it should skip missing tensors; - Requires load/read/save/write to decide how should keys be mapped. --- dfdx-core/src/nn_traits/mod.rs | 75 +++++++++++++++++++++++------ dfdx-core/src/nn_traits/tuples.rs | 17 +++++-- dfdx-core/src/nn_traits/vecs.rs | 13 +++-- dfdx-core/src/tensor/safetensors.rs | 13 ++++- dfdx-derives/src/lib.rs | 36 +++++++++----- 5 files changed, 118 insertions(+), 36 deletions(-) diff --git a/dfdx-core/src/nn_traits/mod.rs b/dfdx-core/src/nn_traits/mod.rs index 20c55da2..52203373 100644 --- a/dfdx-core/src/nn_traits/mod.rs +++ b/dfdx-core/src/nn_traits/mod.rs @@ -116,12 +116,13 @@ pub trait ZeroGrads> { #[cfg(feature = "safetensors")] /// Something that can be saved to a .safetensors file. pub trait SaveSafeTensors { - fn save_safetensors>( + fn save_safetensors_with, F: FnMut(String) -> String>( &self, path: P, + key_map: &mut F, ) -> Result<(), safetensors::SafeTensorError> { let mut tensors = Vec::new(); - self.write_safetensors("", &mut tensors); + self.write_safetensors_with("", &mut tensors, key_map); let data = tensors.iter().map(|(k, dtype, shape, data)| { ( k.clone(), @@ -131,53 +132,88 @@ pub trait SaveSafeTensors { safetensors::serialize_to_file(data, &None, path.as_ref()) } - fn write_safetensors( + fn save_safetensors>( + &self, + path: P, + ) -> Result<(), safetensors::SafeTensorError> { + self.save_safetensors_with(path, &mut core::convert::identity) + } + fn write_safetensors_with String>( &self, location: &str, tensors: &mut Vec<(String, safetensors::Dtype, Vec, Vec)>, + key_map: &mut F, ); + fn write_safetensors( + &self, + location: &str, + tensors: &mut Vec<(String, safetensors::Dtype, Vec, Vec)>, + ) { + self.write_safetensors_with(location, tensors, &mut core::convert::identity) + } } #[cfg(feature = "safetensors")] /// Something that can be loaded from a .safetensors file. pub trait LoadSafeTensors { - fn load_safetensors>( + fn load_safetensors_with, F: FnMut(String) -> String>( &mut self, path: P, + skip_missing: bool, + key_map: &mut F, ) -> Result<(), safetensors::SafeTensorError> { let f = std::fs::File::open(path)?; let buffer = unsafe { memmap2::MmapOptions::new().map(&f)? }; let tensors = safetensors::SafeTensors::deserialize(&buffer)?; - self.read_safetensors("", &tensors) + self.read_safetensors_with("", &tensors, skip_missing, key_map) + } + fn load_safetensors>( + &mut self, + path: P, + ) -> Result<(), safetensors::SafeTensorError> { + self.load_safetensors_with(path, false, &mut core::convert::identity) } - fn read_safetensors( + fn read_safetensors_with String>( &mut self, location: &str, tensors: &safetensors::SafeTensors, + skip_missing: bool, + key_map: &mut F, ) -> Result<(), safetensors::SafeTensorError>; + fn read_safetensors( + &mut self, + location: &str, + tensors: &safetensors::SafeTensors, + ) -> Result<(), safetensors::SafeTensorError> { + self.read_safetensors_with(location, tensors, false, &mut core::convert::identity) + } } #[cfg(feature = "safetensors")] impl, T> LoadSafeTensors for Tensor { - fn read_safetensors( + fn read_safetensors_with String>( &mut self, location: &str, tensors: &safetensors::SafeTensors, + skip_missing: bool, + key_map: &mut F, ) -> Result<(), safetensors::SafeTensorError> { - self.load_safetensor(tensors, location) + self.load_safetensor(tensors, location, skip_missing, key_map) } } #[cfg(feature = "safetensors")] impl, T> SaveSafeTensors for Tensor { - fn write_safetensors( + fn write_safetensors_with String>( &self, location: &str, tensors: &mut Vec<(String, safetensors::Dtype, Vec, Vec)>, + key_map: &mut F, ) { + let location = key_map(location.to_string()); tensors.push(( - location.to_string(), + location, ::DTYPE, self.shape.concrete().into(), self.as_vec().iter().flat_map(|e| e.to_le_bytes()).collect(), @@ -189,15 +225,17 @@ macro_rules! unit_safetensors { ($Ty:ty) => { #[cfg(feature = "safetensors")] impl SaveSafeTensors for $Ty { - fn write_safetensors( + fn write_safetensors_with String>( &self, location: &str, tensors: &mut Vec<(String, safetensors::Dtype, Vec, Vec)>, + key_map: &mut F, ) { + let location = key_map(location.to_string()); #[allow(unused_imports)] use crate::dtypes::ToLeBytes; tensors.push(( - location.to_string(), + location, <$Ty as crate::dtypes::SafeTensorsDtype>::DTYPE, Vec::new(), self.to_le_bytes().to_vec(), @@ -207,14 +245,23 @@ macro_rules! unit_safetensors { #[cfg(feature = "safetensors")] impl LoadSafeTensors for $Ty { - fn read_safetensors( + fn read_safetensors_with String>( &mut self, location: &str, tensors: &safetensors::SafeTensors, + skip_missing: bool, + key_map: &mut F, ) -> Result<(), safetensors::SafeTensorError> { + let location = key_map(location.to_string()); #[allow(unused_imports)] use crate::dtypes::FromLeBytes; - let view = tensors.tensor(location)?; + let view = match tensors.tensor(&location) { + Ok(ok) => ok, + Err(safetensors::SafeTensorError::TensorNotFound(_name)) if skip_missing => { + return Ok(()); + } + Err(e) => return Err(e), + }; *self = Self::from_le_bytes(view.data().try_into().unwrap()); Ok(()) } diff --git a/dfdx-core/src/nn_traits/tuples.rs b/dfdx-core/src/nn_traits/tuples.rs index 205c0419..7f267482 100644 --- a/dfdx-core/src/nn_traits/tuples.rs +++ b/dfdx-core/src/nn_traits/tuples.rs @@ -20,23 +20,32 @@ macro_rules! tuple_impls { #[cfg(feature = "safetensors")] impl<$($name: crate::nn_traits::SaveSafeTensors, )+> crate::nn_traits::SaveSafeTensors for ($($name,)+) { - fn write_safetensors( + fn write_safetensors_with String>( &self, location: &str, tensors: &mut Vec<(String, safetensors::Dtype, Vec, Vec)>, + key_map: &mut F, ) { - $(self.$idx.write_safetensors(&format!("{location}.{}", $idx), tensors);)+ + $( + let name = &format!("{location}.{}", $idx); + self.$idx.write_safetensors_with(name, tensors, key_map); + )+ } } #[cfg(feature = "safetensors")] impl<$($name: crate::nn_traits::LoadSafeTensors, )+> crate::nn_traits::LoadSafeTensors for ($($name,)+) { - fn read_safetensors( + fn read_safetensors_with String>( &mut self, location: &str, tensors: &safetensors::SafeTensors, + skip_missing: bool, + key_map: &mut F, ) -> Result<(), safetensors::SafeTensorError> { - $(self.$idx.read_safetensors(&format!("{location}.{}", $idx), tensors)?;)+ + $( + let name = &format!("{location}.{}", $idx); + self.$idx.read_safetensors_with(name, tensors, skip_missing, key_map)?; + )+ Ok(()) } } diff --git a/dfdx-core/src/nn_traits/vecs.rs b/dfdx-core/src/nn_traits/vecs.rs index 593b1a55..201dd932 100644 --- a/dfdx-core/src/nn_traits/vecs.rs +++ b/dfdx-core/src/nn_traits/vecs.rs @@ -60,26 +60,31 @@ impl, T: crate::nn_traits::ZeroGrads> crate::nn_tra #[cfg(feature = "safetensors")] impl crate::nn_traits::SaveSafeTensors for Vec { - fn write_safetensors( + fn write_safetensors_with String>( &self, location: &str, tensors: &mut Vec<(String, safetensors::Dtype, Vec, Vec)>, + key_map: &mut F, ) { for (i, t) in self.iter().enumerate() { - t.write_safetensors(&format!("{location}.{i}"), tensors); + let name = &format!("{location}.{i}"); + t.write_safetensors_with(name, tensors, key_map); } } } #[cfg(feature = "safetensors")] impl crate::nn_traits::LoadSafeTensors for Vec { - fn read_safetensors( + fn read_safetensors_with String>( &mut self, location: &str, tensors: &safetensors::SafeTensors, + skip_missing: bool, + key_map: &mut F, ) -> Result<(), safetensors::SafeTensorError> { for (i, t) in self.iter_mut().enumerate() { - t.read_safetensors(&format!("{location}.{i}"), tensors)?; + let name = &format!("{location}.{i}"); + t.read_safetensors_with(name, tensors, skip_missing, key_map)?; } Ok(()) } diff --git a/dfdx-core/src/tensor/safetensors.rs b/dfdx-core/src/tensor/safetensors.rs index c0566c40..626eaeaa 100644 --- a/dfdx-core/src/tensor/safetensors.rs +++ b/dfdx-core/src/tensor/safetensors.rs @@ -5,12 +5,21 @@ use std::vec::Vec; impl, T> Tensor { /// Loads data from the [SafeTensors] `Storage` with the given `key` - pub fn load_safetensor( + pub fn load_safetensor String>( &mut self, tensors: &SafeTensors, key: &str, + skip_missing: bool, + key_map: &mut F, ) -> Result<(), SafeTensorError> { - let tensor_view = tensors.tensor(key)?; + let key = key_map(key.to_string()); + let tensor_view = match tensors.tensor(&key) { + Ok(ok) => ok, + Err(safetensors::SafeTensorError::TensorNotFound(_name)) if skip_missing => { + return Ok(()); + } + Err(e) => return Err(e), + }; let v = tensor_view.data(); let num_bytes = std::mem::size_of::(); assert_eq!( diff --git a/dfdx-derives/src/lib.rs b/dfdx-derives/src/lib.rs index 7af885f9..3c68fcb3 100644 --- a/dfdx-derives/src/lib.rs +++ b/dfdx-derives/src/lib.rs @@ -196,18 +196,21 @@ pub fn custom_module(input: proc_macro::TokenStream) -> proc_macro::TokenStream let safetensors_impls = if cfg!(feature = "safetensors") { quote! { impl #built_impl ::dfdx::nn_traits::SaveSafeTensors for #builder_name #built_ty #built_where { - fn write_safetensors( + fn write_safetensors_with String>( &self, location: &str, tensors: &mut Vec<(String, ::dfdx::safetensors::Dtype, Vec, Vec)>, + key_map: &mut KeyMap, ) {} } impl #built_impl ::dfdx::nn_traits::LoadSafeTensors for #builder_name #built_ty #built_where { - fn read_safetensors<'a>( + fn read_safetensors_with<'a, KeyMap: FnMut(String) -> String>( &mut self, location: &str, tensors: &::dfdx::safetensors::SafeTensors<'a>, + skip_missing: bool, + key_map: &mut KeyMap, ) -> Result<(), ::dfdx::safetensors::SafeTensorError> { Ok(()) } @@ -850,9 +853,10 @@ pub fn save_safetensors(input: proc_macro::TokenStream) -> proc_macro::TokenStre where_clause .predicates .push(parse_quote!(#ty: ::dfdx::nn_traits::SaveSafeTensors)); - quote_spanned!(f.span()=>self.#name.write_safetensors( + quote_spanned!(f.span()=>self.#name.write_safetensors_with( &format!("{location}{}{}", if location.is_empty() { "" } else { "." }, #name_str), - tensors + tensors, + key_map );) } else { Default::default() @@ -869,9 +873,10 @@ pub fn save_safetensors(input: proc_macro::TokenStream) -> proc_macro::TokenStre where_clause .predicates .push(parse_quote!(#ty: ::dfdx::nn_traits::SaveSafeTensors)); - quote_spanned!(f.span()=>self.#index.write_safetensors( + quote_spanned!(f.span()=>self.#index.write_safetensors_with( &format!("{location}{}{}", if location.is_empty() { "" } else { "." }, #index), - tensors + tensors, + key_map );) } else { Default::default() @@ -890,10 +895,11 @@ pub fn save_safetensors(input: proc_macro::TokenStream) -> proc_macro::TokenStre proc_macro::TokenStream::from(quote! { // note: SaveSafeTensors definition is already gated by the safetensors feature impl #impl_generics ::dfdx::nn_traits::SaveSafeTensors for #name #ty_generics #where_clause { - fn write_safetensors( + fn write_safetensors_with String>( &self, location: &str, tensors: &mut Vec<(String, ::dfdx::safetensors::Dtype, Vec, Vec)>, + key_map: &mut KeyMap, ) { #save_fields } @@ -919,9 +925,11 @@ pub fn load_safetensors(input: proc_macro::TokenStream) -> proc_macro::TokenStre where_clause .predicates .push(parse_quote!(#ty: ::dfdx::nn_traits::LoadSafeTensors)); - quote_spanned!(f.span()=>self.#name.read_safetensors( + quote_spanned!(f.span()=>self.#name.read_safetensors_with( &format!("{location}{}{}", if location.is_empty() { "" } else { "." }, #name_str), - tensors + tensors, + skip_missing, + key_map )?;) } else { Default::default() @@ -937,9 +945,11 @@ pub fn load_safetensors(input: proc_macro::TokenStream) -> proc_macro::TokenStre where_clause .predicates .push(parse_quote!(#ty: ::dfdx::nn_traits::LoadSafeTensors)); - quote_spanned!(f.span()=>self.#index.read_safetensors( + quote_spanned!(f.span()=>self.#index.read_safetensors_with( &format!("{location}{}{}", if location.is_empty() { "" } else { "." }, #index), - tensors + tensors, + skip_missing, + key_map )?;) } else { Default::default() @@ -958,10 +968,12 @@ pub fn load_safetensors(input: proc_macro::TokenStream) -> proc_macro::TokenStre proc_macro::TokenStream::from(quote! { // note: LoadSafeTensors definition is already gated by the safetensors feature impl #impl_generics ::dfdx::nn_traits::LoadSafeTensors for #name #ty_generics #where_clause { - fn read_safetensors<'a>( + fn read_safetensors_with<'a, KeyMap: FnMut(String) -> String>( &mut self, location: &str, tensors: &::dfdx::safetensors::SafeTensors<'a>, + skip_missing: bool, + key_map: &mut KeyMap, ) -> Result<(), ::dfdx::safetensors::SafeTensorError> { #load_fields Ok(()) From 95e619f3a4d55fc97b668021d34fd732a5dfdc36 Mon Sep 17 00:00:00 2001 From: Thiago Machado Date: Mon, 19 Feb 2024 21:42:32 -0500 Subject: [PATCH 8/8] allow to load safetensors from a byte array --- dfdx-core/src/nn_traits/mod.rs | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/dfdx-core/src/nn_traits/mod.rs b/dfdx-core/src/nn_traits/mod.rs index 52203373..869e1047 100644 --- a/dfdx-core/src/nn_traits/mod.rs +++ b/dfdx-core/src/nn_traits/mod.rs @@ -173,6 +173,21 @@ pub trait LoadSafeTensors { ) -> Result<(), safetensors::SafeTensorError> { self.load_safetensors_with(path, false, &mut core::convert::identity) } + fn load_safetensors_from_bytes_with String>( + &mut self, + bytes: &[u8], + skip_missing: bool, + key_map: &mut F, + ) -> Result<(), safetensors::SafeTensorError> { + let tensors = safetensors::SafeTensors::deserialize(&bytes)?; + self.read_safetensors_with("", &tensors, skip_missing, key_map) + } + fn load_safetensors_from_bytes( + &mut self, + bytes: &[u8], + ) -> Result<(), safetensors::SafeTensorError> { + self.load_safetensors_from_bytes_with(bytes, false, &mut core::convert::identity) + } fn read_safetensors_with String>( &mut self,