diff --git a/examples/02-ops.rs b/examples/02-ops.rs index bba44fa61..240663dad 100644 --- a/examples/02-ops.rs +++ b/examples/02-ops.rs @@ -3,7 +3,7 @@ use dfdx::{ shapes::{Rank0, Rank1, Rank2}, tensor::{AsArray, AutoDevice, SampleTensor, Tensor}, - tensor_ops::{MeanTo, TryMatMul}, + tensor_ops::{MeanTo, TryStaticMatMul}, }; fn main() { diff --git a/examples/03-nn.rs b/examples/03-nn.rs index c3a532471..97c633064 100644 --- a/examples/03-nn.rs +++ b/examples/03-nn.rs @@ -33,7 +33,7 @@ fn main() { // Even dynamic size is supported; let batch_size = 3; - let _: Tensor<(usize, Const<2>), f32, _> = m.forward(dev.zeros_like(&(batch_size, Const))); + let _: Tensor<(usize, Const<2>), f32, _> = m.forward(dev.zeros_like(&(batch_size, Const::<2>))); // you can also combine multiple modules with tuples type Mlp = (Linear<4, 2>, ReLU, Linear<2, 1>); diff --git a/examples/04-gradients.rs b/examples/04-gradients.rs index b7357a18c..1975d9fb0 100644 --- a/examples/04-gradients.rs +++ b/examples/04-gradients.rs @@ -4,7 +4,7 @@ use dfdx::{ nn::ZeroGrads, shapes::{Rank0, Rank2}, tensor::{AsArray, AutoDevice, Gradients, NoneTape, OwnedTape, SampleTensor, Tensor, Trace}, - tensor_ops::{Backward, MeanTo, TryMatMul}, + tensor_ops::{Backward, MeanTo, TryStaticMatMul}, }; fn main() { diff --git a/src/nn/linear.rs b/src/nn/linear.rs index f666bd590..5e95537ae 100644 --- a/src/nn/linear.rs +++ b/src/nn/linear.rs @@ -20,6 +20,59 @@ where } } +pub trait AssertLayerMatch { + const TYPE_CHECK: (); + fn assert_dim_eq(&self); +} + +impl AssertLayerMatch<(Const, Const)> + for Rank1 +{ + const TYPE_CHECK: () = assert!( + M == I, + "You are trying to stack tensors, whose outgoing and ingoing dimensions do not match", + ); + fn assert_dim_eq(&self) { + let _ = , Const)>>::TYPE_CHECK; + } +} + +impl AssertLayerMatch<(Const, Const)> + for (IN, Const) +{ + const TYPE_CHECK: () = assert!( + OUT == I, + "You are trying to stack tensors, whose outgoing and ingoing dimensions do not match", + ); + fn assert_dim_eq(&self) { + let _ = , Const)>>::TYPE_CHECK; + } +} + +impl AssertLayerMatch<(Const, Const)> + for (B, IN, Const) +{ + const TYPE_CHECK: () = assert!( + OUT == I, + "You are trying to stack tensors, whose outgoing and ingoing dimensions do not match", + ); + fn assert_dim_eq(&self) { + let _ = , Const)>>::TYPE_CHECK; + } +} + +// impl AssertLayerMatch> +// for Rank2 +// { +// const TYPE_CHECK: () = assert!( +// OUT == I, +// "You are trying to stack tensors, whose outgoing and ingoing dimensions do not match {I}", +// ); +// fn assert_dim_eq(&self) { +// let _ = >>::TYPE_CHECK; +// } +// } + /// A linear transformation of the form `weight * x + bias`, where `weight` is a matrix, `x` is a vector or matrix, /// and `bias` is a vector. /// @@ -92,8 +145,12 @@ impl> TensorCollection, T> Module for Linear where - T: SplitTape + TryMatMul, E, D, T::Tape>> + HasErr, + T: SplitTape + + TryStaticMatMul, E, D, T::Tape>> + + HasErr + + HasShape, T::Tape: Tape, + T::Shape: AssertLayerMatch>, for<'a> Bias1D<'a, O, E, D>: Module, { type Output = T::Output; @@ -101,6 +158,7 @@ where /// 1d forward using [matmul()] and [add()]. fn try_forward(&self, x: T) -> Result { + x.shape().assert_dim_eq(); let o = x.try_matmul(self.weight.retaped::().try_permute()?)?; Bias1D { beta: &self.bias }.try_forward(o) } diff --git a/src/nn/transformer/mha.rs b/src/nn/transformer/mha.rs index a8ffbc952..3433cc2b9 100644 --- a/src/nn/transformer/mha.rs +++ b/src/nn/transformer/mha.rs @@ -127,11 +127,11 @@ where // Get weights let scalar: E = E::ONE / E::from_usize(K / H).unwrap().sqrt(); - let weights = q.try_matmul(k)?.try_mul(scalar)?; + let weights = q.try_dynamic_matmul(k)?.try_mul(scalar)?; let weights = weights.try_softmax::>()?; // Get new tokens - let tokens = weights.try_matmul(v)?; + let tokens = weights.try_dynamic_matmul(v)?; let tokens = tokens.try_permute::<_, Axes3<1, 0, 2>>()?; let tokens = tokens.try_reshape_like(&(s1, Const::)).unwrap()?; @@ -187,11 +187,11 @@ where // Get weights let scalar: E = E::ONE / E::from_usize(K / H).unwrap().sqrt(); - let weights = q.try_matmul(k)?.try_mul(scalar)?; + let weights = q.try_dynamic_matmul(k)?.try_mul(scalar)?; let weights = weights.try_softmax::>()?; // Get new tokens - let tokens = weights.try_matmul(v)?; + let tokens = weights.try_dynamic_matmul(v)?; let tokens = tokens.try_permute::<_, Axes4<0, 2, 1, 3>>()?; let tokens = tokens.try_reshape_like(&(b, s1, Const::)).unwrap()?; diff --git a/src/nn/unbiased_linear.rs b/src/nn/unbiased_linear.rs index 50f1b120a..ce494356e 100644 --- a/src/nn/unbiased_linear.rs +++ b/src/nn/unbiased_linear.rs @@ -79,7 +79,7 @@ impl, T> Module for UnbiasedLinear where - T: SplitTape + TryMatMul, E, D, T::Tape>> + HasErr, + T: SplitTape + TryStaticMatMul, E, D, T::Tape>> + HasErr, T::Tape: Tape, { type Output = T::Output; diff --git a/src/tensor_ops/matmul/cpu_kernel.rs b/src/tensor_ops/matmul/cpu_kernel.rs index e8e235392..800bca9db 100644 --- a/src/tensor_ops/matmul/cpu_kernel.rs +++ b/src/tensor_ops/matmul/cpu_kernel.rs @@ -17,6 +17,8 @@ use cblas_sys::{ ))] use matrixmultiply::{dgemm, sgemm}; +use super::MulStaticDimCheck; + #[cfg(not(any( feature = "cpu-seq-matmul", feature = "cpu-par-matmul", @@ -221,10 +223,10 @@ impl super::VecMatKernel for Cpu where Self: MatMulImpl, { - fn forward( + fn forward, RightK: Dim, N: Dim>( &self, - lhs: &Tensor<(K,), E, Self>, - rhs: &Tensor<(K, N), E, Self>, + lhs: &Tensor<(LeftK,), E, Self>, + rhs: &Tensor<(RightK, N), E, Self>, ) -> Result, Self::Err> { let (k, n) = rhs.shape; let mut out = self.try_zeros_like(&(n,))?; @@ -239,11 +241,11 @@ where ); Ok(out) } - fn backward( + fn backward, RightK: Dim, N: Dim>( &self, - lhs: &Tensor<(K,), E, Self>, + lhs: &Tensor<(LeftK,), E, Self>, grad_lhs: &mut Self::Vec, - rhs: &Tensor<(K, N), E, Self>, + rhs: &Tensor<(RightK, N), E, Self>, grad_rhs: &mut Self::Vec, grad_out: &Self::Vec, ) -> Result<(), Self::Err> { @@ -271,15 +273,16 @@ where } } -impl super::MatMatKernel for Cpu +impl super::StaticMatMatKernel for Cpu where Self: MatMulImpl, { - fn forward( + fn forward( &self, - lhs: &Tensor<(M, K), E, Self>, - rhs: &Tensor<(K, N), E, Self>, + lhs: &Tensor<(M, LeftK), E, Self>, + rhs: &Tensor<(RightK, N), E, Self>, ) -> Result, Self::Err> { + // assert_eq!(lhs.shape.1.size(), rhs.shape.0.size()); let (m, k) = lhs.shape; let n = rhs.shape.1; let mut out = self.try_zeros_like(&(m, n))?; @@ -294,14 +297,15 @@ where ); Ok(out) } - fn backward( + fn backward( &self, - lhs: &Tensor<(M, K), E, Self>, + lhs: &Tensor<(M, LeftK), E, Self>, grad_lhs: &mut Self::Vec, - rhs: &Tensor<(K, N), E, Self>, + rhs: &Tensor<(RightK, N), E, Self>, grad_rhs: &mut Self::Vec, grad_out: &Self::Vec, ) -> Result<(), Self::Err> { + // assert_eq!(lhs.shape.1.size(), rhs.shape.0.size()); let (m, k) = lhs.shape; let n = rhs.shape.1; let strides = (m, n).strides(); @@ -331,10 +335,10 @@ impl super::MatMatBrKernel for Cpu where Self: MatMulImpl, { - fn forward( + fn forward( &self, - lhs: &Tensor<(B, M, K), E, Self>, - rhs: &Tensor<(K, N), E, Self>, + lhs: &Tensor<(B, M, LeftK), E, Self>, + rhs: &Tensor<(RightK, N), E, Self>, ) -> Result, Self::Err> { let (batch, m, k) = lhs.shape; let n = rhs.shape.1; @@ -353,11 +357,11 @@ where } Ok(out) } - fn backward( + fn backward( &self, - lhs: &Tensor<(B, M, K), E, Self>, + lhs: &Tensor<(B, M, LeftK), E, Self>, grad_lhs: &mut Self::Vec, - rhs: &Tensor<(K, N), E, Self>, + rhs: &Tensor<(RightK, N), E, Self>, grad_rhs: &mut Self::Vec, grad_out: &Self::Vec, ) -> Result<(), Self::Err> { @@ -388,14 +392,14 @@ where } } -impl super::MatMatBatch3Kernel for Cpu +impl super::StaticMatMatBatch3Kernel for Cpu where Self: MatMulImpl, { - fn forward( + fn forward( &self, - lhs: &Tensor<(B, M, K), E, Self>, - rhs: &Tensor<(B, K, N), E, Self>, + lhs: &Tensor<(B, M, LeftK), E, Self>, + rhs: &Tensor<(B, RightK, N), E, Self>, ) -> Result, Self::Err> { let (b, m, k) = lhs.shape; let n = rhs.shape.2; @@ -416,11 +420,11 @@ where } Ok(out) } - fn backward( + fn backward( &self, - lhs: &Tensor<(B, M, K), E, Self>, + lhs: &Tensor<(B, M, LeftK), E, Self>, grad_lhs: &mut Self::Vec, - rhs: &Tensor<(B, K, N), E, Self>, + rhs: &Tensor<(B, RightK, N), E, Self>, grad_rhs: &mut Self::Vec, grad_out: &Self::Vec, ) -> Result<(), Self::Err> { @@ -451,14 +455,140 @@ where } } -impl super::MatMatBatch4Kernel for Cpu +impl super::DynamicMatMatBatch3Kernel for Cpu where Self: MatMulImpl, { - fn forward( + fn forward( &self, - lhs: &Tensor<(B, S, M, K), E, Self>, - rhs: &Tensor<(B, S, K, N), E, Self>, + lhs: &Tensor<(B, S1, usize), E, Self>, + rhs: &Tensor<(B, usize, S2), E, Self>, + ) -> Result, Self::Err> { + let (b, m, k) = lhs.shape; + let n = rhs.shape.2; + let mut out = self.try_zeros_like(&(b, m, n))?; + let ap = lhs.data.as_ref(); + let bp = rhs.data.as_ref(); + let cp = Arc::get_mut(&mut out.data).unwrap(); + for i in 0..b.size() { + Self::matmul( + (m, k, n), + ap[i * lhs.strides[0]..].as_ptr(), + [lhs.strides[1], lhs.strides[2]], + bp[i * rhs.strides[0]..].as_ptr(), + [rhs.strides[1], rhs.strides[2]], + cp[i * out.strides[0]..].as_mut_ptr(), + [out.strides[1], out.strides[2]], + ) + } + Ok(out) + } + fn backward( + &self, + lhs: &Tensor<(B, S1, usize), E, Self>, + grad_lhs: &mut Self::Vec, + rhs: &Tensor<(B, usize, S2), E, Self>, + grad_rhs: &mut Self::Vec, + grad_out: &Self::Vec, + ) -> Result<(), Self::Err> { + let (b, m, k) = lhs.shape; + let n = rhs.shape.2; + let strides = (b, m, n).strides(); + for i in 0..b.size() { + Self::matmul( + (m, n, k), + grad_out[i * strides[0]..].as_ptr(), + [strides[1], strides[2]], + rhs.data[i * rhs.strides[0]..].as_ptr(), + [rhs.strides[2], rhs.strides[1]], + grad_lhs[i * lhs.strides[0]..].as_mut_ptr(), + [lhs.strides[1], lhs.strides[2]], + ); + Self::matmul( + (k, m, n), + lhs.data[i * lhs.strides[0]..].as_ptr(), + [lhs.strides[2], lhs.strides[1]], + grad_out[i * strides[0]..].as_ptr(), + [strides[1], strides[2]], + grad_rhs[i * rhs.strides[0]..].as_mut_ptr(), + [rhs.strides[1], rhs.strides[2]], + ); + } + Ok(()) + } +} + +impl super::DynamicMatMatBatch3Kernel1 for Cpu +where + Self: MatMulImpl, +{ + fn forward( + &self, + lhs: &Tensor<(usize, S1, usize), E, Self>, + rhs: &Tensor<(usize, usize, S2), E, Self>, + ) -> Result, Self::Err> { + let (b, m, k) = lhs.shape; + let n = rhs.shape.2; + let mut out = self.try_zeros_like(&(b, m, n))?; + let ap = lhs.data.as_ref(); + let bp = rhs.data.as_ref(); + let cp = Arc::get_mut(&mut out.data).unwrap(); + for i in 0..b.size() { + Self::matmul( + (m, k, n), + ap[i * lhs.strides[0]..].as_ptr(), + [lhs.strides[1], lhs.strides[2]], + bp[i * rhs.strides[0]..].as_ptr(), + [rhs.strides[1], rhs.strides[2]], + cp[i * out.strides[0]..].as_mut_ptr(), + [out.strides[1], out.strides[2]], + ) + } + Ok(out) + } + fn backward( + &self, + lhs: &Tensor<(usize, S1, usize), E, Self>, + grad_lhs: &mut Self::Vec, + rhs: &Tensor<(usize, usize, S2), E, Self>, + grad_rhs: &mut Self::Vec, + grad_out: &Self::Vec, + ) -> Result<(), Self::Err> { + let (b, m, k) = lhs.shape; + let n = rhs.shape.2; + let strides = (b, m, n).strides(); + for i in 0..b.size() { + Self::matmul( + (m, n, k), + grad_out[i * strides[0]..].as_ptr(), + [strides[1], strides[2]], + rhs.data[i * rhs.strides[0]..].as_ptr(), + [rhs.strides[2], rhs.strides[1]], + grad_lhs[i * lhs.strides[0]..].as_mut_ptr(), + [lhs.strides[1], lhs.strides[2]], + ); + Self::matmul( + (k, m, n), + lhs.data[i * lhs.strides[0]..].as_ptr(), + [lhs.strides[2], lhs.strides[1]], + grad_out[i * strides[0]..].as_ptr(), + [strides[1], strides[2]], + grad_rhs[i * rhs.strides[0]..].as_mut_ptr(), + [rhs.strides[1], rhs.strides[2]], + ); + } + Ok(()) + } +} + +impl super::StaticMatMatBatch4Kernel for Cpu +where + Self: MatMulImpl, +{ + fn forward( + &self, + lhs: &Tensor<(B, S, M, LeftK), E, Self>, + rhs: &Tensor<(B, S, RightK, N), E, Self>, ) -> Result, Self::Err> { let (b, s, m, k) = lhs.shape; let n = rhs.shape.3; @@ -479,11 +609,141 @@ where } Ok(out) } - fn backward( + fn backward( + &self, + lhs: &Tensor<(B, S, M, LeftK), E, Self>, + grad_lhs: &mut Self::Vec, + rhs: &Tensor<(B, S, RightK, N), E, Self>, + grad_rhs: &mut Self::Vec, + grad_out: &Self::Vec, + ) -> Result<(), Self::Err> { + let (b, s, m, k) = lhs.shape; + let n = rhs.shape.3; + let strides = (b, s, m, n).strides(); + for i in 0..b.size() { + for j in 0..s.size() { + Self::matmul( + (m, n, k), + grad_out[i * strides[0] + j * strides[1]..].as_ptr(), + [strides[2], strides[3]], + rhs.data[i * rhs.strides[0] + j * rhs.strides[1]..].as_ptr(), + [rhs.strides[3], rhs.strides[2]], + grad_lhs[i * lhs.strides[0] + j * lhs.strides[1]..].as_mut_ptr(), + [lhs.strides[2], lhs.strides[3]], + ); + Self::matmul( + (k, m, n), + lhs.data[i * lhs.strides[0] + j * lhs.strides[1]..].as_ptr(), + [lhs.strides[3], lhs.strides[2]], + grad_out[i * strides[0] + j * strides[1]..].as_ptr(), + [strides[2], strides[3]], + grad_rhs[i * rhs.strides[0] + j * rhs.strides[1]..].as_mut_ptr(), + [rhs.strides[2], rhs.strides[3]], + ); + } + } + Ok(()) + } +} + +impl super::DynamicMatMatBatch4Kernel for Cpu +where + Self: MatMulImpl, +{ + fn forward( + &self, + lhs: &Tensor<(B, usize, S1, usize), E, Self>, + rhs: &Tensor<(B, usize, usize, S2), E, Self>, + ) -> Result, Self::Err> { + let (b, s, m, k) = lhs.shape; + let n = rhs.shape.3; + let mut out = self.try_zeros_like(&(b, s, m, n))?; + let cp = Arc::get_mut(&mut out.data).unwrap(); + for i in 0..b.size() { + for j in 0..s.size() { + Self::matmul( + (m, k, n), + lhs.data[i * lhs.strides[0] + j * lhs.strides[1]..].as_ptr(), + [lhs.strides[2], lhs.strides[3]], + rhs.data[i * rhs.strides[0] + j * rhs.strides[1]..].as_ptr(), + [rhs.strides[2], rhs.strides[3]], + cp[i * out.strides[0] + j * out.strides[1]..].as_mut_ptr(), + [out.strides[2], out.strides[3]], + ); + } + } + Ok(out) + } + fn backward( + &self, + lhs: &Tensor<(B, usize, S1, usize), E, Self>, + grad_lhs: &mut Self::Vec, + rhs: &Tensor<(B, usize, usize, S2), E, Self>, + grad_rhs: &mut Self::Vec, + grad_out: &Self::Vec, + ) -> Result<(), Self::Err> { + let (b, s, m, k) = lhs.shape; + let n = rhs.shape.3; + let strides = (b, s, m, n).strides(); + for i in 0..b.size() { + for j in 0..s.size() { + Self::matmul( + (m, n, k), + grad_out[i * strides[0] + j * strides[1]..].as_ptr(), + [strides[2], strides[3]], + rhs.data[i * rhs.strides[0] + j * rhs.strides[1]..].as_ptr(), + [rhs.strides[3], rhs.strides[2]], + grad_lhs[i * lhs.strides[0] + j * lhs.strides[1]..].as_mut_ptr(), + [lhs.strides[2], lhs.strides[3]], + ); + Self::matmul( + (k, m, n), + lhs.data[i * lhs.strides[0] + j * lhs.strides[1]..].as_ptr(), + [lhs.strides[3], lhs.strides[2]], + grad_out[i * strides[0] + j * strides[1]..].as_ptr(), + [strides[2], strides[3]], + grad_rhs[i * rhs.strides[0] + j * rhs.strides[1]..].as_mut_ptr(), + [rhs.strides[2], rhs.strides[3]], + ); + } + } + Ok(()) + } +} + +impl super::DynamicMatMatBatch4Kernel1 for Cpu +where + Self: MatMulImpl, +{ + fn forward( + &self, + lhs: &Tensor<(B, usize, S1, S2), E, Self>, + rhs: &Tensor<(B, usize, S2, usize), E, Self>, + ) -> Result, Self::Err> { + let (b, s, m, k) = lhs.shape; + let n = rhs.shape.3; + let mut out = self.try_zeros_like(&(b, s, m, n))?; + let cp = Arc::get_mut(&mut out.data).unwrap(); + for i in 0..b.size() { + for j in 0..s.size() { + Self::matmul( + (m, k, n), + lhs.data[i * lhs.strides[0] + j * lhs.strides[1]..].as_ptr(), + [lhs.strides[2], lhs.strides[3]], + rhs.data[i * rhs.strides[0] + j * rhs.strides[1]..].as_ptr(), + [rhs.strides[2], rhs.strides[3]], + cp[i * out.strides[0] + j * out.strides[1]..].as_mut_ptr(), + [out.strides[2], out.strides[3]], + ); + } + } + Ok(out) + } + fn backward( &self, - lhs: &Tensor<(B, S, M, K), E, Self>, + lhs: &Tensor<(B, usize, S1, S2), E, Self>, grad_lhs: &mut Self::Vec, - rhs: &Tensor<(B, S, K, N), E, Self>, + rhs: &Tensor<(B, usize, S2, usize), E, Self>, grad_rhs: &mut Self::Vec, grad_out: &Self::Vec, ) -> Result<(), Self::Err> { diff --git a/src/tensor_ops/matmul/mod.rs b/src/tensor_ops/matmul/mod.rs index 7ed4669f9..480ed0bfb 100644 --- a/src/tensor_ops/matmul/mod.rs +++ b/src/tensor_ops/matmul/mod.rs @@ -6,6 +6,7 @@ pub(super) mod cpu_kernel; pub(super) mod cuda_kernel; use crate::{ + prelude::{Const, Rank1}, shapes::{Dim, Dtype, Shape}, tensor::{DeviceStorage, HasErr, Merge, PutTape, SplitTape, Tape, Tensor}, }; @@ -60,13 +61,13 @@ use crate::{ /// pub fn matmul(lhs: Lhs, rhs: Rhs) -> Lhs::Output where - Lhs: TryMatMul, + Lhs: TryStaticMatMul, { lhs.matmul(rhs) } /// Fallible matrix multiplication. See [matmul] for examples. -pub trait TryMatMul: HasErr { +pub trait TryStaticMatMul: HasErr { type Output; fn matmul(self, rhs: Rhs) -> Self::Output { self.try_matmul(rhs).unwrap() @@ -74,6 +75,22 @@ pub trait TryMatMul: HasErr { fn try_matmul(self, rhs: Rhs) -> Result; } +pub trait TryDynamicMatMul: HasErr { + type Output; + fn matmul(self, rhs: Rhs) -> Self::Output { + self.try_dynamic_matmul(rhs).unwrap() + } + fn try_dynamic_matmul(self, rhs: Rhs) -> Result; +} + +pub trait TryDynamicMatMul1: HasErr { + type Output; + fn matmul(self, rhs: Rhs) -> Self::Output { + self.try_dynamic1_matmul(rhs).unwrap() + } + fn try_dynamic1_matmul(self, rhs: Rhs) -> Result; +} + #[rustfmt::skip] fn try_binary_op< Lhs: Shape, @@ -106,6 +123,162 @@ fn try_binary_op< Ok(out.put_tape(tape)) } +pub trait MulStaticDimCheck { + const TYPE_CHECK: (); + fn assert_dim_eq(&self); +} + +impl MulStaticDimCheck> for Rank1 { + const TYPE_CHECK: () = assert!( + L == R, + "You are trying to multiply vectors whose dimensions don't match." + ); + fn assert_dim_eq(&self) { + let _ = >>::TYPE_CHECK; + } +} + +impl MulStaticDimCheck<(Const, N)> for Const { + const TYPE_CHECK: () = assert!( + L == R, + "You are trying to multiply a vector to a matrix whose row dimension does not match the dimension of the vector." + ); + fn assert_dim_eq(&self) { + let _ = , N)>>::TYPE_CHECK; + } +} + +impl MulStaticDimCheck<(Const, N)> + for (M, Const) +{ + const TYPE_CHECK: () = assert!( + L == R, + "You are trying to multiply matrices where the column dimension of the first does not match the row dimension of the second." + ); + fn assert_dim_eq(&self) { + let _ = , N)>>::TYPE_CHECK; + } +} + +// impl MulDimCheck<(Const, usize)> for (usize, Const) { +// const TYPE_CHECK: () = assert!( +// L == R, +// "You are trying to multiply matrices where the column dimension of the first does not match the row dimension of the second." +// ); +// fn assert_dim_eq(&self) { +// let _ = , usize)>>::TYPE_CHECK; +// } +// } + +impl MulStaticDimCheck<(Const, N)> + for (B, M, Const) +{ + const TYPE_CHECK: () = assert!( + L == R, + "You are trying to multiply a tensor of rank 3 to a matrix where the last dimension of the first does not match the first dimension of the second." + ); + fn assert_dim_eq(&self) { + let _ = , N)>>::TYPE_CHECK; + } +} + +impl MulStaticDimCheck<(B, Const, N)> + for (B, M, Const) +{ + const TYPE_CHECK: () = assert!( + L == R, + "You are trying to multiply two tensors of rank 3 for Batch3Mul where the last dimension of the first does not match the second dimension of the second." + ); + fn assert_dim_eq(&self) { + let _ = , M)>>::TYPE_CHECK; + } +} + +// impl MulDimCheck<(usize, Const, usize)> +// for (usize, usize, Const) +// { +// const TYPE_CHECK: () = assert!( +// L == R, +// "You are trying to multiply two tensors of rank 3 for Batch3Mul where the last dimension of the first does not match the second dimension of the second." +// ); +// fn assert_dim_eq(&self) { +// let _ = , usize)>>::TYPE_CHECK; +// } +// } + +impl + MulStaticDimCheck<(B, S, Const, N)> for (B, S, M, Const) +{ + const TYPE_CHECK: () = assert!( + L == R, + "You are trying to multiply two tensors of rank 4 for Batch4Mul where the last dimension of the first does not match the second to last dimension of the second." + ); + fn assert_dim_eq(&self) { + let _ = , M)>>::TYPE_CHECK; + } +} + +// impl MulDimCheck<(usize, usize, Const, usize)> +// for (usize, usize, usize, Const) +// { +// const TYPE_CHECK: () = assert!( +// L == R, +// "You are trying to multiply two tensors of rank 4 for Batch4Mul where the last dimension of the first does not match the second to last dimension of the second." +// ); +// fn assert_dim_eq(&self) { +// let _ = , usize)>>::TYPE_CHECK; +// } +// } + +pub trait MulDynamicDimCheck { + fn assert_dim_eq(&self, rhs: &Rhs); +} + +impl MulDynamicDimCheck<(usize, N)> for (M, usize) { + fn assert_dim_eq(&self, rhs: &(usize, N)) { + assert_eq!(self.1, rhs.0); + } +} + +// impl MulDynamicDimCheck<(B, usize, usize, Const)> +// for (B, usize, Const, usize) +// { +// fn assert_dim_eq(&self, rhs: &(B, usize, usize, Const)) { +// assert_eq!(self.3, rhs.2); +// } +// } + +impl MulDynamicDimCheck<(B, usize, S2)> for (B, S1, usize) { + fn assert_dim_eq(&self, rhs: &(B, usize, S2)) { + assert_eq!(self.2, rhs.1); + } +} + +// impl MulDynamicDimCheck<(usize, usize, S2)> for (usize, S1, usize) { +// fn assert_dim_eq(&self, rhs: &(usize, usize, S2)) { +// assert_eq!(self.0, rhs.0); +// assert_eq!(self.2, rhs.1); +// } +// } + +// impl MulDynamicDimCheck<(usize, usize, S2)> for (usize, S1, S2) { +// fn assert_dim_eq(&self, rhs: &(usize, usize, S2)) { +// assert_eq!(self.2, rhs.1); +// } +// } + +impl MulDynamicDimCheck<(B, usize, usize, S2)> for (B, usize, S1, usize) { + fn assert_dim_eq(&self, rhs: &(B, usize, usize, S2)) { + assert_eq!(self.3, rhs.2); + } +} + +// impl MulDynamicDimCheck<(B, usize, S2, usize)> for (B, usize, S1, S2) { +// fn assert_dim_eq(&self, rhs: &(B, usize, S1, S2)) { +// assert_eq!(self.1, rhs.1); +// } +// } + pub trait VecVecKernel: DeviceStorage { fn forward( &self, @@ -124,7 +297,7 @@ pub trait VecVecKernel: DeviceStorage { } impl, T: Tape + Merge, R: Tape> - TryMatMul> for Tensor<(M,), E, D, T> + TryStaticMatMul> for Tensor<(M,), E, D, T> { type Output = Tensor<(M, N), E, D, T>; fn try_matmul(self, rhs: Tensor<(N,), E, D, R>) -> Result { @@ -133,54 +306,120 @@ impl, T: Tape + Merge, R: } pub trait VecMatKernel: DeviceStorage { - fn forward( + fn forward( &self, - lhs: &Tensor<(K,), E, Self>, - rhs: &Tensor<(K, N), E, Self>, - ) -> Result, Self::Err>; + lhs: &Tensor<(LeftK,), E, Self>, + rhs: &Tensor<(RightK, N), E, Self>, + ) -> Result, Self::Err> + where + LeftK: MulStaticDimCheck<(RightK, N)>; - fn backward( + fn backward( &self, - lhs: &Tensor<(K,), E, Self>, + lhs: &Tensor<(LeftK,), E, Self>, grad_lhs: &mut Self::Vec, - rhs: &Tensor<(K, N), E, Self>, + rhs: &Tensor<(RightK, N), E, Self>, grad_rhs: &mut Self::Vec, grad_out: &Self::Vec, - ) -> Result<(), Self::Err>; + ) -> Result<(), Self::Err> + where + LeftK: MulStaticDimCheck<(RightK, N)>; } -impl, T: Tape + Merge, R: Tape> - TryMatMul> for Tensor<(K,), E, D, T> +impl< + LeftK: Dim, + RightK: Dim, + N: Dim, + E: Dtype, + D: VecMatKernel, + T: Tape + Merge, + R: Tape, + > TryStaticMatMul> for Tensor<(LeftK,), E, D, T> +where + LeftK: MulStaticDimCheck<(RightK, N)>, { type Output = Tensor<(N,), E, D, T>; - fn try_matmul(self, rhs: Tensor<(K, N), E, D, R>) -> Result { - assert_eq!(self.shape.0, rhs.shape.0); + fn try_matmul(self, rhs: Tensor<(RightK, N), E, D, R>) -> Result { + // assert_eq!(self.shape.0, rhs.shape.0); try_binary_op(self, rhs, D::forward, D::backward) } } -pub trait MatMatKernel: DeviceStorage { - fn forward( +pub trait StaticMatMatKernel: DeviceStorage { + fn forward( &self, - lhs: &Tensor<(M, K), E, Self>, - rhs: &Tensor<(K, N), E, Self>, - ) -> Result, Self::Err>; + lhs: &Tensor<(M, LeftK), E, Self>, + rhs: &Tensor<(RightK, N), E, Self>, + ) -> Result, Self::Err> + where + (M, LeftK): MulStaticDimCheck<(RightK, N)>; - fn backward( + fn backward( &self, - lhs: &Tensor<(M, K), E, Self>, + lhs: &Tensor<(M, LeftK), E, Self>, grad_lhs: &mut Self::Vec, - rhs: &Tensor<(K, N), E, Self>, + rhs: &Tensor<(RightK, N), E, Self>, grad_rhs: &mut Self::Vec, grad_out: &Self::Vec, - ) -> Result<(), Self::Err>; + ) -> Result<(), Self::Err> + where + (M, LeftK): MulStaticDimCheck<(RightK, N)>; +} + +impl, T, R> + TryStaticMatMul> for Tensor<(M, LeftK), E, D, T> +where + T: Tape + Merge, + R: Tape, + (M, LeftK): MulStaticDimCheck<(RightK, N)>, +{ + type Output = Tensor<(M, N), E, D, T>; + /// ```compile_fail + /// # use dfdx::prelude::*; + /// # let dev: Cpu = Default::default(); + /// let x: Tensor, f32, _> = dev.zeros(); + /// let y: Tensor, f32, _> = dev.zeros(); + /// let _: Tensor, f32, _> = x.try_matmul(y); + /// ``` + fn try_matmul(self, rhs: Tensor<(RightK, N), E, D, R>) -> Result { + // assert_eq!(self.shape.1.size(), rhs.shape.0.size()); + self.shape.assert_dim_eq(); + // println!( + // "Left {:?} Right {:?}", + // self.shape.1.size(), + // rhs.shape.0.size() + // ); + try_binary_op(self, rhs, D::forward, D::backward) + } } -impl, T, R> TryMatMul> - for Tensor<(M, K), E, D, T> +pub trait DynamicMatMatKernel: DeviceStorage { + fn forward( + &self, + lhs: &Tensor<(M, usize), E, Self>, + rhs: &Tensor<(usize, N), E, Self>, + ) -> Result, Self::Err> + where + (M, usize): MulDynamicDimCheck<(usize, M)>; + + fn backward( + &self, + lhs: &Tensor<(M, usize), E, Self>, + grad_lhs: &mut Self::Vec, + rhs: &Tensor<(usize, N), E, Self>, + grad_rhs: &mut Self::Vec, + grad_out: &Self::Vec, + ) -> Result<(), Self::Err> + where + (M, usize): MulDynamicDimCheck<(usize, N)>; +} + +impl, T, R> + TryDynamicMatMul> for Tensor<(M, usize), E, D, T> where T: Tape + Merge, R: Tape, + (M, usize): MulDynamicDimCheck<(usize, N)>, { type Output = Tensor<(M, N), E, D, T>; /// ```compile_fail @@ -190,34 +429,48 @@ where /// let y: Tensor, f32, _> = dev.zeros(); /// let _: Tensor, f32, _> = x.try_matmul(y); /// ``` - fn try_matmul(self, rhs: Tensor<(K, N), E, D, R>) -> Result { - assert_eq!(self.shape.1, rhs.shape.0); + fn try_dynamic_matmul( + self, + rhs: Tensor<(usize, N), E, D, R>, + ) -> Result { + // assert_eq!(self.shape.1.size(), rhs.shape.0.size()); + self.shape.assert_dim_eq(&rhs.shape); + // println!( + // "Left {:?} Right {:?}", + // self.shape.1.size(), + // rhs.shape.0.size() + // ); try_binary_op(self, rhs, D::forward, D::backward) } } pub trait MatMatBrKernel: DeviceStorage { - fn forward( + fn forward( &self, - lhs: &Tensor<(B, M, K), E, Self>, - rhs: &Tensor<(K, N), E, Self>, - ) -> Result, Self::Err>; + lhs: &Tensor<(B, M, LeftK), E, Self>, + rhs: &Tensor<(RightK, N), E, Self>, + ) -> Result, Self::Err> + where + (B, M, LeftK): MulStaticDimCheck<(RightK, N)>; - fn backward( + fn backward( &self, - lhs: &Tensor<(B, M, K), E, Self>, + lhs: &Tensor<(B, M, LeftK), E, Self>, grad_lhs: &mut Self::Vec, - rhs: &Tensor<(K, N), E, Self>, + rhs: &Tensor<(RightK, N), E, Self>, grad_rhs: &mut Self::Vec, grad_out: &Self::Vec, - ) -> Result<(), Self::Err>; + ) -> Result<(), Self::Err> + where + (B, M, LeftK): MulStaticDimCheck<(RightK, N)>; } -impl, T, R> - TryMatMul> for Tensor<(B, M, K), E, D, T> +impl, T, R> + TryStaticMatMul> for Tensor<(B, M, LeftK), E, D, T> where T: Tape + Merge, R: Tape, + (B, M, LeftK): MulStaticDimCheck<(RightK, N)>, { type Output = Tensor<(B, M, N), E, D, T>; /// ```compile_fail @@ -227,35 +480,41 @@ where /// let y: Tensor, f32, _> = dev.zeros(); /// let _: Tensor, f32, _> = x.try_matmul(y); /// ``` - fn try_matmul(self, rhs: Tensor<(K, N), E, D, R>) -> Result { - assert_eq!(self.shape.2, rhs.shape.0); + fn try_matmul(self, rhs: Tensor<(RightK, N), E, D, R>) -> Result { + // assert_eq!(self.shape.2, rhs.shape.0); + self.shape.assert_dim_eq(); try_binary_op(self, rhs, D::forward, D::backward) } } -pub trait MatMatBatch3Kernel: DeviceStorage { - fn forward( +pub trait StaticMatMatBatch3Kernel: DeviceStorage { + fn forward( &self, - lhs: &Tensor<(B, M, K), E, Self>, - rhs: &Tensor<(B, K, N), E, Self>, - ) -> Result, Self::Err>; + lhs: &Tensor<(B, M, LeftK), E, Self>, + rhs: &Tensor<(B, RightK, N), E, Self>, + ) -> Result, Self::Err> + where + (B, M, LeftK): MulStaticDimCheck<(B, RightK, N)>; - fn backward( + fn backward( &self, - lhs: &Tensor<(B, M, K), E, Self>, + lhs: &Tensor<(B, M, LeftK), E, Self>, grad_lhs: &mut Self::Vec, - rhs: &Tensor<(B, K, N), E, Self>, + rhs: &Tensor<(B, RightK, N), E, Self>, grad_rhs: &mut Self::Vec, grad_out: &Self::Vec, - ) -> Result<(), Self::Err>; + ) -> Result<(), Self::Err> + where + (B, M, LeftK): MulStaticDimCheck<(B, RightK, N)>; } -impl TryMatMul> - for Tensor<(B, M, K), E, D, T> +impl + TryStaticMatMul> for Tensor<(B, M, LeftK), E, D, T> where - D: MatMatBatch3Kernel, + D: StaticMatMatBatch3Kernel, T: Tape + Merge, R: Tape, + (B, M, LeftK): MulStaticDimCheck<(B, RightK, N)>, { type Output = Tensor<(B, M, N), E, D, T>; /// ```compile_fail @@ -265,36 +524,186 @@ where /// let y: Tensor, f32, _> = dev.zeros(); /// let _: Tensor, f32, _> = x.try_matmul(y); /// ``` - fn try_matmul(self, rhs: Tensor<(B, K, N), E, D, R>) -> Result { - assert_eq!(self.shape.0, rhs.shape.0); - assert_eq!(self.shape.2, rhs.shape.1); + fn try_matmul(self, rhs: Tensor<(B, RightK, N), E, D, R>) -> Result { + // assert_eq!(self.shape.0, rhs.shape.0); + // assert_eq!(self.shape.2, rhs.shape.1); + self.shape.assert_dim_eq(); try_binary_op(self, rhs, D::forward, D::backward) } } -pub trait MatMatBatch4Kernel: DeviceStorage { - fn forward( +pub trait DynamicMatMatBatch3Kernel: DeviceStorage { + fn forward( &self, - lhs: &Tensor<(B, S, M, K), E, Self>, - rhs: &Tensor<(B, S, K, N), E, Self>, - ) -> Result, Self::Err>; + lhs: &Tensor<(B, S1, usize), E, Self>, + rhs: &Tensor<(B, usize, S2), E, Self>, + ) -> Result, Self::Err> + where + (B, S1, usize): MulDynamicDimCheck<(B, usize, S2)>; - fn backward( + fn backward( &self, - lhs: &Tensor<(B, S, M, K), E, Self>, + lhs: &Tensor<(B, S1, usize), E, Self>, grad_lhs: &mut Self::Vec, - rhs: &Tensor<(B, S, K, N), E, Self>, + rhs: &Tensor<(B, usize, S2), E, Self>, grad_rhs: &mut Self::Vec, grad_out: &Self::Vec, - ) -> Result<(), Self::Err>; + ) -> Result<(), Self::Err> + where + (B, S1, usize): MulDynamicDimCheck<(B, usize, S2)>; +} + +impl TryDynamicMatMul> + for Tensor<(B, S1, usize), E, D, T> +where + D: DynamicMatMatBatch3Kernel, + T: Tape + Merge, + R: Tape, + (B, S1, usize): MulDynamicDimCheck<(B, usize, S2)>, +{ + type Output = Tensor<(B, S1, S2), E, D, T>; + /// ```compile_fail + /// # use dfdx::prelude::*; + /// # let dev: Cpu = Default::default(); + /// let x: Tensor, f32, _> = dev.zeros(); + /// let y: Tensor, f32, _> = dev.zeros(); + /// let _: Tensor, f32, _> = x.try_matmul(y); + /// ``` + fn try_dynamic_matmul( + self, + rhs: Tensor<(B, usize, S2), E, D, R>, + ) -> Result { + // assert_eq!(self.shape.0, rhs.shape.0); + // assert_eq!(self.shape.2, rhs.shape.1); + self.shape.assert_dim_eq(&rhs.shape); + try_binary_op(self, rhs, D::forward, D::backward) + } +} + +pub trait DynamicMatMatBatch3Kernel1: DeviceStorage { + fn forward( + &self, + lhs: &Tensor<(usize, S1, usize), E, Self>, + rhs: &Tensor<(usize, usize, S2), E, Self>, + ) -> Result, Self::Err> + where + (usize, S1, usize): MulDynamicDimCheck<(usize, usize, S2)>; + + fn backward( + &self, + lhs: &Tensor<(usize, S1, usize), E, Self>, + grad_lhs: &mut Self::Vec, + rhs: &Tensor<(usize, usize, S2), E, Self>, + grad_rhs: &mut Self::Vec, + grad_out: &Self::Vec, + ) -> Result<(), Self::Err> + where + (usize, S1, usize): MulDynamicDimCheck<(usize, usize, S2)>; +} + +impl TryDynamicMatMul1> + for Tensor<(usize, S1, usize), E, D, T> +where + D: DynamicMatMatBatch3Kernel1, + T: Tape + Merge, + R: Tape, + (usize, S1, usize): MulDynamicDimCheck<(usize, usize, S2)>, +{ + type Output = Tensor<(usize, S1, S2), E, D, T>; + /// ```compile_fail + /// # use dfdx::prelude::*; + /// # let dev: Cpu = Default::default(); + /// let x: Tensor, f32, _> = dev.zeros(); + /// let y: Tensor, f32, _> = dev.zeros(); + /// let _: Tensor, f32, _> = x.try_matmul(y); + /// ``` + fn try_dynamic1_matmul( + self, + rhs: Tensor<(usize, usize, S2), E, D, R>, + ) -> Result { + // assert_eq!(self.shape.0, rhs.shape.0); + // assert_eq!(self.shape.2, rhs.shape.1); + self.shape.assert_dim_eq(&rhs.shape); + try_binary_op(self, rhs, D::forward, D::backward) + } +} + +// pub trait MatMatBatch3Kernel: DeviceStorage { +// fn forward( +// &self, +// lhs: &Tensor<(B, M, LeftK), E, Self>, +// rhs: &Tensor<(B, RightK, N), E, Self>, +// ) -> Result, Self::Err>; +// // where +// // (B, M, LeftK): MulStaticDimCheck<(B, RightK, N)>; + +// fn backward( +// &self, +// lhs: &Tensor<(B, M, LeftK), E, Self>, +// grad_lhs: &mut Self::Vec, +// rhs: &Tensor<(B, RightK, N), E, Self>, +// grad_rhs: &mut Self::Vec, +// grad_out: &Self::Vec, +// ) -> Result<(), Self::Err>; +// // where +// // (B, M, LeftK): MulStaticDimCheck<(B, RightK, N)>; +// } + +// impl +// TryDynamicMatMul> for Tensor<(B, M, LeftK), E, D, T> +// where +// D: MatMatBatch3Kernel, +// T: Tape + Merge, +// R: Tape, +// (B, M, LeftK): MulDynamicDimCheck<(B, RightK, N)>, +// { +// type Output = Tensor<(B, M, N), E, D, T>; +// /// ```compile_fail +// /// # use dfdx::prelude::*; +// /// # let dev: Cpu = Default::default(); +// /// let x: Tensor, f32, _> = dev.zeros(); +// /// let y: Tensor, f32, _> = dev.zeros(); +// /// let _: Tensor, f32, _> = x.try_matmul(y); +// /// ``` +// fn try_dynamic_matmul( +// self, +// rhs: Tensor<(B, RightK, N), E, D, R>, +// ) -> Result { +// // assert_eq!(self.shape.0, rhs.shape.0); +// // assert_eq!(self.shape.2, rhs.shape.1); +// // self.shape.assert_dim_eq(); +// try_binary_op(self, rhs, D::forward, D::backward) +// } +// } + +pub trait StaticMatMatBatch4Kernel: DeviceStorage { + fn forward( + &self, + lhs: &Tensor<(B, S, M, LeftK), E, Self>, + rhs: &Tensor<(B, S, RightK, N), E, Self>, + ) -> Result, Self::Err> + where + (B, S, M, LeftK): MulStaticDimCheck<(B, S, RightK, N)>; + + fn backward( + &self, + lhs: &Tensor<(B, S, M, LeftK), E, Self>, + grad_lhs: &mut Self::Vec, + rhs: &Tensor<(B, S, RightK, N), E, Self>, + grad_rhs: &mut Self::Vec, + grad_out: &Self::Vec, + ) -> Result<(), Self::Err> + where + (B, S, M, LeftK): MulStaticDimCheck<(B, S, RightK, N)>; } -impl - TryMatMul> for Tensor<(B, S, M, K), E, D, T> +impl + TryStaticMatMul> for Tensor<(B, S, M, LeftK), E, D, T> where - D: MatMatBatch4Kernel, + D: StaticMatMatBatch4Kernel, T: Tape + Merge, R: Tape, + (B, S, M, LeftK): MulStaticDimCheck<(B, S, RightK, N)>, { type Output = Tensor<(B, S, M, N), E, D, T>; /// ```compile_fail @@ -304,10 +713,114 @@ where /// let y: Tensor, f32, _> = dev.zeros(); /// let _: Tensor, f32, _> = x.try_matmul(y); /// ``` - fn try_matmul(self, rhs: Tensor<(B, S, K, N), E, D, R>) -> Result { - assert_eq!(self.shape.0, rhs.shape.0); - assert_eq!(self.shape.1, rhs.shape.1); - assert_eq!(self.shape.3, rhs.shape.2); + fn try_matmul( + self, + rhs: Tensor<(B, S, RightK, N), E, D, R>, + ) -> Result { + // assert_eq!(self.shape.0, rhs.shape.0); + // assert_eq!(self.shape.1, rhs.shape.1); + // assert_eq!(self.shape.3, rhs.shape.2); + self.shape.assert_dim_eq(); + try_binary_op(self, rhs, D::forward, D::backward) + } +} + +pub trait DynamicMatMatBatch4Kernel: DeviceStorage { + fn forward( + &self, + lhs: &Tensor<(B, usize, S1, usize), E, Self>, + rhs: &Tensor<(B, usize, usize, S2), E, Self>, + ) -> Result, Self::Err>; + // where + // (usize, usize, usize, LeftK): MulDynamicDimCheck<(usize, usize, RightK, usize)>; + + fn backward( + &self, + lhs: &Tensor<(B, usize, S1, usize), E, Self>, + grad_lhs: &mut Self::Vec, + rhs: &Tensor<(B, usize, usize, S2), E, Self>, + grad_rhs: &mut Self::Vec, + grad_out: &Self::Vec, + ) -> Result<(), Self::Err>; + // where + // (usize, usize, usize, LeftK): MulDynamicDimCheck<(usize, usize, RightK, usize)>; +} + +impl + TryDynamicMatMul> + for Tensor<(B, usize, S1, usize), E, D, T> +where + D: DynamicMatMatBatch4Kernel, + T: Tape + Merge, + R: Tape, + // (usize, usize, usize, LeftK): MulDynamicDimCheck<(usize, usize, RightK, usize)>, +{ + type Output = Tensor<(B, usize, S1, S2), E, D, T>; + /// ```compile_fail + /// # use dfdx::prelude::*; + /// # let dev: Cpu = Default::default(); + /// let x: Tensor, f32, _> = dev.zeros(); + /// let y: Tensor, f32, _> = dev.zeros(); + /// let _: Tensor, f32, _> = x.try_matmul(y); + /// ``` + fn try_dynamic_matmul( + self, + rhs: Tensor<(B, usize, usize, S2), E, D, R>, + ) -> Result { + // assert_eq!(self.shape.0, rhs.shape.0); + // assert_eq!(self.shape.1, rhs.shape.1); + // assert_eq!(self.shape.3, rhs.shape.2); + self.shape.assert_dim_eq(&rhs.shape); + try_binary_op(self, rhs, D::forward, D::backward) + } +} + +pub trait DynamicMatMatBatch4Kernel1: DeviceStorage { + fn forward( + &self, + lhs: &Tensor<(B, usize, S1, S2), E, Self>, + rhs: &Tensor<(B, usize, S2, usize), E, Self>, + ) -> Result, Self::Err>; + // where + // (usize, usize, usize, LeftK): MulDynamicDimCheck<(usize, usize, RightK, usize)>; + + fn backward( + &self, + lhs: &Tensor<(B, usize, S1, S2), E, Self>, + grad_lhs: &mut Self::Vec, + rhs: &Tensor<(B, usize, S2, usize), E, Self>, + grad_rhs: &mut Self::Vec, + grad_out: &Self::Vec, + ) -> Result<(), Self::Err>; + // where + // (usize, usize, usize, LeftK): MulDynamicDimCheck<(usize, usize, RightK, usize)>; +} + +impl + TryDynamicMatMul1> + for Tensor<(B, usize, S1, S2), E, D, T> +where + D: DynamicMatMatBatch4Kernel1, + T: Tape + Merge, + R: Tape, + (B, usize, S1, S2): MulDynamicDimCheck<(B, usize, S2, usize)>, +{ + type Output = Tensor<(B, usize, S1, usize), E, D, T>; + /// ```compile_fail + /// # use dfdx::prelude::*; + /// # let dev: Cpu = Default::default(); + /// let x: Tensor, f32, _> = dev.zeros(); + /// let y: Tensor, f32, _> = dev.zeros(); + /// let _: Tensor, f32, _> = x.try_matmul(y); + /// ``` + fn try_dynamic1_matmul( + self, + rhs: Tensor<(B, usize, S2, usize), E, D, R>, + ) -> Result { + // assert_eq!(self.shape.0, rhs.shape.0); + // assert_eq!(self.shape.1, rhs.shape.1); + // assert_eq!(self.shape.3, rhs.shape.2); + self.shape.assert_dim_eq(&rhs.shape); try_binary_op(self, rhs, D::forward, D::backward) } } @@ -350,14 +863,14 @@ mod tests { { let a: Tensor, TestDtype, _> = dev.zeros(); let b: Tensor, TestDtype, _> = dev.zeros(); - let _: Tensor, TestDtype, _> = a.matmul(b); + let _: Tensor, TestDtype, _> = matmul::TryStaticMatMul::matmul(a, b); } - { - let a: Tensor, TestDtype, _> = dev.zeros(); - let b: Tensor, TestDtype, _> = dev.zeros(); - let _: Tensor, TestDtype, _> = a.matmul(b); - } + // { + // let a: Tensor, TestDtype, _> = dev.zeros(); + // let b: Tensor, TestDtype, _> = dev.zeros(); + // let _: Tensor, TestDtype, _> = matmul::TryStaticMatMul::matmul(a, b); + // } { let a: Tensor, TestDtype, _> = dev.zeros(); diff --git a/src/tensor_ops/mod.rs b/src/tensor_ops/mod.rs index 8b99e110a..fbb606f2d 100644 --- a/src/tensor_ops/mod.rs +++ b/src/tensor_ops/mod.rs @@ -221,7 +221,7 @@ pub use huber_error::huber_error; pub use ln::ln; pub use log_softmax::log_softmax; pub use logsumexp_to::LogSumExpTo; -pub use matmul::{matmul, TryMatMul}; +pub use matmul::{matmul, TryDynamicMatMul, TryDynamicMatMul1, TryStaticMatMul}; pub use max_to::MaxTo; pub use maximum::maximum; pub use mean_to::MeanTo; diff --git a/src/tensor_ops/utilities/device.rs b/src/tensor_ops/utilities/device.rs index 09ed30dce..9f38c8e7e 100644 --- a/src/tensor_ops/utilities/device.rs +++ b/src/tensor_ops/utilities/device.rs @@ -42,11 +42,14 @@ pub trait Device: // matmuls + super::super::matmul::VecMatKernel - + super::super::matmul::MatMatKernel + + super::super::matmul::StaticMatMatKernel + super::super::matmul::VecVecKernel + super::super::matmul::MatMatBrKernel - + super::super::matmul::MatMatBatch3Kernel - + super::super::matmul::MatMatBatch4Kernel + + super::super::matmul::StaticMatMatBatch3Kernel + + super::super::matmul::DynamicMatMatBatch3Kernel + + super::super::matmul::StaticMatMatBatch4Kernel + + super::super::matmul::DynamicMatMatBatch4Kernel + + super::super::matmul::DynamicMatMatBatch4Kernel1 // scalar arithmetic + UnaryKernel, E>