Skip to content

Commit

Permalink
fix FIRST_DOT by adding dependency from tmaCopy to wait
Browse files Browse the repository at this point in the history
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
manman-ren committed Jun 20, 2024
1 parent 0bc5294 commit deac629
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 16 deletions.
56 changes: 40 additions & 16 deletions lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,14 @@ class CoarseSchedule {
}

void insertDepsOfOp(Operation *op, int stage, CoarseSchedule::Cluster cluster,
bool includeArg) {
bool includeArg,
DenseMap<Operation *, Operation *> *additionalDep) {
// Look in additionalDep.
if (additionalDep && additionalDep->find(op) != additionalDep->end()) {
Operation *wait = (*additionalDep)[op];
if (insertIfAbsent(wait, stage, cluster))
insertDepsOfOp(wait, stage, cluster, includeArg, additionalDep);
}
for (Value operand : op->getOperands()) {
Value v = operand;
llvm::SmallDenseSet<Value> seen;
Expand All @@ -121,7 +128,7 @@ class CoarseSchedule {
Operation *defOp = v.getDefiningOp();
if (defOp && defOp->getBlock() == op->getBlock()) {
if (insertIfAbsent(defOp, stage, cluster)) {
insertDepsOfOp(defOp, stage, cluster, includeArg);
insertDepsOfOp(defOp, stage, cluster, includeArg, additionalDep);
}
}
}
Expand Down Expand Up @@ -310,11 +317,12 @@ static void createAsyncCopy(scf::ForOp &forOp, tt::LoadOp loadOp, Value alloc,
loadOp.erase();
}

static void createTMAAsyncCopy(
static Operation *createTMAAsyncCopy(
scf::ForOp &forOp, tt::ExperimentalDescriptorLoadOp loadOp, Value alloc,
Value insertIdx, Value extractIdx, Value barrier, Operation *waitOp,
Value phase, CoarseSchedule &schedule,
llvm::MapVector<Operation *, LoadInfo> &loadToInfo, int numStages) {
llvm::MapVector<Operation *, LoadInfo> &loadToInfo, int numStages,
DenseMap<Operation *, Operation *> &TMAUserToWait) {
assert(phase && "Phase value is required for TMA async copy.");
OpBuilder builder(forOp);
Value zero = builder.create<arith::ConstantIntOp>(forOp.getLoc(), 0, 32);
Expand Down Expand Up @@ -344,6 +352,7 @@ static void createTMAAsyncCopy(
loadOffsets[0] = extractIdx;
auto viewLoad =
builder.create<ttg::MemDescSubviewOp>(loc, subviewTy, alloc, loadOffsets);
TMAUserToWait[viewLoad] = waitOp; // viewLoad will depend on barrierWait
if (isMMV3Load) {
auto alloc = cast<ttg::LocalAllocOp>((*loadOp->getUsers().begin()));
alloc.replaceAllUsesWith(viewLoad.getResult());
Expand All @@ -366,6 +375,7 @@ static void createTMAAsyncCopy(
loadOp->replaceAllUsesWith(result);
}
loadOp.erase();
return copy;
}

// If all the transitive uses of the given value have are used by a convert to
Expand Down Expand Up @@ -806,16 +816,17 @@ schedulePrologueAndEpilogue(scf::ForOp forOp, CoarseSchedule &schedule,

// Add dependencies of anchor ops to the coarse schedule. Schedule them to
// the same stage and ordering cluster as the anchor op.
static void scheduleDependencies(scf::ForOp forOp, CoarseSchedule &schedule,
int numStages) {
static void
scheduleDependencies(scf::ForOp forOp, CoarseSchedule &schedule, int numStages,
DenseMap<Operation *, Operation *> &TMAUserToWait) {
SmallVector<std::tuple<Operation *, int, CoarseSchedule::Cluster>>
opsInOrder = schedule.getOpsInOrder(forOp);
// Schedule dependencies stage by stage.
for (int stage = 0; stage < numStages; stage++) {
for (auto [op, stage_, cluster] : opsInOrder) {
if (stage_ != stage)
continue;
schedule.insertDepsOfOp(op, stage, cluster, false);
schedule.insertDepsOfOp(op, stage, cluster, false, &TMAUserToWait);
}
}
}
Expand Down Expand Up @@ -856,14 +867,14 @@ static void scheduleDistanceOneDependencies(scf::ForOp forOp,
// Exception: Schedule loads with a distance of 1 together
// with the current op.
schedule.insertIfAbsent(defOp, stage, cluster);
schedule.insertDepsOfOp(defOp, stage, cluster, true);
schedule.insertDepsOfOp(defOp, stage, cluster, true, nullptr);
} else {
if (dist1Cluster.count(&cluster) == 0) {
dist1Cluster[&cluster] = schedule.clusters.newBefore(cluster);
}
schedule.insertIfAbsent(defOp, stage + 1, dist1Cluster[&cluster]);
schedule.insertDepsOfOp(defOp, stage + 1, dist1Cluster[&cluster],
true);
true, nullptr);
}
}
}
Expand Down Expand Up @@ -973,6 +984,8 @@ static void createTMABarrierAndWait(
}
SmallVector<SmallVector<AsyncLoad *>> loadGroups;
llvm::SmallDenseSet<Operation *> visited;
// TODO: do not group if SWP_FIRST_DOT is true. We need to make sure there is
// one barrier for each loadOp.
// Find groups of loads that can share the same barrier. We look consecutive
// loads and check that there are uses in between.
for (AsyncLoad &asyncLoad : asyncLoads) {
Expand Down Expand Up @@ -1066,7 +1079,8 @@ static void createTMABarrierAndWait(
static SmallVector<Value>
createAsyncOps(scf::ForOp &forOp, CoarseSchedule &schedule,
llvm::MapVector<Operation *, LoadInfo> &loadToInfo,
SmallVector<Value> &barriers, int numStages) {
SmallVector<Value> &barriers, int numStages,
DenseMap<Operation *, Operation *> &TMAUserToWait) {
// Calculate the number of buffers needed for each load.
// TODO pawel: we could do more fine-grained allocation here and
// allocate only the number of buffers that specific loads need.
Expand All @@ -1083,6 +1097,7 @@ createAsyncOps(scf::ForOp &forOp, CoarseSchedule &schedule,
// pipelining post-processing.
numBuffers++;
};
LDBG("numBuffers" << numBuffers);

SmallVector<AsyncLoad> asyncLoads;
SmallVector<Value> allocs;
Expand Down Expand Up @@ -1160,11 +1175,19 @@ createAsyncOps(scf::ForOp &forOp, CoarseSchedule &schedule,
schedule, prefetchCluster, loadToInfo, numStages);
} else {
auto descLoad = cast<tt::ExperimentalDescriptorLoadOp>(asyncLoad.loadOp);
createTMAAsyncCopy(forOp, descLoad, asyncLoad.alloc, insertIdx,
extractIdx, asyncLoad.barrier, asyncLoad.waitOp, phase,
schedule, loadToInfo, numStages);
Operation *copy = createTMAAsyncCopy(
forOp, descLoad, asyncLoad.alloc, insertIdx, extractIdx,
asyncLoad.barrier, asyncLoad.waitOp, phase, schedule, loadToInfo,
numStages, TMAUserToWait);
// TMACopyToWait[copy] = asyncLoad.waitOp;
}
}
// TODO: make sure each copy has a unique waitOp.
DenseSet<Operation *> uniqueWaits;
for (auto [copy, wait] : TMAUserToWait) {
assert(!uniqueWaits.count(wait));
uniqueWaits.insert(wait);
}
SmallVector<Value> newYieldOperands = {insertIdx, extractIdx};
if (phase)
newYieldOperands.push_back(phase);
Expand Down Expand Up @@ -1208,9 +1231,10 @@ bool mlir::triton::preProcessLoopAndGetSchedule(
});

SmallVector<Value> barriers;
DenseMap<Operation *, Operation *> TMAUserToWait;
// Convert the loads into async loads and create the allocs.
SmallVector<Value> allocs =
createAsyncOps(forOp, coarseSchedule, loadToInfo, barriers, numStages);
SmallVector<Value> allocs = createAsyncOps(
forOp, coarseSchedule, loadToInfo, barriers, numStages, TMAUserToWait);

LLVM_DEBUG({
LDBG("Coarse schedule with async loads:");
Expand All @@ -1224,7 +1248,7 @@ bool mlir::triton::preProcessLoopAndGetSchedule(
coarseSchedule.dump();
});

scheduleDependencies(forOp, coarseSchedule, numStages);
scheduleDependencies(forOp, coarseSchedule, numStages, TMAUserToWait);
LLVM_DEBUG({
LDBG("Coarse schedule with dependencies:");
coarseSchedule.dump();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,15 @@ Operation *mlir::triton::predicateOp(RewriterBase &rewriter, Operation *op,
return op;
}

// TODO: create if statement around wait_barrier
if (auto wait = dyn_cast<ttng::WaitBarrierOp>(op)) {
rewriter.setInsertionPoint(wait);
auto ifOp =
rewriter.create<scf::IfOp>(wait->getLoc(), pred, /*else=*/false);
// move wait to ifOp
rewriter.moveOpBefore(wait, ifOp.thenBlock(), ifOp.thenBlock()->begin());
return ifOp;
}
assert("don't know how to predicate this op" && false);
return op;
}
Expand Down

0 comments on commit deac629

Please sign in to comment.