Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MatMul] CUTLASS style predicate evaluation : peeled predicate shift #1973

Open
wants to merge 18 commits into
base: skew_double_buffer
Choose a base branch
from

Conversation

shmsong
Copy link

@shmsong shmsong commented Sep 13, 2022

This PR introduces a very specific predicate simplification technique that works well with matmul kernels.

A simple example showing how this trick works is, say we have
T0 [I(T0.size[0])] -> split(32) -> T0 [Io(ceilDiv(T0.size[0],32)), Ii(32)], which generates the following code:

for i in 0..ceilDiv(T0.size[0],32)
  for j in 0..32:
     // assume we need to initialize in this kernel
    T0[i*32+j]  = 0;
    if i*32+j < T0.size[0]
      T0[i*32+j] ...

The above code generates 32 predicates as the predicate is inlined in the inner loop.

The simplification trick is to convert the loop into:

let ceiMod(a, b) = a %b == 0 ? b : a %b;

// peeled residue prolog : (called initial evaluation in cutlass)
//   Very similar to the original loop except the
//  outer loop extent and the predicate extent
//  are modified.
for i in 0..1
  for j in 0..32:
    T0[i*32+j]  = 0;
    if i*32+j < ceilMod(T0.size[0], 32)
      T0[i*32+j] ...

// peeled residue main loop
// (called steady-state in cutlass)
for i in 0..ceilDiv(T0.size[0],32)-1
  for j in 0..32:
      // No need to initialize as we know the predicate
     //  is all true. 
     //  This significantly reduces memory instruction
     //  congestion with cp.async kernels. 

      // No longer need to predicate here as
      //  the residue part of the root iterdomain has
      //  been peeled away.
      T0[i*32+j + ceilMod(T0.size[0],32)] ...

This way there is no predicate and no initialization in the peeled residue main loop

This PR does a simple implementation of the above technique that currently only supports a direct split of K dimension from a root iterdomain, which covers matmul, and batched matmul use cases.

TODO: while useful for matmul, this technique would be tricky to extend to work well with tensor contraction kernels.

@shmsong shmsong changed the title [Matmul perf] CUTLASS style predicate evaluation : peeled predicate shift WIP: [Not ready for review] [Matmul perf] CUTLASS style predicate evaluation : peeled predicate shift Sep 13, 2022
@shmsong shmsong changed the title WIP: [Not ready for review] [Matmul perf] CUTLASS style predicate evaluation : peeled predicate shift CUTLASS style predicate evaluation : peeled predicate shift Sep 20, 2022
@@ -778,7 +778,7 @@ void IndexLowering::handle(const LoadStoreOp* ldst) {
pushBack(IrBuilder::create<LoadStoreOp>(ldst->opType(), out, in));
GpuLower::current()->propagateExprInfo(ldst, back());
if (ldst->predicate()) {
back()->setPredicate(ldst->predicate());
back()->setPredicate(IrBuilder::create<kir::Predicate>(ldst->predicate()));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why creating a new instance here?

return SimplifyingIrBuilder::subExpr(extent_plus_factor, extent_round_up);
}

Val* PredicatePeeling::getSplitTileMainOffset(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to self: How is this used?

//! Returns true if the expression is an initialization expr that
//! can be omitted in main loop.
//! See [Predicate Peeling Interaction with Circular Buffering]
bool canOmitInitInMainLoop(Expr* expr, kir::ForLoop* double_buffer_loop) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trying to understand how this works with the PredicatePeeling2 test, but didn't see any zero initialization even when this function was modified to return always false. Changed the test not to use cp.async, and then the zero init appeared. Why is there no init when cp.async is used?

Comment on lines +334 to +336
// This optimization only applies when all the loops on the
// inner side of the double buffer main loop are either
// constant unrolled or parallel.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does it apply only in those cases?

Comment on lines +507 to +509
// If the double buffer loop is to be peeled. Will need to insert
// a circular buffer init stage to initialize the final stage of
// circular buffer space.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not following this. Can you please give an example?

@naoyam
Copy link
Collaborator

naoyam commented Sep 22, 2022

Here's the generated code from the PredicatePeeling2 test:

__global__ void kernel1(Tensor<float, 2> T0, Tensor<float, 1> T1) {
  alignas(16) extern __shared__ char array[];
  unsigned offset = 0;
  NVFUSER_DEFINE_MAGIC_ZERO
  offset = alignBufferSize(offset, 16);
  float* T2 = reinterpret_cast<float*>(array + offset);
  offset += (((16 * 16) * 3) * sizeof(float));
  int64_t i36;
  i36 = (((nvfuser_index_t)blockIdx.x) * 16) + ((nvfuser_index_t)threadIdx.x);
  if ((i36 < T0.size[0])) {
    T1[i36] = 0.00000000000000000e+00;
  }
  NVFUSER_UPDATE_MAGIC_ZERO
  #pragma unroll
  for(nvfuser_index_t i17 = 0; i17 < 16; ++i17) {
    int64_t i43;
    i43 = (((nvfuser_index_t)blockIdx.x) * 16) + (i17 + nvfuser_zero);
    Ampere::cpAsync(reinterpret_cast<Array<float,1,1>*>(&T2[(i17 * 16) + ((nvfuser_index_t)threadIdx.x)]),reinterpret_cast<Array<float,1,1>*>(&T0[(i43 * T0.size[1]) + ((nvfuser_index_t)threadIdx.x)]),((i43 < T0.size[0]) && (((nvfuser_index_t)threadIdx.x) < ((T0.size[1]
+ 16) + (-((ceilDiv(T0.size[1], 16)) * 16))))));
  }
  NVFUSER_UPDATE_MAGIC_ZERO
  Ampere::cpAsyncCommit();
  #pragma unroll
  for(nvfuser_index_t i18 = 1; i18 < 2; ++i18) {
    #pragma unroll
    for(nvfuser_index_t i17 = 0; i17 < 16; ++i17) {
      int64_t i61;
      i61 = (((nvfuser_index_t)blockIdx.x) * 16) + (i17 + nvfuser_zero);
      Ampere::cpAsync(reinterpret_cast<Array<float,1,1>*>(&T2[(i17 * 16) + ((nvfuser_index_t)threadIdx.x) + (i18 * (16 * 16))]),reinterpret_cast<Array<float,1,1>*>(&T0[(i61 * T0.size[1]) + (((i18 * 16) + ((nvfuser_index_t)threadIdx.x)) + (-((-((T0.size[1] + 16) + (-((ceilDiv(T0.size[1], 16)) * 16)))) + 16)))]),(i61 < T0.size[0]));
    }
    Ampere::cpAsyncCommit();
  }
  NVFUSER_UPDATE_MAGIC_ZERO
  #pragma unroll
  for(nvfuser_index_t i19 = 2; i19 < 3; ++i19) {
    Ampere::cpAsyncCommit();
  }
  NVFUSER_UPDATE_MAGIC_ZERO
  Ampere::cpAsyncPartialBarrier<1>();
  __barrier_sync(0);
  #pragma unroll 1
  for(nvfuser_index_t i20 = 0; i20 < (ceilDiv(T0.size[1], 16)); ++i20) {
    int64_t i92;
    i92 = ((i20 + 2) * 16) + ((nvfuser_index_t)threadIdx.x);
    #pragma unroll
    for(nvfuser_index_t i17 = 0; i17 < 16; ++i17) {
      int64_t i94;
      i94 = (((nvfuser_index_t)blockIdx.x) * 16) + (i17 + nvfuser_zero);
      Ampere::cpAsync(reinterpret_cast<Array<float,1,1>*>(&T2[(i17 * 16) + ((nvfuser_index_t)threadIdx.x) + (((i20 + 2) % 3) * (16 * 16))]),reinterpret_cast<Array<float,1,1>*>(&T0[(i94 * T0.size[1]) + (i92 + (-((-((T0.size[1] + 16) + (-((ceilDiv(T0.size[1], 16)) * 16)))
) + 16)))]),((i94 < T0.size[0]) && ((i20 + 2) < (ceilDiv(T0.size[1], 16)))));
    }
    NVFUSER_UPDATE_MAGIC_ZERO
    #pragma unroll
    for(nvfuser_index_t i22 = 0; i22 < 16; ++i22) {
      if ((i36 < T0.size[0])) {
        T1[i36]
          = T1[i36]
          + T2[(((nvfuser_index_t)threadIdx.x) * 16) + i22 + ((i20 % 3) * (16 * 16))];
      }
    }
    NVFUSER_UPDATE_MAGIC_ZERO
    Ampere::cpAsyncPartialBarrier<1>();
    __barrier_sync(0);
    Ampere::cpAsyncCommit();
  }
}

What does this loop?

#pragma unroll
  for(nvfuser_index_t i19 = 2; i19 < 3; ++i19) {
    Ampere::cpAsyncCommit();
  }

As far as I can see, it's a CircularInitProlog loop. What is it supposed to do? Looks like the loop initially has an zero-init expression, but for some reason it doesn't show up in the final code.

The end part of the main loop looks also odd to me:

  Ampere::cpAsyncPartialBarrier<1>();
    __barrier_sync(0);
    Ampere::cpAsyncCommit();

I think __barrier_sync is necessary because the way T1 and T2 are parallelized, but is this order of async barrier followed by commit correct? Shouldn't this be reversed?

@csarofeen csarofeen changed the title CUTLASS style predicate evaluation : peeled predicate shift [MatMul] CUTLASS style predicate evaluation : peeled predicate shift Oct 19, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants