Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MatMul] CUTLASS style predicate evaluation : peeled predicate shift #1973

Open
wants to merge 18 commits into
base: skew_double_buffer
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions build_variables.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -704,6 +704,7 @@ libtorch_cuda_core_sources = [
"torch/csrc/jit/codegen/cuda/lower_fusion_simplifier.cpp",
"torch/csrc/jit/codegen/cuda/lower_index.cpp",
"torch/csrc/jit/codegen/cuda/lower_index_hoist.cpp",
"torch/csrc/jit/codegen/cuda/lower_predicate_peeling.cpp",
"torch/csrc/jit/codegen/cuda/lower_insert_syncs.cpp",
"torch/csrc/jit/codegen/cuda/lower_instrument.cpp",
"torch/csrc/jit/codegen/cuda/lower_loops.cpp",
Expand Down
11 changes: 10 additions & 1 deletion torch/csrc/jit/codegen/cuda/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2534,7 +2534,16 @@ class CudaKernelGenerator : private OptOutConstDispatch {
}

indent() << "for(nvfuser_index_t " << gen_index;
if (loop->iter_domain()->isParallelized()) {

// TODO: need to revisit this one,
// a predicate peeled loop would be guaranteed not to be
// a degenerate loop. So the comments on the else block
// should not apply here.
if (loop->iter_domain()->isParallelized() ||
loop->loopTransformInfo().predicate_peel_stage ==
PredicatePeelStage::Main ||
loop->loopTransformInfo().double_buffer_loop_stage ==
DoubleBufferLoopStage::CircularInitProlog) {
code_ << " = " << gen_start << "; ";
} else {
// Do not start at the start of the ID when not parallelized. Instead,
Expand Down
3 changes: 3 additions & 0 deletions torch/csrc/jit/codegen/cuda/compute_at_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -796,6 +796,9 @@ void ComputeAtMap::allocateIndexVariables() {
std::make_unique<DoubleBufferIndices>(DoubleBufferIndices(
{{DoubleBufferLoopStage::Prolog,
IrBuilder::create<Int>(c10::nullopt)},
// TODO: need to add upper and lower prolog here too.
{DoubleBufferLoopStage::CircularInitProlog,
IrBuilder::create<Int>(c10::nullopt)},
{DoubleBufferLoopStage::Main,
IrBuilder::create<Int>(c10::nullopt)},
{DoubleBufferLoopStage::Epilog,
Expand Down
96 changes: 92 additions & 4 deletions torch/csrc/jit/codegen/cuda/index_compute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1679,6 +1679,21 @@ std::vector<Val*> Index::getGlobalProducerStridedIndices(
if (root_ind->isZeroInt()) {
continue;
} else {
if (auto tile_entry =
GpuLower::current()->predicatePeelingInfo().getMaybePeeledTileEntry(
loops, root_dom[i])) {
// Add the "predicate peeling offset", see [Predicate Peeling]
// to the tensor index if this root domain is predicate peeled.
if (tile_entry.value().peel_stage != PredicatePeelStage::Prolog &&
!tile_entry.value()
.for_loop->loopTransformInfo()
.is_base_index_loop) {
root_ind = SimplifyingIrBuilder::subExpr(
root_ind,
PredicatePeeling::getSplitTileMainOffset(
root_dom[i], tile_entry.value().inner_factor));
}
}
auto strided_ind = SimplifyingIrBuilder::mulExpr(root_ind, strides[i]);
if (i == root_dom.size() - 1 && vectorize_shift != nullptr) {
strided_inds[i] =
Expand Down Expand Up @@ -2142,6 +2157,10 @@ std::vector<Val*> Index::getGlobalConsumerStridedIndices(
auto strides = getStrides(consumer_tv);
auto root_inds = getRootIndices(consumer_tv, loops, index_from_id_graph);

// Indices should now be mapped onto IterDomains in consumer, so just grab
// and use them.
auto root_dom = consumer_tv->getMaybeRFactorDomain();

// Global striding
auto vectorize_shift =
loops.empty() ? nullptr : loops.back()->vectorize_shift();
Expand All @@ -2151,8 +2170,20 @@ std::vector<Val*> Index::getGlobalConsumerStridedIndices(
if (root_inds[i]->isZeroInt()) {
continue;
} else {
auto strided_ind =
SimplifyingIrBuilder::mulExpr(root_inds[i], strides[i]);
auto root_ind = root_inds[i];
if (auto tile_entry = GpuLower::current()
->predicatePeelingInfo()
.getMaybePeeledTileEntry(loops, root_dom[i])) {
// Add the "predicate peeling offset", see [Predicate Peeling]
// to the tensor index if this root domain is predicate peeled.
if (tile_entry.value().peel_stage != PredicatePeelStage::Prolog) {
root_ind = SimplifyingIrBuilder::subExpr(
root_ind,
PredicatePeeling::getSplitTileMainOffset(
root_dom[i], tile_entry.value().inner_factor));
}
}
auto strided_ind = SimplifyingIrBuilder::mulExpr(root_ind, strides[i]);
if (i == strides.size() - 1 && vectorize_shift != nullptr) {
strided_inds[i] =
SimplifyingIrBuilder::addExpr(strided_ind, vectorize_shift);
Expand Down Expand Up @@ -2323,7 +2354,8 @@ std::vector<Val*> Index::getNonGlobalConsumerStridedIndices(
if (is_prolog && is_circular_buffer_loop) {
// The buffer switching logic is the same as original index
// in the case of circular buffer prolog.
db_switch_index = db_loop->index();
db_switch_index =
db_loop->isTrivial() ? db_loop->start() : db_loop->index();
} else {
// Switching index generated for main loop or epilog component.
db_switch_index = SimplifyingIrBuilder::modExpr(
Expand Down Expand Up @@ -3194,17 +3226,73 @@ std::vector<RootPredicateInfo> Index::getReferenceRootPredicates(
info.start_predicate_ = start_pred;
}

auto maybe_tiled_entry =
gpu_lower->predicatePeelingInfo().getMaybePeeledTileEntry(
loops, contig_id);

// Check if this predicate for this contig_id is being generated
// in predicate peeling prolog. See also [Predicate Peeling].
bool has_peeled_prolog = maybe_tiled_entry.has_value() &&
maybe_tiled_entry.value().peel_stage == PredicatePeelStage::Prolog;

// Build predicates for stop positions as:
// stop_index + stop_offset < IterDomain::extent
auto stop_offset = info.stop_offset_;
if (canOmitStopPredicate(stop_index, stop_offset, contig_id)) {
if (canOmitStopPredicate(stop_index, stop_offset, contig_id) &&
!has_peeled_prolog) {
info.stop_predicate_ = GpuLower::current()->kernel()->trueVal();
} else {
auto offsetted_stop_index =
SimplifyingIrBuilder::addExpr(stop_index, stop_offset);
auto stop_pred = SimplifyingIrBuilder::ltExpr(
offsetted_stop_index, contig_id->extent())
->as<Bool>();

// Modifying predicate math for predicate peeled loop:
// detailed definition see [Predicate Peeling]
if (maybe_tiled_entry.has_value()) {
auto tile_entry = maybe_tiled_entry.value();
if (tile_entry.peel_stage == PredicatePeelStage::Prolog) {
// In predicate peeled prolog, the stop predicate is
// stop_index < tile_residue
stop_pred = SimplifyingIrBuilder::ltExpr(
offsetted_stop_index,
PredicatePeeling::getPrologPredicateOffset(
contig_id, tile_entry.inner_factor))
->as<Bool>();
} else if (
// Handle the condition where the predicate peeled
// iterdomain is double/circular buffered.
tile_entry.for_loop->doubleBufferLoopStage() ==
DoubleBufferLoopStage::Main &&
db_axis != nullptr &&
GpuLower::current()->caMap()->areMapped(
db_axis,
tile_entry.for_loop->iter_domain(),
IdMappingMode::LOOP)) {
// When the predicate peeled loop is double buffered
// or circular buffered, the producer index is skewed
// ahead of the main loop by (stage_depth-1). So on the
// predicate peeled main loop side, instead of just removing
// the predicate for this contig_id, just re-write it to
// (loop_index + stage_depth) < loop_stop, which should
// be thread uniform and very cheap to evaluate.
auto db_index = SimplifyingIrBuilder::addExpr(
tile_entry.for_loop->index(),
IrBuilder::create<Int>(
gpu_lower->doubleBufferInfo().getStageDepthFor(
tile_entry.for_loop->iter_domain()) -
1));
stop_pred = SimplifyingIrBuilder::ltExpr(
db_index, tile_entry.for_loop->stop())
->as<Bool>();
} else {
// If the predicate peeled loop is not double buffered
// then in the main stage of the predicate peeled loop
// this predicate can just be omitted.
stop_pred = gpu_lower->kernel()->trueVal();
}
}
info.stop_predicate_ = stop_pred;
}

Expand Down
15 changes: 15 additions & 0 deletions torch/csrc/jit/codegen/cuda/ir_interface_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,17 @@ class TORCH_CUDA_CU_API TensorView : public Val {
// example, grouping multiple reductions.
void updateMaxProducerPosition();

//! A scheduler primitive requesting predicate peeling transform
//! on the loop generated corresponding to `axis_id`.
//! See [Predicate Peeling].
void peelPredicatedLoop(int axis_id);

//! Returns the iterdomain corresponding to the loop that will
//! be using the predicate peeling transform.
auto peeledSerialId() const {
return peeled_serial_id_;
}

protected:
void setDomain(TensorDomain* td) {
domain_ = td;
Expand Down Expand Up @@ -605,6 +616,10 @@ class TORCH_CUDA_CU_API TensorView : public Val {
//! Indicates the circular buffering stage depth if applicable.
unsigned int circular_buffer_stage_ = 0;

//! Keeps track of the iterdomain that will use predicate
//! peeling transform on the corresponding loop.
IterDomain* peeled_serial_id_ = nullptr;

// special handling for CPU based zero-dim tensors (i.e. CPU Tensors that
// only have one value). This is only used if on an input value, otherwise
// ignored. This is important as special handling because these "scalars"
Expand Down
14 changes: 12 additions & 2 deletions torch/csrc/jit/codegen/cuda/ir_iostream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -618,8 +618,12 @@ void IrPrinter::handle(const GroupedWelfordOp* grouped_wop) {
}

void IrPrinter::handle(const LoadStoreOp* ldst) {
indent() << ldst->out() << " = " << ldst->opType() << "( " << ldst->in()
<< " )\n";
indent() << ldst->out() << " = " << ldst->opType() << "( " << ldst->in();
if (ldst->container()->isA<kir::Kernel>() && ldst->predicate() != nullptr &&
ldst->predicate()->hasValue()) {
os_ << ", " << ldst->predicate()->value()->toInlineString();
}
os_ << " )\n";
}

void IrPrinter::handle(const BroadcastOp* bop) {
Expand Down Expand Up @@ -737,6 +741,9 @@ void IrPrinter::handle(const kir::Predicate* node) {
}
default:
os_ << node->predicate_type();
if (node->hasValue()) {
os_ << " : " << node->value()->toInlineString();
}
break;
}
}
Expand Down Expand Up @@ -817,6 +824,9 @@ void IrPrinter::handle(const kir::ForLoop* node) {
handle(node->index());
os_ << " in ";
handle(node->iter_domain());
os_ << ", start = " << node->start();
os_ << ", stop = " << node->stop();
os_ << ", step = " << node->step();
os_ << ":\n";
handleScope(node->body());
}
Expand Down
13 changes: 13 additions & 0 deletions torch/csrc/jit/codegen/cuda/kernel_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,19 @@ Predicate::Predicate(IrBuilderPasskey passkey, Bool* value)
TORCH_INTERNAL_ASSERT(value != nullptr);
}

Predicate::Predicate(IrBuilderPasskey passkey, const Predicate* other)
: Val(passkey, ValType::Predicate, DataType::Bool),
ptype_(other->ptype_),
expr_(other->expr_),
thread_pred_(other->thread_pred_),
unrolled_loop_(other->unrolled_loop_) {
TORCH_INTERNAL_ASSERT(
passkey.ir_container_->isA<kir::Kernel>(),
"IR type only valid for Kernel container.");
TORCH_INTERNAL_ASSERT(
other->value_ == nullptr, "No support yet for predicate deep copy");
}

TensorIndex::TensorIndex(
IrBuilderPasskey passkey,
const TensorView* view,
Expand Down
12 changes: 12 additions & 0 deletions torch/csrc/jit/codegen/cuda/kernel_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ class TORCH_CUDA_CU_API Predicate final : public Val {

explicit Predicate(IrBuilderPasskey passkey, Bool* value);

explicit Predicate(IrBuilderPasskey passkey, const Predicate* other);

PredicateType predicate_type() const {
return ptype_;
}
Expand Down Expand Up @@ -564,6 +566,10 @@ struct LoopTransformInfo {
DoubleBufferLoopStage double_buffer_loop_stage =
DoubleBufferLoopStage::NotApplicable;

//! Tracks the predicate peeling stage of this loop,
//! see [Predicate Peeling].
PredicatePeelStage predicate_peel_stage = PredicatePeelStage::NoApplicable;

//! Tracks if this for loop is for base index calculation for
//! lifted memory address.
bool is_base_index_loop = false;
Expand All @@ -580,6 +586,12 @@ struct LoopTransformInfo {
return *this;
}

//! Setter API
LoopTransformInfo& predicatePeelStage(PredicatePeelStage stage) {
predicate_peel_stage = stage;
return *this;
}

bool operator==(const LoopTransformInfo& other) const {
return double_buffer_loop_stage == other.double_buffer_loop_stage &&
is_base_index_loop == other.is_base_index_loop;
Expand Down
7 changes: 6 additions & 1 deletion torch/csrc/jit/codegen/cuda/lower2device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <torch/csrc/jit/codegen/cuda/lower_magic_zero.h>
#include <torch/csrc/jit/codegen/cuda/lower_misaligned_vectorization.h>
#include <torch/csrc/jit/codegen/cuda/lower_predicate.h>
#include <torch/csrc/jit/codegen/cuda/lower_predicate_peeling.h>
#include <torch/csrc/jit/codegen/cuda/lower_replace_size.h>
#include <torch/csrc/jit/codegen/cuda/lower_shift.h>
#include <torch/csrc/jit/codegen/cuda/lower_unroll.h>
Expand Down Expand Up @@ -351,6 +352,7 @@ void GpuLower::lower(Fusion* fusion, DataType index_type) {
compute_at_map_->allocateIndexVariables();

addressComputeInfo().build(fusion_);
predicatePeelingInfo().build(fusion_);

// Run our passes keeping the lowered expressions and forwarding
// them
Expand Down Expand Up @@ -394,13 +396,16 @@ void GpuLower::lower(Fusion* fusion, DataType index_type) {
const auto exprs_double_buffered =
DoubleBufferPass::run(exprs_with_precompute_address);

const auto predicate_peeled =
PredicatePeeling::peelPredicatedLoop(exprs_double_buffered);

// This pass inserts predicates as well as branches in the code. Up until now
// the code is explicitly single shot for loop based. Need to be careful in
// later passes when doing any kind of insertions in loop nest structure as
// insertions could be on if then or else instead of directly on a for loop.
dumpExprsIfEnabled(exprs_double_buffered, "Before UnrollPass");
const auto exprs_unrolled_loops =
UnrollPass::runPass(fusion_, exprs_double_buffered);
UnrollPass::runPass(fusion_, predicate_peeled);

dumpExprsIfEnabled(
exprs_unrolled_loops, "Before processMisalignedVectorization");
Expand Down
10 changes: 10 additions & 0 deletions torch/csrc/jit/codegen/cuda/lower2device.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <torch/csrc/jit/codegen/cuda/lower_mem_index.h>
#include <torch/csrc/jit/codegen/cuda/lower_predicate.h>
#include <torch/csrc/jit/codegen/cuda/lower_predicate_elimination.h>
#include <torch/csrc/jit/codegen/cuda/lower_predicate_peeling.h>
#include <torch/csrc/jit/codegen/cuda/lower_shift.h>
#include <torch/csrc/jit/codegen/cuda/lower_sync_information.h>
#include <torch/csrc/jit/codegen/cuda/lower_thread_predicate.h>
Expand Down Expand Up @@ -175,6 +176,14 @@ class TORCH_CUDA_CU_API GpuLower : public NonCopyable {
return profile_;
}

auto& predicatePeelingInfo() {
return predicate_peeling_info_;
}

const auto& predicatePeelingInfo() const {
return predicate_peeling_info_;
}

// This is an interface to propagate information after expression
// replacement on the kernel IR. E.g.:
// for ...
Expand Down Expand Up @@ -223,6 +232,7 @@ class TORCH_CUDA_CU_API GpuLower : public NonCopyable {
FusedReductionInfo fused_reduction_info_;
std::shared_ptr<const SyncMap> sync_map_;
AddressComputeInfo address_compute_info_;
PredicatePeelingInfo predicate_peeling_info_;
kir::KernelPerformanceProfile profile_;
std::unordered_set<Split*> divisible_splits_;

Expand Down
7 changes: 7 additions & 0 deletions torch/csrc/jit/codegen/cuda/lower_allocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,13 @@ class AllocationInserter : public kir::ExprMutator {
auto out_tv = out->as<TensorView>();
auto default_val = gpu_lower->predicateElimination().getInitValue(out_tv);

if (out_tv->isCircularBuffered() && default_val == nullptr) {
if (GpuLower::current()->predicatePeelingInfo().hasPeeledId(out_tv)) {
// Always initialize cp async output if it has peeled id.
default_val = GpuLower::current()->kernel()->zeroVal();
}
}

Val* init = nullptr;
if (expr->isA<ReductionOp>() && out_tv->hasReduction()) {
TORCH_INTERNAL_ASSERT(
Expand Down
Loading