Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactors TryMatmuls for better error messages when dimensions don't match #681

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

kstavro
Copy link
Contributor

@kstavro kstavro commented Apr 8, 2023

I haven't synced the branch yet with all the recent commits, but will of course do if this works out in the end.

So, I tried to move more of the dynamic dim checks into the respective traits before pushing, and now I need to split the traits and ops even more for stuff to work inside mha.rs. So, unless I "turn off" the dynamic checks as in e.g. where (B, S1, usize): MulDynamicDimCheck<(B, usize, S2)>; (and apart from tests inside mod.rs needing to be fixed in any case), multiheaded attentions still spits out errors.

Just so that you don't get completely lost in the mess that I have pushed, a quick summary of what I am doing:

  1. Replacing all the trymatmuls TryMatMul<M,K,N> from assuming that the crucial dim K is always common between the two tensors, into TryMatMul<M, LeftK, RightK, N>, so that we can make the dim checks with the respective trait. This would already work quite good, but with dynamic dims we need to split into static or dynamic traits, depending on whether K is static or dynamic (the dynamic part mostly because of transformers).
  2. Splitting all kernels and ops into more variations so that we are able to implement the checks for different combinations of shapes, e.g when trying to check different combinations within the same trait:
impl<B: Dim, S1: Dim, S2: Dim> MulDynamicDimCheck<(B, usize, usize, S2)> for (B, usize, S1, usize) {
    fn assert_dim_eq(&self, rhs: &(B, usize, usize, S2)) {
        assert_eq!(self.1, rhs.1);
        assert_eq!(self.3, rhs.2);
    }
}

impl<B: Dim, S1: Dim, S2: Dim> MulDynamicDimCheck<(B, usize, S2, usize)> for (B, usize, S1, S2) {
    fn assert_dim_eq(&self, rhs: &(B, usize, S1, S2)) {
        assert_eq!(self.1, rhs.1);
    }
}

we get conflict errors:

--> src\tensor_ops\matmul\mod.rs:277:1
    |
270 | impl<B: Dim, S1: Dim, S2: Dim> MulDynamicDimCheck<(B, usize, usize, S2)> for (B, usize, S1, usize) {
    | -------------------------------------------------------------------------------------------------- first implementation here
...
277 | impl<B: Dim, S1: Dim, S2: Dim> MulDynamicDimCheck<(B, usize, S2, usize)> for (B, usize, S1, S2) {
    | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ conflicting implementation for `(_, usize, _, usize)`

I have also probably forgotten to implement some stuff while lost in the splitting of traits. All the implementations that I have commented out more or less need new variations of the traits for them to work.

It would be great if we could find a nicer (and more maintainable way) to get over this, because the errors now look quite good and informative (already tested for the static cases and some dynamic cases).

@coreylowman
Copy link
Owner

Okay have you tried having a trait at the Dim level that does the dyn vs const?

trait DimEq<Rhs> {
    fn assert_dim_eq(...) { ... }
}

impl DimEq<usize> for usize {  ... }
impl DimEq<Const<N>> for Const<M> { ... }
impl DimEq<usize> for Const<M> { ... }
impl DimEq<Const<M>> for usize { ... }

That might let us combine the static vs dynamic traits you have now?

impl<M: Dim, K1: Dim, K2: Dim, N: Dim> MulDimCheck<(M, K1)> for (K2, N)
where K1: DimEq<K2> 
{ ... }

@kstavro
Copy link
Contributor Author

kstavro commented Apr 9, 2023

Okay have you tried having a trait at the Dim level that does the dyn vs const?

trait DimEq<Rhs> {
    fn assert_dim_eq(...) { ... }
}

impl<M: Dim, K1: Dim, K2: Dim, N: Dim> MulDimCheck<(M, K1)> for (K2, N)
where K1: DimEq<K2> 
{ ... }

By having only one DimEq trait and MulDimCheck we will be able to only check either static dims or dynamic dims because of the single assert_dim_eq(), right?

I tried only having the static dim checks (which is what you are elluding to?), but attention complains even if we omit the dynamic dim checks, because we still need variations of TryMatMuls now that K is split into K1 and K2 (when K1, K2 are not Dim, but usize. Explaining more below). I mentioned it somewhere inside the issue thread.

Having the static and dynamic variants isn't so much of a problem imo, because they can really serve a purpose: either make a compile or a runtime check. The main problem is that, these traits need to be implemented for dynamic shapes multiple times because of what happens inside attention. In particular, try_reshape_like. The output of let v = v.try_reshape_like(&(b, s2, H, V / H)).unwrap()?; is a Tensor<(B, S2, usize, usize), E, D>, so all the matmul traits and ops need to be implemented specifically for combinations of dims that have usize instead of Dim. I can still force the output to be Tensor<(B, S2, const<H>, usize), E, D> but there is now way for V/H not to be usize inside the tensor.

To make the above clearer on your example, one would then need to implement both
impl<M: Dim, K1: Dim, K2: Dim, N: Dim> MulDimCheck<(M, K1)> for (K2, N) where K1: DimEq<K2>,
as well as
impl<M: Dim, N: Dim> MulDimCheck<(M, usize)> for (usize, N) where {dynamic check of the dynamic dims}.

Even if I omit the dynamic implementations and checks, I still need to implement the dynamic matmuls. The problem remains the same: rust complains about conflicts in the implementations, which is mostly relevant for the rank3 and rank4 matmuls inside attention, so this is where one needs to split the traits to allow for implementations of all arising combinations of (B, S2, usize, usize), (B, S2, S1, usize), (usize, usize, S1, usize), (usize, usize, usize, S2), etc. If you have a look inside mha.rs, you will see what kind of tensors rust thinks q,k,v are, it's usize all over the place.

If we could force somehow, eg the output of let v = v.try_reshape_like(&(b, s2, H, V / H)).unwrap()?; to be sort of Tensor<(B, S2, const<H>, const<V/H), E, D>, that would solve all the problems at once. So basically either find an analogue of
let v: Tensor<(B, S2, Const<H>, {representation of V/H as a Const<>), E, D>= v.try_reshape_like(&(b, s2, H, V / H)).unwrap()?; (it works for H, but couldn't make it work for V/H) or trying to change the way try_reshape_like works (didn't find any obvious way there). But then this sort of solves the problem of having dynamic dimensions in the first place, so I guess this isn't easy to do?

@coreylowman
Copy link
Owner

Having the static and dynamic variants isn't so much of a problem imo, because they can really serve a purpose: either make a compile or a runtime check. The main problem is that, these traits need to be implemented for dynamic shapes multiple times because of what happens inside attention. In particular, try_reshape_like. The output of let v = v.try_reshape_like(&(b, s2, H, V / H)).unwrap()?; is a Tensor<(B, S2, usize, usize), E, D>, so all the matmul traits and ops need to be implemented specifically for combinations of dims that have usize instead of Dim. I can still force the output to be Tensor<(B, S2, const, usize), E, D> but there is now way for V/H not to be usize inside the tensor.

To make the above clearer on your example, one would then need to implement both
impl<M: Dim, K1: Dim, K2: Dim, N: Dim> MulDimCheck<(M, K1)> for (K2, N) where K1: DimEq,
as well as
impl<M: Dim, N: Dim> MulDimCheck<(M, usize)> for (usize, N) where {dynamic check of the dynamic dims}.

Can you expand? Dim is implemented for usize so K1: Dim should cover both usize and Const - you shouldn't need separate impls of MulDimCheck for them, one impl will cover all cases

@kstavro
Copy link
Contributor Author

kstavro commented Apr 10, 2023

Can you expand? Dim is implemented for usize so K1: Dim should cover both usize and Const - you shouldn't need separate impls of MulDimCheck for them, one impl will cover all cases

This is what I also thought/hoped, but it seems like if rust thinks an object has usize as a dimension, then it is a different thing, it isn't a Dim any more. That is what creates all the problems.

Eg in the code I have pushed in the PR, I have already implemented:

impl<B: Dim, M: Dim, LeftK: Dim, RightK: Dim, N: Dim, E: Dtype, D, T, R>
    TryStaticMatMul<Tensor<(B, RightK, N), E, D, R>> for Tensor<(B, M, LeftK), E, D, T>
where ...

and

impl<B: Dim, S1: Dim, S2: Dim, E: Dtype, D, T, R> TryDynamicMatMul<Tensor<(B, usize, S2), E, D, R>>
    for Tensor<(B, S1, usize), E, D, T>
where ...

and

impl<S1: Dim, S2: Dim, E: Dtype, D, T, R> TryDynamicMatMul1<Tensor<(usize, usize, S2), E, D, R>>
    for Tensor<(usize, S1, usize), E, D, T>
where ...

So, combinations of Tensor<(B, RightK, N), E, D, R>, Tensor<(B, usize, S2), E, D, R>, Tensor<(usize, usize, S2), E, D, R> multiplied with Tensor<(B, M, LeftK), E, D, T>, Tensor<(B, S1, usize), E, D, T>, Tensor<(usize, S1, usize), E, D, T>. Implementing all these combinations gets me inside attention through all the multiplications up to L134 inside mha.rs:

let tokens = weights.try_dynamic_matmul(v)?;

The error message for it:

error[E0599]: no method named `try_dynamic_matmul` found for struct `Tensor` in the current scope
   --> src\nn\transformer\mha.rs:134:30
    |
134 |         let tokens = weights.try_dynamic_matmul(v)?;
    |                              ^^^^^^^^^^^^^^^^^^ method not found in `Tensor<(usize, S1, S2), E, D, T>`
    |
   ::: src\tensor\tensor_impls.rs:32:1
    |
32  | pub struct Tensor<S: Shape, E: Unit, D: DeviceStorage, T = NoneTape> {
    | -------------------------------------------------------------------- method `try_dynamic_matmul` not found for this struct
    |
    = help: items from traits can only be used if the trait is implemented and in scope
note: `matmul::TryDynamicMatMul` defines an item `try_dynamic_matmul`, perhaps you need to implement it
   --> src\tensor_ops\matmul\mod.rs:78:1
    |
78  | pub trait TryDynamicMatMul<Rhs>: HasErr {
    | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Another dimension combination I haven't implemented (Tensor<(usize, S1, S2), E, D, T> which again, rust thinks is a new type combination, usize is not a Dim), which I probably need a new trait for.

I wish I am doing something dumb to be honest or misunderstanding some rust details. But not sure what that might be. My current insights are summarized in the last paragraph of my previous comment, if that helps.

@coreylowman
Copy link
Owner

Here's the diff of the fixes!
fix-dim-checks.patch

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants