Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactors TryMatmuls for better error messages when dimensions don't match #681

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/02-ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
use dfdx::{
shapes::{Rank0, Rank1, Rank2},
tensor::{AsArray, AutoDevice, SampleTensor, Tensor},
tensor_ops::{MeanTo, TryMatMul},
tensor_ops::{MeanTo, TryStaticMatMul},
};

fn main() {
Expand Down
2 changes: 1 addition & 1 deletion examples/03-nn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ fn main() {

// Even dynamic size is supported;
let batch_size = 3;
let _: Tensor<(usize, Const<2>), f32, _> = m.forward(dev.zeros_like(&(batch_size, Const)));
let _: Tensor<(usize, Const<2>), f32, _> = m.forward(dev.zeros_like(&(batch_size, Const::<2>)));

// you can also combine multiple modules with tuples
type Mlp = (Linear<4, 2>, ReLU, Linear<2, 1>);
Expand Down
2 changes: 1 addition & 1 deletion examples/04-gradients.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use dfdx::{
nn::ZeroGrads,
shapes::{Rank0, Rank2},
tensor::{AsArray, AutoDevice, Gradients, NoneTape, OwnedTape, SampleTensor, Tensor, Trace},
tensor_ops::{Backward, MeanTo, TryMatMul},
tensor_ops::{Backward, MeanTo, TryStaticMatMul},
};

fn main() {
Expand Down
60 changes: 59 additions & 1 deletion src/nn/linear.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,59 @@ where
}
}

pub trait AssertLayerMatch<Rhs: Shape> {
const TYPE_CHECK: ();
fn assert_dim_eq(&self);
}

impl<const M: usize, const I: usize, const O: usize> AssertLayerMatch<(Const<I>, Const<O>)>
for Rank1<M>
{
const TYPE_CHECK: () = assert!(
M == I,
"You are trying to stack tensors, whose outgoing and ingoing dimensions do not match",
);
fn assert_dim_eq(&self) {
let _ = <Self as AssertLayerMatch<(Const<I>, Const<O>)>>::TYPE_CHECK;
}
}

impl<IN, const OUT: usize, const I: usize, const O: usize> AssertLayerMatch<(Const<I>, Const<O>)>
for (IN, Const<OUT>)
{
const TYPE_CHECK: () = assert!(
OUT == I,
"You are trying to stack tensors, whose outgoing and ingoing dimensions do not match",
);
fn assert_dim_eq(&self) {
let _ = <Self as AssertLayerMatch<(Const<I>, Const<O>)>>::TYPE_CHECK;
}
}

impl<B, IN, const OUT: usize, const I: usize, const O: usize> AssertLayerMatch<(Const<I>, Const<O>)>
for (B, IN, Const<OUT>)
{
const TYPE_CHECK: () = assert!(
OUT == I,
"You are trying to stack tensors, whose outgoing and ingoing dimensions do not match",
);
fn assert_dim_eq(&self) {
let _ = <Self as AssertLayerMatch<(Const<I>, Const<O>)>>::TYPE_CHECK;
}
}

// impl<S1: Dim, const O: usize, const IN: usize, const OUT: usize> AssertLayerMatch<Rank2<S1, O>>
// for Rank2<IN, OUT>
// {
// const TYPE_CHECK: () = assert!(
// OUT == I,
// "You are trying to stack tensors, whose outgoing and ingoing dimensions do not match {I}",
// );
// fn assert_dim_eq(&self) {
// let _ = <Self as AssertLayerMatch<Rank2<I, O>>>::TYPE_CHECK;
// }
// }

/// A linear transformation of the form `weight * x + bias`, where `weight` is a matrix, `x` is a vector or matrix,
/// and `bias` is a vector.
///
Expand Down Expand Up @@ -92,15 +145,20 @@ impl<const I: usize, const O: usize, E: Dtype, D: Device<E>> TensorCollection<E,

impl<const I: usize, const O: usize, E: Dtype, D: Device<E>, T> Module<T> for Linear<I, O, E, D>
where
T: SplitTape + TryMatMul<Tensor<Rank2<I, O>, E, D, T::Tape>> + HasErr<Err = D::Err>,
T: SplitTape
+ TryStaticMatMul<Tensor<Rank2<I, O>, E, D, T::Tape>>
+ HasErr<Err = D::Err>
+ HasShape,
T::Tape: Tape<E, D>,
T::Shape: AssertLayerMatch<Rank2<I, O>>,
for<'a> Bias1D<'a, O, E, D>: Module<T::Output, Output = T::Output, Error = D::Err>,
{
type Output = T::Output;
type Error = D::Err;

/// 1d forward using [matmul()] and [add()].
fn try_forward(&self, x: T) -> Result<Self::Output, D::Err> {
x.shape().assert_dim_eq();
let o = x.try_matmul(self.weight.retaped::<T::Tape>().try_permute()?)?;
Bias1D { beta: &self.bias }.try_forward(o)
}
Expand Down
8 changes: 4 additions & 4 deletions src/nn/transformer/mha.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,11 @@ where

// Get weights
let scalar: E = E::ONE / E::from_usize(K / H).unwrap().sqrt();
let weights = q.try_matmul(k)?.try_mul(scalar)?;
let weights = q.try_dynamic_matmul(k)?.try_mul(scalar)?;
let weights = weights.try_softmax::<Axis<2>>()?;

// Get new tokens
let tokens = weights.try_matmul(v)?;
let tokens = weights.try_dynamic_matmul(v)?;
let tokens = tokens.try_permute::<_, Axes3<1, 0, 2>>()?;
let tokens = tokens.try_reshape_like(&(s1, Const::<V>)).unwrap()?;

Expand Down Expand Up @@ -187,11 +187,11 @@ where

// Get weights
let scalar: E = E::ONE / E::from_usize(K / H).unwrap().sqrt();
let weights = q.try_matmul(k)?.try_mul(scalar)?;
let weights = q.try_dynamic_matmul(k)?.try_mul(scalar)?;
let weights = weights.try_softmax::<Axis<3>>()?;

// Get new tokens
let tokens = weights.try_matmul(v)?;
let tokens = weights.try_dynamic_matmul(v)?;
let tokens = tokens.try_permute::<_, Axes4<0, 2, 1, 3>>()?;
let tokens = tokens.try_reshape_like(&(b, s1, Const::<V>)).unwrap()?;

Expand Down
2 changes: 1 addition & 1 deletion src/nn/unbiased_linear.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ impl<const I: usize, const O: usize, E: Dtype + Float + SampleUniform, D: Device
impl<const I: usize, const O: usize, E: Dtype, D: Device<E>, T> Module<T>
for UnbiasedLinear<I, O, E, D>
where
T: SplitTape + TryMatMul<Tensor<Rank2<I, O>, E, D, T::Tape>> + HasErr<Err = D::Err>,
T: SplitTape + TryStaticMatMul<Tensor<Rank2<I, O>, E, D, T::Tape>> + HasErr<Err = D::Err>,
T::Tape: Tape<E, D>,
{
type Output = T::Output;
Expand Down
Loading