-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MatMul] CUTLASS style predicate evaluation : peeled predicate shift #1973
base: skew_double_buffer
Are you sure you want to change the base?
Conversation
@@ -778,7 +778,7 @@ void IndexLowering::handle(const LoadStoreOp* ldst) { | |||
pushBack(IrBuilder::create<LoadStoreOp>(ldst->opType(), out, in)); | |||
GpuLower::current()->propagateExprInfo(ldst, back()); | |||
if (ldst->predicate()) { | |||
back()->setPredicate(ldst->predicate()); | |||
back()->setPredicate(IrBuilder::create<kir::Predicate>(ldst->predicate())); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why creating a new instance here?
return SimplifyingIrBuilder::subExpr(extent_plus_factor, extent_round_up); | ||
} | ||
|
||
Val* PredicatePeeling::getSplitTileMainOffset( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note to self: How is this used?
//! Returns true if the expression is an initialization expr that | ||
//! can be omitted in main loop. | ||
//! See [Predicate Peeling Interaction with Circular Buffering] | ||
bool canOmitInitInMainLoop(Expr* expr, kir::ForLoop* double_buffer_loop) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Trying to understand how this works with the PredicatePeeling2
test, but didn't see any zero initialization even when this function was modified to return always false
. Changed the test not to use cp.async
, and then the zero init appeared. Why is there no init when cp.async
is used?
// This optimization only applies when all the loops on the | ||
// inner side of the double buffer main loop are either | ||
// constant unrolled or parallel. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why does it apply only in those cases?
// If the double buffer loop is to be peeled. Will need to insert | ||
// a circular buffer init stage to initialize the final stage of | ||
// circular buffer space. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not following this. Can you please give an example?
Here's the generated code from the
What does this loop?
As far as I can see, it's a The end part of the main loop looks also odd to me:
I think |
This PR introduces a very specific predicate simplification technique that works well with matmul kernels.
A simple example showing how this trick works is, say we have
T0 [I(T0.size[0])] -> split(32) -> T0 [Io(ceilDiv(T0.size[0],32)), Ii(32)], which generates the following code:
The above code generates 32 predicates as the predicate is inlined in the inner loop.
The simplification trick is to convert the loop into:
This way there is no predicate and no initialization in the peeled residue main loop
This PR does a simple implementation of the above technique that currently only supports a direct split of K dimension from a root iterdomain, which covers matmul, and batched matmul use cases.
TODO: while useful for matmul, this technique would be tricky to extend to work well with tensor contraction kernels.