diff --git a/build_variables.bzl b/build_variables.bzl index 93de9df44dee8..e211283b49623 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -704,6 +704,7 @@ libtorch_cuda_core_sources = [ "torch/csrc/jit/codegen/cuda/lower_fusion_simplifier.cpp", "torch/csrc/jit/codegen/cuda/lower_index.cpp", "torch/csrc/jit/codegen/cuda/lower_index_hoist.cpp", + "torch/csrc/jit/codegen/cuda/lower_predicate_peeling.cpp", "torch/csrc/jit/codegen/cuda/lower_insert_syncs.cpp", "torch/csrc/jit/codegen/cuda/lower_instrument.cpp", "torch/csrc/jit/codegen/cuda/lower_loops.cpp", diff --git a/torch/csrc/jit/codegen/cuda/codegen.cpp b/torch/csrc/jit/codegen/cuda/codegen.cpp index 6d36b0672e323..b4bfc01599ee8 100644 --- a/torch/csrc/jit/codegen/cuda/codegen.cpp +++ b/torch/csrc/jit/codegen/cuda/codegen.cpp @@ -2534,7 +2534,16 @@ class CudaKernelGenerator : private OptOutConstDispatch { } indent() << "for(nvfuser_index_t " << gen_index; - if (loop->iter_domain()->isParallelized()) { + + // TODO: need to revisit this one, + // a predicate peeled loop would be guaranteed not to be + // a degenerate loop. So the comments on the else block + // should not apply here. + if (loop->iter_domain()->isParallelized() || + loop->loopTransformInfo().predicate_peel_stage == + PredicatePeelStage::Main || + loop->loopTransformInfo().double_buffer_loop_stage == + DoubleBufferLoopStage::CircularInitProlog) { code_ << " = " << gen_start << "; "; } else { // Do not start at the start of the ID when not parallelized. Instead, diff --git a/torch/csrc/jit/codegen/cuda/compute_at_map.cpp b/torch/csrc/jit/codegen/cuda/compute_at_map.cpp index bde91fda24f87..534a3c3f178f5 100644 --- a/torch/csrc/jit/codegen/cuda/compute_at_map.cpp +++ b/torch/csrc/jit/codegen/cuda/compute_at_map.cpp @@ -796,6 +796,9 @@ void ComputeAtMap::allocateIndexVariables() { std::make_unique(DoubleBufferIndices( {{DoubleBufferLoopStage::Prolog, IrBuilder::create(c10::nullopt)}, + // TODO: need to add upper and lower prolog here too. + {DoubleBufferLoopStage::CircularInitProlog, + IrBuilder::create(c10::nullopt)}, {DoubleBufferLoopStage::Main, IrBuilder::create(c10::nullopt)}, {DoubleBufferLoopStage::Epilog, diff --git a/torch/csrc/jit/codegen/cuda/index_compute.cpp b/torch/csrc/jit/codegen/cuda/index_compute.cpp index 8f63effad0d19..e2d64b75f8215 100644 --- a/torch/csrc/jit/codegen/cuda/index_compute.cpp +++ b/torch/csrc/jit/codegen/cuda/index_compute.cpp @@ -1679,6 +1679,21 @@ std::vector Index::getGlobalProducerStridedIndices( if (root_ind->isZeroInt()) { continue; } else { + if (auto tile_entry = + GpuLower::current()->predicatePeelingInfo().getMaybePeeledTileEntry( + loops, root_dom[i])) { + // Add the "predicate peeling offset", see [Predicate Peeling] + // to the tensor index if this root domain is predicate peeled. + if (tile_entry.value().peel_stage != PredicatePeelStage::Prolog && + !tile_entry.value() + .for_loop->loopTransformInfo() + .is_base_index_loop) { + root_ind = SimplifyingIrBuilder::subExpr( + root_ind, + PredicatePeeling::getSplitTileMainOffset( + root_dom[i], tile_entry.value().inner_factor)); + } + } auto strided_ind = SimplifyingIrBuilder::mulExpr(root_ind, strides[i]); if (i == root_dom.size() - 1 && vectorize_shift != nullptr) { strided_inds[i] = @@ -2142,6 +2157,10 @@ std::vector Index::getGlobalConsumerStridedIndices( auto strides = getStrides(consumer_tv); auto root_inds = getRootIndices(consumer_tv, loops, index_from_id_graph); + // Indices should now be mapped onto IterDomains in consumer, so just grab + // and use them. + auto root_dom = consumer_tv->getMaybeRFactorDomain(); + // Global striding auto vectorize_shift = loops.empty() ? nullptr : loops.back()->vectorize_shift(); @@ -2151,8 +2170,20 @@ std::vector Index::getGlobalConsumerStridedIndices( if (root_inds[i]->isZeroInt()) { continue; } else { - auto strided_ind = - SimplifyingIrBuilder::mulExpr(root_inds[i], strides[i]); + auto root_ind = root_inds[i]; + if (auto tile_entry = GpuLower::current() + ->predicatePeelingInfo() + .getMaybePeeledTileEntry(loops, root_dom[i])) { + // Add the "predicate peeling offset", see [Predicate Peeling] + // to the tensor index if this root domain is predicate peeled. + if (tile_entry.value().peel_stage != PredicatePeelStage::Prolog) { + root_ind = SimplifyingIrBuilder::subExpr( + root_ind, + PredicatePeeling::getSplitTileMainOffset( + root_dom[i], tile_entry.value().inner_factor)); + } + } + auto strided_ind = SimplifyingIrBuilder::mulExpr(root_ind, strides[i]); if (i == strides.size() - 1 && vectorize_shift != nullptr) { strided_inds[i] = SimplifyingIrBuilder::addExpr(strided_ind, vectorize_shift); @@ -2323,7 +2354,8 @@ std::vector Index::getNonGlobalConsumerStridedIndices( if (is_prolog && is_circular_buffer_loop) { // The buffer switching logic is the same as original index // in the case of circular buffer prolog. - db_switch_index = db_loop->index(); + db_switch_index = + db_loop->isTrivial() ? db_loop->start() : db_loop->index(); } else { // Switching index generated for main loop or epilog component. db_switch_index = SimplifyingIrBuilder::modExpr( @@ -3194,10 +3226,20 @@ std::vector Index::getReferenceRootPredicates( info.start_predicate_ = start_pred; } + auto maybe_tiled_entry = + gpu_lower->predicatePeelingInfo().getMaybePeeledTileEntry( + loops, contig_id); + + // Check if this predicate for this contig_id is being generated + // in predicate peeling prolog. See also [Predicate Peeling]. + bool has_peeled_prolog = maybe_tiled_entry.has_value() && + maybe_tiled_entry.value().peel_stage == PredicatePeelStage::Prolog; + // Build predicates for stop positions as: // stop_index + stop_offset < IterDomain::extent auto stop_offset = info.stop_offset_; - if (canOmitStopPredicate(stop_index, stop_offset, contig_id)) { + if (canOmitStopPredicate(stop_index, stop_offset, contig_id) && + !has_peeled_prolog) { info.stop_predicate_ = GpuLower::current()->kernel()->trueVal(); } else { auto offsetted_stop_index = @@ -3205,6 +3247,52 @@ std::vector Index::getReferenceRootPredicates( auto stop_pred = SimplifyingIrBuilder::ltExpr( offsetted_stop_index, contig_id->extent()) ->as(); + + // Modifying predicate math for predicate peeled loop: + // detailed definition see [Predicate Peeling] + if (maybe_tiled_entry.has_value()) { + auto tile_entry = maybe_tiled_entry.value(); + if (tile_entry.peel_stage == PredicatePeelStage::Prolog) { + // In predicate peeled prolog, the stop predicate is + // stop_index < tile_residue + stop_pred = SimplifyingIrBuilder::ltExpr( + offsetted_stop_index, + PredicatePeeling::getPrologPredicateOffset( + contig_id, tile_entry.inner_factor)) + ->as(); + } else if ( + // Handle the condition where the predicate peeled + // iterdomain is double/circular buffered. + tile_entry.for_loop->doubleBufferLoopStage() == + DoubleBufferLoopStage::Main && + db_axis != nullptr && + GpuLower::current()->caMap()->areMapped( + db_axis, + tile_entry.for_loop->iter_domain(), + IdMappingMode::LOOP)) { + // When the predicate peeled loop is double buffered + // or circular buffered, the producer index is skewed + // ahead of the main loop by (stage_depth-1). So on the + // predicate peeled main loop side, instead of just removing + // the predicate for this contig_id, just re-write it to + // (loop_index + stage_depth) < loop_stop, which should + // be thread uniform and very cheap to evaluate. + auto db_index = SimplifyingIrBuilder::addExpr( + tile_entry.for_loop->index(), + IrBuilder::create( + gpu_lower->doubleBufferInfo().getStageDepthFor( + tile_entry.for_loop->iter_domain()) - + 1)); + stop_pred = SimplifyingIrBuilder::ltExpr( + db_index, tile_entry.for_loop->stop()) + ->as(); + } else { + // If the predicate peeled loop is not double buffered + // then in the main stage of the predicate peeled loop + // this predicate can just be omitted. + stop_pred = gpu_lower->kernel()->trueVal(); + } + } info.stop_predicate_ = stop_pred; } diff --git a/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h b/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h index a3efd3fae2e60..33c6fecdbddd2 100644 --- a/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h @@ -571,6 +571,17 @@ class TORCH_CUDA_CU_API TensorView : public Val { // example, grouping multiple reductions. void updateMaxProducerPosition(); + //! A scheduler primitive requesting predicate peeling transform + //! on the loop generated corresponding to `axis_id`. + //! See [Predicate Peeling]. + void peelPredicatedLoop(int axis_id); + + //! Returns the iterdomain corresponding to the loop that will + //! be using the predicate peeling transform. + auto peeledSerialId() const { + return peeled_serial_id_; + } + protected: void setDomain(TensorDomain* td) { domain_ = td; @@ -605,6 +616,10 @@ class TORCH_CUDA_CU_API TensorView : public Val { //! Indicates the circular buffering stage depth if applicable. unsigned int circular_buffer_stage_ = 0; + //! Keeps track of the iterdomain that will use predicate + //! peeling transform on the corresponding loop. + IterDomain* peeled_serial_id_ = nullptr; + // special handling for CPU based zero-dim tensors (i.e. CPU Tensors that // only have one value). This is only used if on an input value, otherwise // ignored. This is important as special handling because these "scalars" diff --git a/torch/csrc/jit/codegen/cuda/ir_iostream.cpp b/torch/csrc/jit/codegen/cuda/ir_iostream.cpp index b4ba37ea930f4..80dd28ceec1bc 100644 --- a/torch/csrc/jit/codegen/cuda/ir_iostream.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_iostream.cpp @@ -618,8 +618,12 @@ void IrPrinter::handle(const GroupedWelfordOp* grouped_wop) { } void IrPrinter::handle(const LoadStoreOp* ldst) { - indent() << ldst->out() << " = " << ldst->opType() << "( " << ldst->in() - << " )\n"; + indent() << ldst->out() << " = " << ldst->opType() << "( " << ldst->in(); + if (ldst->container()->isA() && ldst->predicate() != nullptr && + ldst->predicate()->hasValue()) { + os_ << ", " << ldst->predicate()->value()->toInlineString(); + } + os_ << " )\n"; } void IrPrinter::handle(const BroadcastOp* bop) { @@ -737,6 +741,9 @@ void IrPrinter::handle(const kir::Predicate* node) { } default: os_ << node->predicate_type(); + if (node->hasValue()) { + os_ << " : " << node->value()->toInlineString(); + } break; } } @@ -817,6 +824,9 @@ void IrPrinter::handle(const kir::ForLoop* node) { handle(node->index()); os_ << " in "; handle(node->iter_domain()); + os_ << ", start = " << node->start(); + os_ << ", stop = " << node->stop(); + os_ << ", step = " << node->step(); os_ << ":\n"; handleScope(node->body()); } diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir.cpp b/torch/csrc/jit/codegen/cuda/kernel_ir.cpp index e016a9d7be45e..4b4992f1969af 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_ir.cpp +++ b/torch/csrc/jit/codegen/cuda/kernel_ir.cpp @@ -51,6 +51,19 @@ Predicate::Predicate(IrBuilderPasskey passkey, Bool* value) TORCH_INTERNAL_ASSERT(value != nullptr); } +Predicate::Predicate(IrBuilderPasskey passkey, const Predicate* other) + : Val(passkey, ValType::Predicate, DataType::Bool), + ptype_(other->ptype_), + expr_(other->expr_), + thread_pred_(other->thread_pred_), + unrolled_loop_(other->unrolled_loop_) { + TORCH_INTERNAL_ASSERT( + passkey.ir_container_->isA(), + "IR type only valid for Kernel container."); + TORCH_INTERNAL_ASSERT( + other->value_ == nullptr, "No support yet for predicate deep copy"); +} + TensorIndex::TensorIndex( IrBuilderPasskey passkey, const TensorView* view, diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir.h b/torch/csrc/jit/codegen/cuda/kernel_ir.h index 772f518e38b9e..5645fe827adce 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_ir.h +++ b/torch/csrc/jit/codegen/cuda/kernel_ir.h @@ -61,6 +61,8 @@ class TORCH_CUDA_CU_API Predicate final : public Val { explicit Predicate(IrBuilderPasskey passkey, Bool* value); + explicit Predicate(IrBuilderPasskey passkey, const Predicate* other); + PredicateType predicate_type() const { return ptype_; } @@ -564,6 +566,10 @@ struct LoopTransformInfo { DoubleBufferLoopStage double_buffer_loop_stage = DoubleBufferLoopStage::NotApplicable; + //! Tracks the predicate peeling stage of this loop, + //! see [Predicate Peeling]. + PredicatePeelStage predicate_peel_stage = PredicatePeelStage::NoApplicable; + //! Tracks if this for loop is for base index calculation for //! lifted memory address. bool is_base_index_loop = false; @@ -580,6 +586,12 @@ struct LoopTransformInfo { return *this; } + //! Setter API + LoopTransformInfo& predicatePeelStage(PredicatePeelStage stage) { + predicate_peel_stage = stage; + return *this; + } + bool operator==(const LoopTransformInfo& other) const { return double_buffer_loop_stage == other.double_buffer_loop_stage && is_base_index_loop == other.is_base_index_loop; diff --git a/torch/csrc/jit/codegen/cuda/lower2device.cpp b/torch/csrc/jit/codegen/cuda/lower2device.cpp index ef2b79422f474..5114c43876161 100644 --- a/torch/csrc/jit/codegen/cuda/lower2device.cpp +++ b/torch/csrc/jit/codegen/cuda/lower2device.cpp @@ -18,6 +18,7 @@ #include #include #include +#include #include #include #include @@ -351,6 +352,7 @@ void GpuLower::lower(Fusion* fusion, DataType index_type) { compute_at_map_->allocateIndexVariables(); addressComputeInfo().build(fusion_); + predicatePeelingInfo().build(fusion_); // Run our passes keeping the lowered expressions and forwarding // them @@ -394,13 +396,16 @@ void GpuLower::lower(Fusion* fusion, DataType index_type) { const auto exprs_double_buffered = DoubleBufferPass::run(exprs_with_precompute_address); + const auto predicate_peeled = + PredicatePeeling::peelPredicatedLoop(exprs_double_buffered); + // This pass inserts predicates as well as branches in the code. Up until now // the code is explicitly single shot for loop based. Need to be careful in // later passes when doing any kind of insertions in loop nest structure as // insertions could be on if then or else instead of directly on a for loop. dumpExprsIfEnabled(exprs_double_buffered, "Before UnrollPass"); const auto exprs_unrolled_loops = - UnrollPass::runPass(fusion_, exprs_double_buffered); + UnrollPass::runPass(fusion_, predicate_peeled); dumpExprsIfEnabled( exprs_unrolled_loops, "Before processMisalignedVectorization"); diff --git a/torch/csrc/jit/codegen/cuda/lower2device.h b/torch/csrc/jit/codegen/cuda/lower2device.h index 3bfecb4c2aaed..ffe19742a5a4a 100644 --- a/torch/csrc/jit/codegen/cuda/lower2device.h +++ b/torch/csrc/jit/codegen/cuda/lower2device.h @@ -13,6 +13,7 @@ #include #include #include +#include #include #include #include @@ -175,6 +176,14 @@ class TORCH_CUDA_CU_API GpuLower : public NonCopyable { return profile_; } + auto& predicatePeelingInfo() { + return predicate_peeling_info_; + } + + const auto& predicatePeelingInfo() const { + return predicate_peeling_info_; + } + // This is an interface to propagate information after expression // replacement on the kernel IR. E.g.: // for ... @@ -223,6 +232,7 @@ class TORCH_CUDA_CU_API GpuLower : public NonCopyable { FusedReductionInfo fused_reduction_info_; std::shared_ptr sync_map_; AddressComputeInfo address_compute_info_; + PredicatePeelingInfo predicate_peeling_info_; kir::KernelPerformanceProfile profile_; std::unordered_set divisible_splits_; diff --git a/torch/csrc/jit/codegen/cuda/lower_allocation.cpp b/torch/csrc/jit/codegen/cuda/lower_allocation.cpp index 1bb98e96b0bdf..ea271cf10444e 100644 --- a/torch/csrc/jit/codegen/cuda/lower_allocation.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_allocation.cpp @@ -449,6 +449,13 @@ class AllocationInserter : public kir::ExprMutator { auto out_tv = out->as(); auto default_val = gpu_lower->predicateElimination().getInitValue(out_tv); + if (out_tv->isCircularBuffered() && default_val == nullptr) { + if (GpuLower::current()->predicatePeelingInfo().hasPeeledId(out_tv)) { + // Always initialize cp async output if it has peeled id. + default_val = GpuLower::current()->kernel()->zeroVal(); + } + } + Val* init = nullptr; if (expr->isA() && out_tv->hasReduction()) { TORCH_INTERNAL_ASSERT( diff --git a/torch/csrc/jit/codegen/cuda/lower_double_buffer.cpp b/torch/csrc/jit/codegen/cuda/lower_double_buffer.cpp index 9aba511b9f31f..b37abf35e0cad 100644 --- a/torch/csrc/jit/codegen/cuda/lower_double_buffer.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_double_buffer.cpp @@ -211,6 +211,11 @@ class DoubleBufferLoopCloner : public kir::IrVisitor { start = IrBuilder::subExpr( double_buffer_loop_->stop(), SimplifyingIrBuilder::create(stage_depth - 1)); + } else if (loop_type_ == DoubleBufferLoopStage::CircularInitProlog) { + // See [Predicate Peeling Interaction with Circular Buffering] + TORCH_INTERNAL_ASSERT(start->isZeroInt()); + start = SimplifyingIrBuilder::create(stage_depth - 1); + stop = SimplifyingIrBuilder::create(stage_depth); } cloned_top_level_loop_ = IrBuilder::create( @@ -263,7 +268,9 @@ class DoubleBufferLoopCloner : public kir::IrVisitor { TORCH_INTERNAL_ASSERT(!cloned_scopes_.empty()); if (loop_type_ == DoubleBufferLoopStage::Main) { - cloned_scopes_.back()->push_back(expr); + if (!canOmitInitInMainLoop(expr, double_buffer_loop_)) { + cloned_scopes_.back()->push_back(expr); + } return; } @@ -280,12 +287,82 @@ class DoubleBufferLoopCloner : public kir::IrVisitor { TORCH_INTERNAL_ASSERT(double_buffer_tv != nullptr); return out_tv == double_buffer_tv; }); + if ((loop_type_ == DoubleBufferLoopStage::Prolog && is_double_buffer_load_expr) || (loop_type_ == DoubleBufferLoopStage::Epilog && !is_double_buffer_load_expr)) { - cloned_scopes_.back()->push_back(expr); + if (lower_utils::supportInlinePredicate(expr) && + expr->isA()) { + auto ldst = expr->as(); + cloned_scopes_.back()->push_back(IrBuilder::create( + ldst->opType(), ldst->out(), ldst->in())); + } else { + cloned_scopes_.back()->push_back(expr); + } + } else if ( + loop_type_ == DoubleBufferLoopStage::CircularInitProlog && + is_double_buffer_load_expr) { + // Only need the init expressions in circular init prolog stage + if (ir_utils::isTensorScalarFillOp(expr)) { + cloned_scopes_.back()->push_back(expr); + } + } + } + + //! Returns true if the expression is an initialization expr that + //! can be omitted in main loop. + //! See [Predicate Peeling Interaction with Circular Buffering] + bool canOmitInitInMainLoop(Expr* expr, kir::ForLoop* double_buffer_loop) { + // Check that this is an initialization for cp.async. + if (!ir_utils::isCpAsyncInit(expr) || + !GpuLower::current()->predicatePeelingInfo().shouldPeelLoop( + double_buffer_loop)) { + return false; + } + + auto out_tv = ir_utils::getTvOutput(expr); + + // Check that the double buffer loop is the main stage of + // the loop defining out_tv as there might be multiple + // loops that realize double buffers. + bool db_loop_found = false; + const auto& ca_map = GpuLower::current()->caMap(); + + if (!(out_tv->isDoubleBuffered() || out_tv->isCircularBuffered()) || + !ca_map->areMapped( + GpuLower::current()->doubleBufferInfo().getDoubleBufferAxis(out_tv), + double_buffer_loop->iter_domain(), + IdMappingMode::LOOP)) { + return false; } + + // This optimization only applies when all the loops on the + // inner side of the double buffer main loop are either + // constant unrolled or parallel. + // TODO: + // Buffer alias and broadcast resolution might still + // break this. These are not showing in matmul kernels but + // would need to build out support for general safty usage. + for (auto id : out_tv->domain()->domain()) { + if (db_loop_found) { + auto loop_concrete_id = + ca_map->getConcreteMappedID(id, IdMappingMode::LOOP); + + if (!loop_concrete_id->isParallelized() && + !loop_concrete_id->extent()->isConstInt()) { + return false; + } + } + + db_loop_found = db_loop_found || + ca_map->areMapped( + id, double_buffer_loop->iter_domain(), IdMappingMode::LOOP); + } + + // Only when double buffer loop was found on out_tv could useful + // information have been inferred by this function. + return db_loop_found; } private: @@ -433,6 +510,17 @@ class DoubleBufferInserter : private kir::ExprMutator { MemoryType::Shared; }); + // If the double buffer loop is to be peeled. Will need to insert + // a circular buffer init stage to initialize the final stage of + // circular buffer space. + if (GpuLower::current()->predicatePeelingInfo().shouldPeelLoop( + double_buffer_loop) && + write_to_smem) { + auto circular_init_loop = DoubleBufferLoopCloner::clone( + double_buffer_loop, loads, DoubleBufferLoopStage::CircularInitProlog); + registerInsertBefore(double_buffer_loop, circular_init_loop); + } + // RAW sync is not inserted for double buffered tensors. The only // exception is the prologue load. bool insert_cpasync_wait = false; diff --git a/torch/csrc/jit/codegen/cuda/lower_predicate_peeling.cpp b/torch/csrc/jit/codegen/cuda/lower_predicate_peeling.cpp new file mode 100644 index 0000000000000..01e5322dd9480 --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/lower_predicate_peeling.cpp @@ -0,0 +1,309 @@ +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace cuda { + +bool PredicatePeeling::supportedPeelingLoop(IterDomain* id) { + // Not meaningful to peel a parallel loop + if (id->isParallelized()) { + return false; + } + + auto id_def = id->definition(); + + if (id_def == nullptr) { + // This case is not profitable so skip peeling support. + return false; + } else if (auto split = dynamic_cast(id_def)) { + auto split_in = split->in(); + + // This is typical case that we want to peel. + // where we advance a serial iteration through + // constant sized tiles. + return split_in->definition() == nullptr && id == split->outer() && + split->factor()->isConstInt(); + } + + // TODO: + // could possibly support more patterns using the separability analysis. + return false; +} + +void PredicatePeelingInfo::build(Fusion* fusion) { + auto used_vals = fusion->usedMathVals(); + for (auto tv : ir_utils::filterByType(used_vals)) { + // Only visit tensorviews with peeling serial id info. + if (tv->peeledSerialId() != nullptr) { + // Create a set of leaf ids for validation. + std::unordered_set leaf_id_set{ + tv->domain()->domain().begin(), tv->domain()->domain().end()}; + auto peeled_id = tv->peeledSerialId(); + + // Quick check that the peeled id at schedule + // time is still a leaf domain. + TORCH_INTERNAL_ASSERT( + leaf_id_set.count(peeled_id), + "only exisiting leaf domain supported for peeling\n", + tv->toString(), + " does not have\n", + peeled_id->toString()); + + // Insert the peeled concrete id to the recorded map. + concrete_id_of_peeled_loops_.insert( + GpuLower::current()->caMap()->getConcreteMappedID( + peeled_id, IdMappingMode::LOOP)); + } + } +} + +c10::optional PredicatePeelingInfo::getMaybePeeledTileEntry( + const std::vector& loops, + IterDomain* root_id) { + auto gpu_lower = GpuLower::current(); + + std::unordered_map concrete_id_to_loop_map; + for (auto fl : loops) { + auto concrete_loop_id = gpu_lower->caMap()->getConcreteMappedID( + fl->iter_domain(), IdMappingMode::LOOP); + concrete_id_to_loop_map[concrete_loop_id] = fl; + } + + for (auto peeled_id : concrete_id_of_peeled_loops_) { + // Need to locate the peeled loop to validate this check + auto matching_loop_it = concrete_id_to_loop_map.find(peeled_id); + if (matching_loop_it != concrete_id_to_loop_map.end()) { + // This is the only supported case at the initial stage. + // see also [Supported Case in Predicate Peeling pass] + auto split = peeled_id->definition()->as(); + if (gpu_lower->caMap()->areMapped( + split->in(), root_id, IdMappingMode::EXACT)) { + // This means the given id has been peeled. + PeeledTileEntry entry; + entry.peel_stage = + matching_loop_it->second->loopTransformInfo().predicate_peel_stage; + entry.inner_factor = split->factor(); + entry.for_loop = matching_loop_it->second; + return entry; + } + } + } + return c10::nullopt; +} + +bool PredicatePeelingInfo::hasPeeledId(const TensorView* tv) const { + for (auto id : concrete_id_of_peeled_loops_) { + if (std::any_of( + tv->domain()->domain().begin(), + tv->domain()->domain().end(), + [id](IterDomain* tv_id) { + return GpuLower::current()->caMap()->areMapped( + tv_id, id, IdMappingMode::LOOP); + })) { + return true; + } + } + return false; +} + +bool PredicatePeelingInfo::shouldPeelLoop(kir::ForLoop* forloop) const { + auto loop_concrete_id = GpuLower::current()->caMap()->getConcreteMappedID( + forloop->iter_domain(), IdMappingMode::LOOP); + + return + // Base index calculation loop should not be involved here. + !forloop->loopTransformInfo().is_base_index_loop && + concrete_id_of_peeled_loops_.count(loop_concrete_id); +} + +namespace { + +//! A utility class that deep clones the loop body of original_fl +//! and add the cloned expressions into the loop body of new_fl. +class LoopNestDeepCloner { + public: + static void clone(kir::ForLoop* original_fl, kir::ForLoop* new_fl) { + LoopNestDeepCloner cloner; + cloner.cloned_scopes_.push_back(&new_fl->body()); + for (auto expr : original_fl->body().exprs()) { + cloner.handle(expr); + } + } + + private: + void handle(Expr* expr) { + if (auto fl = dynamic_cast(expr)) { + handle(fl); + } else { + cloned_scopes_.back()->push_back(expr); + } + } + + void handle(kir::ForLoop* fl) { + auto new_fl = IrBuilder::create(fl); + cloned_scopes_.push_back(&new_fl->body()); + for (auto expr : fl->body().exprs()) { + handle(expr); + } + cloned_scopes_.pop_back(); + + cloned_scopes_.back()->push_back(new_fl); + } + + private: + std::vector cloned_scopes_; +}; + +//! The predicate peeling transform pass implementation. +class PredicatePeeledLoops : kir::ExprMutator { + public: + static std::vector run(const std::vector exprs) { + PredicatePeeledLoops peeled_loops; + peeled_loops.traverseAndInsert(exprs); + return peeled_loops.exprs_; + } + + private: + using kir::ExprMutator::handle; + + // Create the predicate peeled prolog loop + kir::ForLoop* createPeeledLoop(kir::ForLoop* fl) { + // Make clone of the outermost loop, but + // limit the loop to the first iteration. + auto peeled_loop = IrBuilder::create( + fl->iter_domain(), + fl->index(), + fl->kernel()->zeroVal(), + fl->kernel()->oneVal(), + fl->kernel()->oneVal(), + false, + nullptr, + fl->isUnrollRequired(), + fl->loopTransformInfo().predicatePeelStage(PredicatePeelStage::Prolog)); + + LoopNestDeepCloner::clone(fl, peeled_loop); + return peeled_loop; + } + + // Create the predicate peeled main loop + kir::ForLoop* createMainLoop(kir::ForLoop* fl) { + auto start = + SimplifyingIrBuilder::addExpr(fl->start(), fl->kernel()->oneVal()); + + auto main_loop = IrBuilder::create( + fl->iter_domain(), + fl->index(), + start, + fl->stop(), + fl->step(), + fl->vectorize(), + fl->vectorize_shift(), + fl->isUnrollRequired(), + fl->loopTransformInfo().predicatePeelStage(PredicatePeelStage::Main)); + + LoopNestDeepCloner::clone(fl, main_loop); + return main_loop; + } + + void handle(kir::ForLoop* fl) final { + kir::ExprMutator::handle(fl); + + if (GpuLower::current()->predicatePeelingInfo().shouldPeelLoop(fl)) { + auto loop_concrete_id = GpuLower::current()->caMap()->getConcreteMappedID( + fl->iter_domain(), IdMappingMode::LOOP); + + // Peel this loop if shouldPeel is true and it + // hasn't been processed by this pass already. + if (!peeled_iterdomains_.count(loop_concrete_id)) { + // Create and insert the peeled prolog. + auto peeled_loop = createPeeledLoop(fl); + registerInsertBefore(fl, peeled_loop); + + if (fl->stop()->isOneInt()) { + // This is the case for double buffer prolog, + // will just use the peeled loop as the new + // double buffer prolog. + registerRemove(fl); + } else { + // Peel off one iteration from the main + // component of the original loop. + auto new_main_loop = createMainLoop(fl); + registerReplace(fl, new_main_loop); + } + // Record peeling of this loop + peeled_iterdomains_.insert(loop_concrete_id); + } + } + } + + void handle(kir::IfThenElse* ite) final { + TORCH_INTERNAL_ASSERT( + false, "no support for inserted ite before this point"); + } + + private: + //! Keeps track of loop concrete iterdomains that has already been + //! transformed by the loop peeling pass. + //! This pass runs after double buffering pass, so + //! we may encounter forloops corresponding to the + //! same loop domain twice. In this case we only need + //! to peel the first occurrence (the prolog). + std::unordered_set peeled_iterdomains_; +}; + +} // namespace + +std::vector PredicatePeeling::peelPredicatedLoop( + const std::vector exprs) { + return PredicatePeeledLoops::run(exprs); +} + +Val* PredicatePeeling::getPrologPredicateOffset( + IterDomain* id, + Val* tile_factor) { + // Assume X = original extent, + // L = tile factor + // Offset needs to satisfy: + // X % L == 0 : return L else return X % L + + // X + L - ceildiv(X,L)*L + auto orig_extent = id->extent(); + + // X + L + auto extent_plus_factor = + SimplifyingIrBuilder::addExpr(orig_extent, tile_factor); + + // ceildiv(X,L) + auto extent_ceildiv_factor = + SimplifyingIrBuilder::ceilDivExpr(orig_extent, tile_factor); + + // ceildiv(X,L)* L + auto extent_round_up = + SimplifyingIrBuilder::mulExpr(extent_ceildiv_factor, tile_factor); + + // X + L - ceildiv(X,L)*L + return SimplifyingIrBuilder::subExpr(extent_plus_factor, extent_round_up); +} + +Val* PredicatePeeling::getSplitTileMainOffset( + IterDomain* id, + Val* tile_factor) { + // This offset is to be **subtracted** from the tensor index + // on the predicate peeling main loop. + return SimplifyingIrBuilder::subExpr( + tile_factor, PredicatePeeling::getPrologPredicateOffset(id, tile_factor)); +} + +} // namespace cuda +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/lower_predicate_peeling.h b/torch/csrc/jit/codegen/cuda/lower_predicate_peeling.h new file mode 100644 index 0000000000000..8c6a67108b06d --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/lower_predicate_peeling.h @@ -0,0 +1,176 @@ +#pragma once + +#include + +#include +#include +#include + +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace cuda { + +//! Note: [Predicate Peeling] +//! This is a loop transformation that attempts to eliminate predicate +//! evaluation in a serial loop. +//! +//! A simple example showing how this trick works is, say we have +//! T0 [I(T0.size[0])] -> split(32) -> T0 [Io(ceilDiv(T0.size[0],32)), +//! Ii(32)], which generates the following code: for i in +//! 0..ceilDiv(T0.size[0],32) +//! for j in 0..32: +//! // assume we need to initialize in this kernel +//! T0[i*32+j] = 0; +//! if i*32+j < T0.size[0] +//! T0[i*32+j] ... +//! The above code generates 32 predicates as the predicate is inlined in the +//! inner loop. +//! +//! The simplification trick is to convert the loop into: +//! +//! let ceilMod(a, b) = a %b == 0 ? b : a %b; +//! +//! // peeled residue prolog : (called initial evaluation in cutlass) +//! // Very similar to the original loop except the +//! // outer loop extent and the predicate extent +//! // are modified. +//! +//! for i in 0..1 +//! for j in 0..32: +//! T0[i*32+j] = 0; +//! if i*32+j < ceilMod(T0.size[0], 32) +//! T0[i*32+j] ... +//! // peeled residue main loop +//! // (called steady-state in cutlass) +//! for i in 0..ceilDiv(T0.size[0],32)-1 +//! for j in 0..32: +//! // No need to initialize as we know the predicate +//! // is all true. +//! // This significantly reduces memory instruction +//! // congestion with cp.async kernels. +//! // No longer need to predicate here as +//! // the residue part of the root iterdomain has +//! // been peeled away. +//! T0[i*32+j + ceilMod(T0.size[0],32)] ... +//! +//! Some details on the predicate peeling pass implemented here: +//! 1. The peeled loop is separate into 2 `PredicatePeelingStage`'s: +//! The first iteration is peeled and marked as +//! PredicatePeelingStage::Prolog, while +//! the rest of the iterations are PredicatePeelingStage::Main. +//! +//! 2. The predicate indexing at the (predicate peeling) prolog is modified to +//! make the access within the residue tile +//! +//! 3. The address indexing at the (predicate peeling) main loop is modified +//! by adding the residue tile as offset. +//! +//! 4. The initialization within (predicate peeling) main loop can be lifted +//! out of the main loop if there are no other not-unrolled serial loops. +//! +//! Note: [Supported Case in Predicate Peeling pass]: +//! The predicate peeling transform is a very specialized pattern used in matmul +//! and some non-trivial overhead would be involved to generalize. +//! +//! The current support for predicate peeling is for a very specific case only +//! and some consideration is needed regarding whether more complex peeling +//! pattern along this line should be pursued. +//! +//! The only supported pattern now is: +//! tile_o, tile_i = split(root_id, inner_factor); +//! where tile_o is required to be on the leaf domain and is where the loop +//! peeling primitive should be applied. +//! +//! The inner_factor is required to be a compile-time constant. +//! +// Note: [Predicate Peeling Interaction with Circular Buffering] +//! +//! 1. In the case where the original loop is double buffered, the first +//! iteration of the double buffer prolog loop is used as +//! PredicatePeelingStage::Prolog and the rest are labeled as +//! PredicatePeelingStage::Main. +//! +//! 2. If a tv is double buffered or circular buffered, the gmem load stage is +//! (stage_depth-1) iterations ahead, so would need to add an extra (simpler) +//! predicate to avoid out-of-bound access. +//! +//! 3. A circular buffer init prolog is added in the case of a predicate tiled +//! and circular buffered loop, as the circular buffer loop prolog only +//! prefetches up to iteration `stage_depth-1`, and if the initialization were +//! to be lifted out of the main loop stage, would also need to initialize for +//! iteration `stage_depth` to make sure the shared memory buffer is all zero +//! initialized. + +//! A data structure used by PredicatePeelingInfo to communicate which +//! for loop is predicate peeled along with the peeling stage and +//! original inner tiling factor +//! TODO: some info here is redundant now. +struct PeeledTileEntry { + //! The peeling stage, see note above. + PredicatePeelStage peel_stage = PredicatePeelStage::NoApplicable; + + //! The original splitting factor, see [Supported Case in Predicate Peeling + //! pass]. + Val* inner_factor = nullptr; + + //! The actual for loop that is predicate peeled. + kir::ForLoop* for_loop = nullptr; +}; + +//! Keeps track fo predicate peeled loops requested +//! from scheduler. +class PredicatePeelingInfo { + public: + //! Returns true if predicate peeling is requested by scheduler + //! for the given loop. + bool shouldPeelLoop(kir::ForLoop* forloop) const; + + //! Collect predicate peeling information from fusion. + void build(Fusion* fusion); + + //! Returns the peeled entry info if the given loop is predicate + //! peeled and the given root_id matches with the tiled root id. + //! + //! see also [Supported Case in Predicate Peeling pass]. + c10::optional getMaybePeeledTileEntry( + const std::vector& loops, + IterDomain* root_id); + + //! Returns true if any iterdomain on the given tv's tensor + //! domain corresponds to a predicate peeled loop. + bool hasPeeledId(const TensorView* tv) const; + + private: + //! Keeps track of loop concrete iterdomains that were predicate + //! peeled. + std::unordered_set concrete_id_of_peeled_loops_; +}; + +namespace PredicatePeeling { + +//! User space check that makes sure the loop can +//! actually be peeled to remove predicates. +//! See also +//! [Supported Case in Predicate Peeling pass]: +bool supportedPeelingLoop(IterDomain* id); + +//! Kernel IR pass that applies the predicate peeling transformation. +std::vector peelPredicatedLoop(const std::vector exprs); + +//! Utility to generate the residual extend used in predicate +//! peeling prolog. +Val* getPrologPredicateOffset(IterDomain* id, Val* tile_factor); + +//! Utility to generate the offset applied to tensor indices +//! in predicate peeling main loop. +Val* getSplitTileMainOffset(IterDomain* id, Val* tile_factor); + +} // namespace PredicatePeeling + +} // namespace cuda +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/scheduler/matmul.cpp b/torch/csrc/jit/codegen/cuda/scheduler/matmul.cpp index dd61f393ac6da..44071a01931a6 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/matmul.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/matmul.cpp @@ -548,6 +548,10 @@ void scheduleMatmul( acw_smem->liftReadAddress(); bcw_smem->liftReadAddress(); } + + if (params.peel_main_loop) { + cc->peelPredicatedLoop(2); + } } } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/scheduler/matmul.h b/torch/csrc/jit/codegen/cuda/scheduler/matmul.h index 200343f6d6fc9..354e2affeab04 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/matmul.h +++ b/torch/csrc/jit/codegen/cuda/scheduler/matmul.h @@ -48,6 +48,9 @@ class MatmulParam { // TODO: add gmem_write address for // latency bound kernels. } index_lift_options; + + //! Enables predicate peeling mainloop: + bool peel_main_loop = true; }; //! Prototype auto scheduling function. diff --git a/torch/csrc/jit/codegen/cuda/tensor_view.cpp b/torch/csrc/jit/codegen/cuda/tensor_view.cpp index 86664cb11b398..a92ef481687a4 100644 --- a/torch/csrc/jit/codegen/cuda/tensor_view.cpp +++ b/torch/csrc/jit/codegen/cuda/tensor_view.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include // Cleanup @@ -219,6 +220,7 @@ TensorView::TensorView(const TensorView* src, IrCloner* ir_cloner) lift_read_address_(src->lift_read_address_), lift_write_address_(src->lift_write_address_), skew_double_buffer_loop_(src->skew_double_buffer_loop_), + peeled_serial_id_(ir_cloner->clone(src->peeled_serial_id_)), compute_with_consumers_(ir_cloner->clone(src->compute_with_consumers_)), compute_with_pos_(src->compute_with_pos_) {} @@ -1429,6 +1431,13 @@ void TensorView::applyMmaSwizzle(MmaOptions options) { } } +void TensorView::peelPredicatedLoop(int axis_id) { + auto id = axis(axis_id); + TORCH_CHECK( + PredicatePeeling::supportedPeelingLoop(id), "unsupported loop peeling"); + peeled_serial_id_ = id; +} + TensorViewBuilder& TensorViewBuilder::ndims(size_t ndims) { TORCH_CHECK(shape_.empty() || shape_.size() == ndims); TORCH_CHECK(contiguity_.empty() || contiguity_.size() == ndims); diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu3.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu3.cpp index ff3b3983290f1..372cb27db30be 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu3.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu3.cpp @@ -7174,6 +7174,92 @@ TEST_F(NVFuserTest, FusionVectorizeWelford2_CUDA) { __FILE__); } +// Simple test case for predicate peeling use pattern +TEST_F(NVFuserTest, FusionPredicatePeeling1_CUDA) { + // requires ampere+ GPU + Fusion fusion; + FusionGuard fg(&fusion); + + // Using vectorization so need to keep n multiple of 4. + int m = 33, n = 48; + + TensorView* tv0 = makeContigTensor(2); + + fusion.addInput(tv0); + auto tv1 = set(tv0); + auto tv2 = set(tv1); + fusion.addOutput(tv2); + + tv2->split(0, 16); + + tv0->computeAt(tv2, 1); + tv1->computeAt(tv2, -1); + + tv2->axis(1)->parallelize(ParallelType::TIDx); + tv2->axis(2)->parallelize(ParallelType::TIDy); + tv2->peelPredicatedLoop(0); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({m, n}, options); + + FusionExecutor fe; + + fe.compileFusion(&fusion, {t0}); + auto cg_outputs = fe.runFusion({t0}); + + testValidate(&fusion, cg_outputs, {t0}, {t0}, __LINE__, __FILE__); +} + +// A circular buffer test case for predicate peeling use pattern +TEST_F(NVFuserTest, FusionPredicatePeeling2_CUDA) { + // requires ampere+ GPU + if (!deviceMajorMinorCheck(8)) { + GTEST_SKIP() << "skipping tests on pre-AMPERE GPUs"; + return; + } + // requires ampere+ GPU + Fusion fusion; + FusionGuard fg(&fusion); + + // Using vectorization so need to keep n multiple of 4. + int m = 33, n = 45; + + TensorView* tv0 = makeContigTensor(2); + + fusion.addInput(tv0); + auto tv1 = sum(tv0, {1}); + fusion.addOutput(tv1); + + auto tv2 = tv0->cacheAfter(LoadStoreOpType::CpAsync); + + tv1->split(1, 16); + tv1->split(0, 16); + // make tile + tv1->reorder({{1, 2}, {2, 1}}); + + tv0->computeAt(tv1, 2); + + tv2->axis(-1)->parallelize(ParallelType::TIDx); + + tv1->axis(0)->parallelize(ParallelType::BIDx); + tv1->axis(2)->parallelize(ParallelType::TIDx); + tv1->peelPredicatedLoop(1); + + tv2->setMemoryType(MemoryType::Shared); + tv2->circularBuffer(3); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({m, n}, options); + + FusionExecutor fe; + + fe.compileFusion(&fusion, {t0}); + auto cg_outputs = fe.runFusion({t0}); + auto ref = t0.sum({1}); + + testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__); +} + // Test file size should be up to 10K LoC. Create a new file for more tests. } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/type.cpp b/torch/csrc/jit/codegen/cuda/type.cpp index cac4114f0245a..5b01249261fed 100644 --- a/torch/csrc/jit/codegen/cuda/type.cpp +++ b/torch/csrc/jit/codegen/cuda/type.cpp @@ -1228,6 +1228,7 @@ bool isProlog(DoubleBufferLoopStage stage) { case DoubleBufferLoopStage::Prolog: case DoubleBufferLoopStage::UpperProlog: case DoubleBufferLoopStage::LowerProlog: + case DoubleBufferLoopStage::CircularInitProlog: return true; default: @@ -1258,6 +1259,25 @@ TORCH_CUDA_CU_API std::ostream& operator<<( return os; } +std::ostream& operator<<( + std::ostream& os, + const PredicatePeelStage& peel_stage) { + switch (peel_stage) { + case PredicatePeelStage::NoApplicable: + break; + case PredicatePeelStage::Prolog: + os << "{PeeledProlog}"; + break; + case PredicatePeelStage::Main: + os << "{PeeledMain}"; + break; + default: + TORCH_INTERNAL_ASSERT(false, "unsupported loop attribute"); + break; + } + return os; +} + } // namespace cuda } // namespace fuser } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/type.h b/torch/csrc/jit/codegen/cuda/type.h index 850773797f6e4..d3d1ef4d0895c 100644 --- a/torch/csrc/jit/codegen/cuda/type.h +++ b/torch/csrc/jit/codegen/cuda/type.h @@ -328,10 +328,13 @@ enum class DoubleBufferLoopStage { Prolog, Main, Epilog, + CircularInitProlog, UpperProlog, LowerProlog }; +enum class PredicatePeelStage { NoApplicable, Prolog, Main }; + //! Returns true if the given stage is a prolog stage //! for some double buffered or circular buffered loop. bool isProlog(DoubleBufferLoopStage stage); @@ -387,6 +390,9 @@ TORCH_CUDA_CU_API std::ostream& operator<<( const DoubleBufferLoopStage); TORCH_CUDA_CU_API std::ostream& operator<<(std::ostream&, const Swizzle2DType&); TORCH_CUDA_CU_API std::ostream& operator<<(std::ostream&, const SwizzleMode&); +TORCH_CUDA_CU_API std::ostream& operator<<( + std::ostream&, + const PredicatePeelStage&); std::string stringifyBooleanOp(const UnaryOpType); std::string stringifyBooleanOp(const BinaryOpType);