diff --git a/torch/csrc/jit/codegen/cuda/codegen.cpp b/torch/csrc/jit/codegen/cuda/codegen.cpp index 10fc7867025ce5..76c5bc4ac82b0e 100644 --- a/torch/csrc/jit/codegen/cuda/codegen.cpp +++ b/torch/csrc/jit/codegen/cuda/codegen.cpp @@ -504,12 +504,60 @@ class CudaKernelGenerator : private OptOutConstDispatch { return index.str(); } + // Generate the tensor index that are directly added + // to a base address pointer. So all the components + // are computed in units of bytes. + std::string genTensorAddressIndex( + const kir::TensorIndex* ti, + DataType dtype) { + bool first = true; + std::stringstream index; + for (auto* ind : ti->indices()) { + if (!ind->isZeroInt()) { + if (!first) { + index << " + "; + } + + // Multiply all the components here by the size of the data + // type to get byte offset. + index << "(" << genInline(ind) << ")*" << dataTypeSize(dtype); + first = false; + } + } + + // If there is a uniform component in this tensor index, + // just add them too. + // See also, [Double Buffer Uniform Offset]. + if (ti->uniformAddress() != nullptr) { + if (!first) { + index << " + "; + } + index << genInline(ti->uniformAddress()); + first = false; + } + + if (first) { + index << "0"; + } + + return index.str(); + } + void handle(const kir::TensorIndex* ti) final { bool is_volatile = ti->view()->getMemoryType() == MemoryType::Global && kernel_->summary().sync_map.needsRawSync(ti->view()).hasBID(); if (is_volatile) { code_ << "*(volatile " << ti->getDataType().value() << "*)&"; } + + if (ti->hasBaseAddress()) { + // WAR path to generate a tensor index with pointer content. + code_ << "reinterpret_cast<" << ti->view()->dtype() << "*>(" + << gen(ti->baseAddress()) << ")" + << "[" << genTensorIndex(ti) << "]"; + return; + } + code_ << varName(ti->view()) << "[" << genTensorIndex(ti) << "]"; } @@ -545,20 +593,41 @@ class CudaKernelGenerator : private OptOutConstDispatch { return ss.str(); } + //! Generates the given value as a pointer address as + //! either: + //! 1. hosted_base_ptr + address_index + //! 2. &Tensor[index] + //! depending on if the given index value carries + //! a hoisted component or not. + std::string genMaybeHoistedPointer(const Val* val) { + auto ti = dynamic_cast(val); + TORCH_INTERNAL_ASSERT(ti != nullptr, "only support tensor index input"); + std::stringstream ss; + + if (ti->hasBaseAddress()) { + ss << genTensorAddressIndex(ti, ti->view()->dtype()) << "," + << gen(ti->baseAddress()); + } else { + ss << "&" << gen(ti) << "\n"; + } + + return ss.str(); + } + // Utility function to emit a cp.async intrinsic void genCpAsync(const LoadStoreOp* ldst, int vec_size) { auto dtype = ldst->in()->getDataType().value(); if (ldst->predicate() == nullptr) { // Out of line predicate variant - indent() << "Ampere::cpAsync(" - << genVectorPointer(ldst->out(), dtype, vec_size) << "," - << genVectorPointer(ldst->in(), dtype, vec_size) << ");\n"; + indent() << "Ampere::cpAsync<" << dtype << "," << vec_size << ">(" + << genMaybeHoistedPointer(ldst->out()) << "," + << genMaybeHoistedPointer(ldst->in()) << ");\n"; } else { // Inline predicate variant - indent() << "Ampere::cpAsync(" - << genVectorPointer(ldst->out(), dtype, vec_size) << "," - << genVectorPointer(ldst->in(), dtype, vec_size) << "," + indent() << "Ampere::cpAsync<" << dtype << "," << vec_size << ">(" + << genMaybeHoistedPointer(ldst->out()) << "," + << genMaybeHoistedPointer(ldst->in()) << "," << genInline(ldst->predicate()) << ");\n"; } } @@ -579,8 +648,7 @@ class CudaKernelGenerator : private OptOutConstDispatch { } code_ << " ("; code_ << "*" << genVectorPointer(ldst->out(), dtype, vector_word_size) - << "," - << "&" << gen(ldst->in()) << ");\n"; + << "," << genMaybeHoistedPointer(ldst->in()) << ");\n"; genBankConflictCheck(ldst->in(), 16); } @@ -697,6 +765,25 @@ class CudaKernelGenerator : private OptOutConstDispatch { !(out_tv->isDoubleBuffered() || out_tv->isCircularBuffered())) { // Vectorized initialization indent() << varName(out_tv) << ".set(" << gen(uop->in()) << ");\n"; + } else if ( + uop->out()->isA() && + uop->out()->as()->useSmemAddress()) { + // A special resource string "smemReset" is used if + // the unary op writes to the shared memory using + // the lifted 32b shared mem pointer. + // This mode is reserved for resetting shared memory + // space at the moment currently. + auto ti = uop->out()->as(); + // Special case branch for smem reset + // FIXME: only support filling zero at the moment: + // could possibly extend. + TORCH_INTERNAL_ASSERT( + uop->in()->isZero(), "only support filling zero in smem reset"); + + indent() << "smemReset<" << ti->view()->dtype() << "," + << vector_word_size << ">(" << gen(ti->baseAddress()) + << "+" << genTensorAddressIndex(ti, ti->view()->dtype()) + << ");\n"; } else { // Note: currently arraySet option is not vectorized, so it will // rely on auto vectorization pass of cuda compiler. @@ -2607,7 +2694,11 @@ class CudaKernelGenerator : private OptOutConstDispatch { alloc_map_.emplace(alloc->buffer(), alloc); if (!alloc->buffer()->isA()) { - indent() << buffer_dtype << " " << gen(alloc->buffer()) << ";\n"; + indent() << buffer_dtype << " " << gen(alloc->buffer()); + if (alloc->zeroInit()) { + code_ << " = 0"; + } + code_ << ";\n"; return; } @@ -2684,12 +2775,59 @@ class CudaKernelGenerator : private OptOutConstDispatch { } void handle(const kir::AddressCompute* address_compute) final { - indent() << "// Address tensor for indexing " - << varName(address_compute->dataTv()) << "\n"; - indent() << gen(address_compute->addressTv()) << " = " - << genTensorIndex( - address_compute->dataTv()->as()) - << ";\n"; + // FIXME: + // All the global/shared memory address/offset manipulations + // to reduce register usage are currently lumped into this single + // kernel IR operator. + // + // If there's any need to commit to the current codegen tweaks + // longer, could consider separating them into more IR nodes. + if (address_compute->opType() == + kir::AddressCompute::AddressComputeOpType::DOUBLE_BUFFER_SWITCH) { + indent() << "doubleBufferSwitch<" << address_compute->stageNumber() << "," + << address_compute->loopOffset() << ">(" + << gen(address_compute->doubleBufferSwitchIndex()) << "," + << gen(address_compute->loopIndex()) << "," + << gen(address_compute->doubleBufferByteSize()) << ");\n"; + } else if ( + address_compute->opType() == + kir::AddressCompute::AddressComputeOpType::DOUBLE_BUFFER_UPDATE) { + indent() << "doubleBufferUpdate<" << address_compute->stageNumber() << "," + << address_compute->loopOffset() << ">(" + << gen(address_compute->addressTv()) << "," + << gen(address_compute->loopIndex()) << "," + << gen(address_compute->doubleBufferByteSize()) << ");\n"; + } else if ( + address_compute->opType() == + kir::AddressCompute::AddressComputeOpType::GMEM_INCREMENT) { + indent() << gen(address_compute->addressTv()) << "+=" + << genTensorAddressIndex( + address_compute->incrementValue(), + address_compute->dataTv()->dtype()) + << ";\n"; + } else if ( + address_compute->opType() == + kir::AddressCompute::AddressComputeOpType::GMEM_DECREMENT) { + indent() << gen(address_compute->addressTv()) << "-=" + << genTensorAddressIndex( + address_compute->incrementValue(), + address_compute->dataTv()->dtype()) + << ";\n"; + } else { + indent() << "//Base Address:::\n"; + indent() << gen(address_compute->addressTv()); + + if (address_compute->addressTv()->dtype() == DataType::Pointer) { + code_ << " = (DataPointer) &" + << gen(address_compute->dataTv()->as()) + << ";\n"; + } else if ( + address_compute->addressTv()->dtype() == DataType::SmemAddress) { + code_ << " = Turing::util::toSmem(&" + << gen(address_compute->dataTv()->as()) + << ");\n"; + } + } } void handle(const kir::GridSync* sync) final { diff --git a/torch/csrc/jit/codegen/cuda/executor.cpp b/torch/csrc/jit/codegen/cuda/executor.cpp index 59c63e6cfe775b..0a89399fd59ae1 100644 --- a/torch/csrc/jit/codegen/cuda/executor.cpp +++ b/torch/csrc/jit/codegen/cuda/executor.cpp @@ -65,6 +65,8 @@ typedef int int32_t; typedef unsigned int uint32_t; typedef long long int int64_t; typedef unsigned long long int uint64_t; +typedef char* DataPointer; +typedef unsigned SmemAddress; )"; } diff --git a/torch/csrc/jit/codegen/cuda/index_compute.cpp b/torch/csrc/jit/codegen/cuda/index_compute.cpp index 7fbdda71f6af4c..2d62a645319547 100644 --- a/torch/csrc/jit/codegen/cuda/index_compute.cpp +++ b/torch/csrc/jit/codegen/cuda/index_compute.cpp @@ -1681,7 +1681,20 @@ std::vector Index::getGlobalProducerStridedIndices( loops, root_dom[i])) { // Add the "predicate peeling offset", see [Predicate Peeling] // to the tensor index if this root domain is predicate peeled. - if (tile_entry.value().peel_stage != PredicatePeelStage::Prolog && + + // Incremental mode should add offset at prolog, + // See Note [Predicate Peeing interaction with Incremental Offset] + bool is_increment = + std::any_of(loops.begin(), loops.end(), [](kir::ForLoop* fl) { + return fl->loopTransformInfo().is_increment_loop; + }); + + bool should_add_offset = + (tile_entry.value().peel_stage != PredicatePeelStage::Prolog && + !producer_tv->shouldLiftReadAddress()) || + (tile_entry.value().peel_stage == PredicatePeelStage::Prolog && + producer_tv->shouldLiftReadAddress() && is_increment); + if (should_add_offset && !tile_entry.value() .for_loop->loopTransformInfo() .is_base_index_loop) { @@ -1701,16 +1714,6 @@ std::vector Index::getGlobalProducerStridedIndices( } } - if (shouldUseLiftedAddress(producer_tv, consumer_tv, loops)) { - auto maybe_address_record = - GpuLower::current()->addressComputeInfo().getMaybeLiftedAddress( - producer_tv, consumer_tv); - - auto address_index = generateAddressTensorIndex( - loops, maybe_address_record.value()->addressTensor()); - strided_inds.push_back(address_index); - } - return strided_inds; } @@ -1987,56 +1990,57 @@ std::vector Index::getNonGlobalProducerStridedIndices( // No need to compute double buffer index in the address compute loop // as they have been handled with addtional offsets. if (!db_loop->isBaseIndexLoop()) { - auto loop_index = - db_loop->isTrivial() ? db_loop->start() : db_loop->index(); - - // Need to add the producer outer main loop index by 1 - // in the case of lower prolog, see the example in - // [Skew Double Buffer Loop Transformation] - auto consumer_db_loop = - gpu_lower->doubleBufferInfo().getDoubleBufferLoop( - consumer_tv, loops); - - if (consumer_db_loop != nullptr) { - if (gpu_lower->doubleBufferInfo().isLowerPrologWithin( - consumer_db_loop->iter_domain(), db_loop->iter_domain())) { - if (consumer_db_loop->doubleBufferLoopStage() == - DoubleBufferLoopStage::LowerProlog) { - loop_index = SimplifyingIrBuilder::addExpr( - loop_index, gpu_lower->kernel()->oneVal()); + auto maybe_read_offset = + GpuLower::current()->doubleBufferInfo().getReadSwitchIndex( + producer_tv); + + // The double buffer switching indices are now applied in two + // different ways, depending on if the index is lifted or not. + // + // When lifted, the double buffer switching index is computed + // separately as a "double buffer offset" and added to the + // uniform section of the tensor index. + // When not lifted, the behavior stays the same as before + // i.e. they are computed inline. + // See also: + // [Double Buffer Uniform Offset]. + if (!maybe_read_offset.has_value()) { + auto loop_index = + db_loop->isTrivial() ? db_loop->start() : db_loop->index(); + + // Need to add the producer outer main loop index by 1 + // in the case of lower prolog, see the example in + // [Skew Double Buffer Loop Transformation] + auto consumer_db_loop = + gpu_lower->doubleBufferInfo().getDoubleBufferLoop( + consumer_tv, loops); + + if (consumer_db_loop != nullptr) { + if (gpu_lower->doubleBufferInfo().isLowerPrologWithin( + consumer_db_loop->iter_domain(), db_loop->iter_domain())) { + if (consumer_db_loop->doubleBufferLoopStage() == + DoubleBufferLoopStage::LowerProlog) { + loop_index = SimplifyingIrBuilder::addExpr( + loop_index, gpu_lower->kernel()->oneVal()); + } } } - } - auto stage_depth = gpu_lower->doubleBufferInfo().getStageDepthFor( - db_loop->iter_domain()); - auto db_switch_index = SimplifyingIrBuilder::modExpr( - loop_index, SimplifyingIrBuilder::create(stage_depth)); + auto stage_depth = gpu_lower->doubleBufferInfo().getStageDepthFor( + db_loop->iter_domain()); + auto db_switch_index = SimplifyingIrBuilder::modExpr( + loop_index, SimplifyingIrBuilder::create(stage_depth)); - auto original_alloc_size = - gpu_lower->doubleBufferInfo().getOriginalAllocSize(producer_tv); - auto db_strided_index = - SimplifyingIrBuilder::mulExpr(db_switch_index, original_alloc_size); - strided_inds.push_back(db_strided_index); + auto original_alloc_size = + gpu_lower->doubleBufferInfo().getOriginalAllocSize(producer_tv); + auto db_strided_index = SimplifyingIrBuilder::mulExpr( + db_switch_index, original_alloc_size); + strided_inds.push_back(db_strided_index); + } } } } - // Below is the code path to take when the indexing math has a component - // that has been pre-computed before. So just generate the logic - // that gets the correct pre-computed index. - if (should_use_lifted_address) { - auto maybe_address_record = - GpuLower::current()->addressComputeInfo().getMaybeLiftedAddress( - producer_tv, consumer_tv); - - TORCH_INTERNAL_ASSERT( - maybe_address_record.has_value(), "Address record not found"); - auto address_index = generateAddressTensorIndex( - loops, maybe_address_record.value()->addressTensor()); - strided_inds.push_back(address_index); - } - return strided_inds; } @@ -2326,7 +2330,11 @@ std::vector Index::getNonGlobalConsumerStridedIndices( TORCH_INTERNAL_ASSERT( strided_inds.size() == consumer_tv->getMaybeRFactorDomain().size()); - if (consumer_tv->isDoubleBuffered() || consumer_tv->isCircularBuffered()) { + if ((consumer_tv->isDoubleBuffered() || consumer_tv->isCircularBuffered()) + // Lifted address case the double buffer offset is + // computed inplace into the write address buffer. + // See [Inplace double buffer update] + && !lower_utils::useDirectSmemAddress(consumer_tv)) { auto db_loop = gpu_lower->doubleBufferInfo().getDoubleBufferLoop(consumer_tv, loops); TORCH_INTERNAL_ASSERT( @@ -2371,18 +2379,6 @@ std::vector Index::getNonGlobalConsumerStridedIndices( } } - // Below is the code path to take when the indexing math has a component - // that has been pre-computed before. So just generate the logic - // that gets the correct pre-computed index. - if (shouldUseLiftedAddress(consumer_tv, consumer_tv, loops)) { - auto maybe_address_record = - GpuLower::current()->addressComputeInfo().getMaybeLiftedAddress( - consumer_tv); - auto address_index = generateAddressTensorIndex( - loops, maybe_address_record.value()->addressTensor()); - strided_inds.push_back(address_index); - } - return strided_inds; } @@ -2427,6 +2423,28 @@ kir::TensorIndex* Index::getProducerIndex( const TensorView* consumer, const std::vector& loops) { auto strided_indices = getProducerStridedIndices(producer, consumer, loops); + + // Insert base address and uniform components into the tensor + // index object directly to support separating them on the + // code gen interface. + // See also: [Pointer Addressing In Lifted Indices] + if (shouldUseLiftedAddress(producer, consumer, loops)) { + auto maybe_address_record = + GpuLower::current()->addressComputeInfo().getMaybeLiftedAddress( + producer, consumer); + + auto maybe_read_offset = + GpuLower::current()->doubleBufferInfo().getReadSwitchIndex(producer); + Val* uniform_address = nullptr; + if (maybe_read_offset.has_value()) { + uniform_address = maybe_read_offset.value(); + } + auto address_index = generateAddressTensorIndex( + loops, maybe_address_record.value()->addressTensor()); + return SimplifyingIrBuilder::create( + producer, strided_indices, address_index, uniform_address); + } + return SimplifyingIrBuilder::create( producer, strided_indices); } @@ -2456,6 +2474,22 @@ kir::TensorIndex* Index::getConsumerIndex( const TensorView* consumer, const std::vector& loops) { auto strided_indices = getConsumerStridedIndices(consumer, loops); + + // Insert base address and uniform components into the tensor + // index object directly to support separating them on the + // code gen interface. + // See also: [Pointer Addressing In Lifted Indices] + if (shouldUseLiftedAddress(consumer, consumer, loops)) { + auto maybe_address_record = + GpuLower::current()->addressComputeInfo().getMaybeLiftedAddress( + consumer); + + auto address_index = generateAddressTensorIndex( + loops, maybe_address_record.value()->addressTensor()); + return SimplifyingIrBuilder::create( + consumer, strided_indices, address_index); + } + return SimplifyingIrBuilder::create( consumer, strided_indices); } diff --git a/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp b/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp index b29a8bc417cd06..8f1a3f1d98901b 100644 --- a/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp @@ -257,6 +257,13 @@ bool Val::isZeroInt() const { return int_val.has_value() && int_val.value() == 0; } +bool Val::isZero() const { + auto int_val = getInt(); + auto double_val = getDouble(); + return (int_val.has_value() && int_val.value() == 0) || + (double_val.has_value() && double_val.value() == 0); +} + bool Val::isOneInt() const { auto int_val = getInt(); return int_val.has_value() && int_val.value() == 1; diff --git a/torch/csrc/jit/codegen/cuda/ir_base_nodes.h b/torch/csrc/jit/codegen/cuda/ir_base_nodes.h index 7d5ebad25282bc..a3d30a8365ae92 100644 --- a/torch/csrc/jit/codegen/cuda/ir_base_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_base_nodes.h @@ -285,6 +285,9 @@ class TORCH_CUDA_CU_API Val : public Statement { bool isZeroInt() const; bool isOneInt() const; + // Check zero supporting both int or double. + bool isZero() const; + // Returns the Expr that this value is an output of, returns nullptr if none // was found Expr* definition() const { @@ -342,6 +345,12 @@ class TORCH_CUDA_CU_API Val : public Statement { void resolveIndexDtype(); + // Provide a way to instantiate a 32b integer scalar + void to32b() { + TORCH_INTERNAL_ASSERT(vtype_ == ValType::Scalar && dtype_ == DataType::Int); + dtype_ = DataType::Int32; + } + protected: friend Fusion; diff --git a/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h b/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h index 8d3034902c62a1..44d7eac8065cba 100644 --- a/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h @@ -469,6 +469,10 @@ class TORCH_CUDA_CU_API TensorView : public Val { // Returns the depth of circular buffering if applicable. unsigned int circularBufferDepth() const { + if (is_double_buffered_) { + // Double buffering is circular buffering with stage 2. + return 2; + } TORCH_INTERNAL_ASSERT( is_circular_buffered_, toString(), "not circular buffered"); return circular_buffer_stage_; diff --git a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp index 3319bf28a18a9d..af94e7bf7058e2 100644 --- a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp @@ -1857,8 +1857,9 @@ void IterDomain::parallelize(ParallelType t) { // they are swizzled. TORCH_CHECK( t == ParallelType::Vectorize || t == ParallelType::TIDx || - t == ParallelType::Serial, - "Parallel type other than serial, tidx, vectorize not allowed for mma swizzled ids"); + t == ParallelType::Serial || t == ParallelType::Mma, + "Parallel type other than serial, tidx, vectorize not allowed for mma swizzled ids", + t); } parallel_type_ = t; diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir.cpp b/torch/csrc/jit/codegen/cuda/kernel_ir.cpp index 981cfaa0400fd6..dce6e140d41b5e 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_ir.cpp +++ b/torch/csrc/jit/codegen/cuda/kernel_ir.cpp @@ -66,10 +66,14 @@ Predicate::Predicate(IrBuilderPasskey passkey, const Predicate* other) TensorIndex::TensorIndex( IrBuilderPasskey passkey, const TensorView* view, - std::vector indices) + std::vector indices, + Val* base_address, + Val* uniform_address) : Val(passkey, ValType::TensorIndex, view->getDataType().value()), view_(view), - indices_(indices) { + indices_(indices), + base_address_(base_address), + uniform_address_(uniform_address) { TORCH_INTERNAL_ASSERT( passkey.ir_container_->isA(), "IR type only valid for Kernel container."); @@ -134,6 +138,67 @@ AddressCompute::AddressCompute( "IR type only valid for Kernel container."); } +AddressCompute::AddressCompute( + IrBuilderPasskey passkey, + Val* address_tensor, + Val* data_tensor, + TensorIndex* increment_value, + bool is_decrement) + : Expr(passkey, ExprType::AddressCompute), + op_type_(AddressCompute::AddressComputeOpType::GMEM_INCREMENT), + data_tensor_(data_tensor), + address_tensor_(address_tensor), + increment_value_(increment_value) { + if (is_decrement) { + op_type_ = AddressCompute::AddressComputeOpType::GMEM_DECREMENT; + } + TORCH_INTERNAL_ASSERT( + passkey.ir_container_->isA(), + "IR type only valid for Kernel container."); +} + +AddressCompute::AddressCompute( + IrBuilderPasskey passkey, + TensorView* data_tv, + Val* double_buffer_switch_index, + Val* buffer_size_in_byte, + int loop_offset, + int stage_number, + Val* loop_index) + : Expr(passkey, ExprType::AddressCompute), + op_type_(AddressCompute::AddressComputeOpType::DOUBLE_BUFFER_SWITCH), + data_tensor_(data_tv), + double_buffer_switch_index_(double_buffer_switch_index), + buffer_size_in_byte_(buffer_size_in_byte), + loop_offset_(loop_offset), + stage_number_(stage_number), + loop_index_(loop_index) { + TORCH_INTERNAL_ASSERT( + passkey.ir_container_->isA(), + "IR type only valid for Kernel container."); +} + +AddressCompute::AddressCompute( + IrBuilderPasskey passkey, + Val* address_tensor, + Val* buffer_size_in_byte, + int stage_number, + int loop_offset, + TensorView* data_tensor, + Val* loop_index) + : Expr(passkey, ExprType::AddressCompute), + op_type_(AddressCompute::AddressComputeOpType::DOUBLE_BUFFER_UPDATE), + data_tensor_(data_tensor), + address_tensor_(address_tensor), + buffer_size_in_byte_(buffer_size_in_byte), + loop_offset_(loop_offset), + stage_number_(stage_number), + loop_index_(loop_index) { + TORCH_INTERNAL_ASSERT( + passkey.ir_container_->isA(), + "IR type only valid for Kernel container."); +} + InitMagicZero::InitMagicZero(IrBuilderPasskey passkey) : Expr(passkey, ExprType::InitMagicZero) { TORCH_INTERNAL_ASSERT( diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir.h b/torch/csrc/jit/codegen/cuda/kernel_ir.h index ff4c826d1f53aa..8792315b2f1afa 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_ir.h +++ b/torch/csrc/jit/codegen/cuda/kernel_ir.h @@ -154,7 +154,9 @@ class TORCH_CUDA_CU_API TensorIndex final : public Val { TensorIndex( IrBuilderPasskey, const TensorView* view, - std::vector indices); + std::vector indices, + Val* base_address = nullptr, + Val* uniform_address = nullptr); std::vector::size_type nDims() const { return indices_.size(); @@ -171,9 +173,33 @@ class TORCH_CUDA_CU_API TensorIndex final : public Val { return const_cast(view_); // NOLINT } + bool hasBaseAddress() const { + return base_address_ != nullptr; + } + + Val* baseAddress() const { + return base_address_; + } + + auto uniformAddress() const { + return uniform_address_; + } + + bool useSmemAddress() const { + return use_smem_address_; + } + + TensorIndex* toSmemAddress() { + use_smem_address_ = true; + return this; + } + private: const TensorView* view_ = nullptr; std::vector indices_; + Val* base_address_ = nullptr; + Val* uniform_address_ = nullptr; + bool use_smem_address_ = false; }; //! Allocate is a lower level Node that describes a buffer of memory that @@ -291,14 +317,61 @@ class TORCH_CUDA_CU_API CpAsyncCommit final : public Expr { //! that are not inlined. class TORCH_CUDA_CU_API AddressCompute final : public Expr { public: - enum class AddressComputeOpType { BASE_ADDRESS }; - + enum class AddressComputeOpType { + // Calculate base address for lifted memory index + BASE_ADDRESS, + // Switch a double buffer index register, + // see [Uniform Double Buffer Offset] + DOUBLE_BUFFER_SWITCH, + // Inplace update a double buffered address + // see [Inplace Double Buffer Update] + DOUBLE_BUFFER_UPDATE, + // Inplace increment a global address, see + // see [Gmem address increment] + GMEM_INCREMENT, + // Inplace increment a global address, see + // see [Gmem Increment Hoisting] + GMEM_DECREMENT + }; + + // Constructor for BASE_ADDRESS mode calculation + // (Default). explicit AddressCompute( IrBuilderPasskey passkey, AddressComputeOpType op_type, Val* address_tensor, Val* data_tensor); + // Constructor for gmem increment + explicit AddressCompute( + IrBuilderPasskey passkey, + Val* address_tensor, + Val* data_tensor, + TensorIndex* increment_value = nullptr, + bool is_decrement = false); + + // Constructor for double buffer offset + // calculation: + explicit AddressCompute( + IrBuilderPasskey passkey, + TensorView* data_tv, + Val* double_buffer_switch_index, + Val* buffer_size_in_byte, + int loop_offset, + int stage_number, + Val* loop_index = nullptr); + + // Constructor for double buffer offset + // inplace update: + explicit AddressCompute( + IrBuilderPasskey passkey, + Val* address_tensor, + Val* buffer_size_in_byte, + int stage_number, + int loop_offset, + TensorView* data_tensor, + Val* loop_index = nullptr); + auto dataTv() const { return data_tensor_; } @@ -311,6 +384,34 @@ class TORCH_CUDA_CU_API AddressCompute final : public Expr { return op_type_; } + auto doubleBufferSwitchIndex() const { + return double_buffer_switch_index_; + } + + auto doubleBufferByteSize() const { + return buffer_size_in_byte_; + } + + auto loopOffset() const { + return loop_offset_; + } + + auto stageNumber() const { + return stage_number_; + } + + auto loopIndex() const { + return loop_index_; + } + + auto incrementValue() const { + return increment_value_; + } + + bool isDecrement() const { + return op_type_ == AddressComputeOpType::GMEM_DECREMENT; + } + private: // The type of computation this op computes, // currently only do compute address. @@ -323,6 +424,30 @@ class TORCH_CUDA_CU_API AddressCompute final : public Expr { // Tensor that stores pre-computed address for the // data tensor. Val* address_tensor_ = nullptr; + + // Double buffer switch and update parameters below: + + // The switching index that this op is updating. + Val* double_buffer_switch_index_ = nullptr; + + // The original buffer alloc size used for double buffer + // update calculation. + Val* buffer_size_in_byte_ = nullptr; + + // The double buffer loop offset that is used for + // computing the double buffer size update. + int loop_offset_ = 0; + + // The double buffer loop offset that is used for + // computing the double buffer size update. + int stage_number_ = 0; + + // The double buffer loop index. + Val* loop_index_ = nullptr; + + // Gmem increment parameters below: + // The increment value to apply to the pointer. + kir::TensorIndex* increment_value_ = nullptr; }; // Synchronize all blocks in device, implies cooperative group launch is @@ -443,6 +568,10 @@ struct LoopTransformInfo { //! lifted memory address. bool is_base_index_loop = false; + //! Tracks if this for loop is for calculating inductive variable + //! increments. + bool is_increment_loop = false; + //! Setter API LoopTransformInfo& doubleBufferStage(DoubleBufferLoopStage stage) { double_buffer_loop_stage = stage; @@ -460,6 +589,12 @@ struct LoopTransformInfo { predicate_peel_stage = stage; return *this; } + + // ! Setter API + LoopTransformInfo& incrementLoop() { + is_increment_loop = true; + return *this; + } }; //! ForLoop provides scoping around an int iterator from 0 to range. Exprs diff --git a/torch/csrc/jit/codegen/cuda/lower_double_buffer.cpp b/torch/csrc/jit/codegen/cuda/lower_double_buffer.cpp index 05589386c863f1..7f0c232d7abdba 100644 --- a/torch/csrc/jit/codegen/cuda/lower_double_buffer.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_double_buffer.cpp @@ -150,6 +150,69 @@ bool requireEpilogue(const std::vector& exprs) { }); } +bool isGmemIncrement(Expr* expr) { + if (auto loop = dynamic_cast(expr)) { + if (loop->body().exprs().size() != 1) { + return false; + } + return isGmemIncrement(loop->body().exprs()[0]); + } else if (auto address_compute = dynamic_cast(expr)) { + return address_compute->opType() == + kir::AddressCompute::AddressComputeOpType::GMEM_INCREMENT; + } + return false; +} + +//! Hoists the gmem increment ops to the beginning of the loop +//! within the scope of the given loop. +//! Note: [Gmem Increment Hoisting] +//! +//! This optimization is very useful when inplace increment +//! is used on the global memory pointers. +//! Before this optimization, the code would look like: +//! +//! for i in ... // main loop +//! load.global ... [ptr] +//! // Here we actually have an anti-dependency (WAR) on +//! // the register holding ptr and could result in +//! // non-ideal performance when we do not have enough +//! // instructions to put between the load and the increment. +//! // depending on how many other instructions we have +//! // within this loop. +//! ptr += increment_value +//! +//! After this transformation, the code looks like: +//! ptr -=increment_value // a naive way to compensate +//! // for the first iter. +//! for i in ... // main loop +//! ptr += increment_value +//! // This is actually ok as integer instructions +//! // are usually much faster than memory. +//! load.global ... [ptr] +//! +//! This function hoists the pointer increments, in the given +//! loop, assuming that the decrements have been inserted +//! on the CircularInitProlog stage. +kir::ForLoop* hoistGmemIncrement(kir::ForLoop* fl) { + auto hoisted_loop = IrBuilder::create(fl); + + // insert all gmem increment exprs + for (auto expr : fl->body().exprs()) { + if (isGmemIncrement(expr)) { + hoisted_loop->body().push_back(expr); + } + } + + // insert all non gmem increment exprs + for (auto expr : fl->body().exprs()) { + if (!isGmemIncrement(expr)) { + hoisted_loop->body().push_back(expr); + } + } + + return hoisted_loop; +} + // Replicates double buffer loops for Prologue, Main, and // Epilogue. Prologue only copies the load expressions of double // buffered tensors, whereas Epilogue does any expression other than @@ -226,10 +289,78 @@ class DoubleBufferLoopCloner : public kir::IrVisitor { handle(double_buffer_loop_); + // insert double buffer switching for the read offset: + if (loop_type_ == DoubleBufferLoopStage::Main) { + auto& db_info = GpuLower::current()->doubleBufferInfo(); + + for (auto load : double_buffer_load_exprs_) { + if (auto tv_out = ir_utils::getTvOutput(load)) { + // calculate the switching size + auto switch_size = db_info.getOriginalAllocSize(tv_out); + auto switch_size_in_byte = SimplifyingIrBuilder::mulExpr( + switch_size, + SimplifyingIrBuilder::create(dataTypeSize(tv_out->dtype()))); + + // insert db switch expressions: + // Note:[Uniform Double Buffer Offset] + // This modification is to encourage usage of uniform registers on + // sm75+ when + // accessing shared memory double buffered tensors. + // The code before transformation: + // for i in ... // double buffer loop + // ... = ld.shared [... + (i%5) * double_buffer_size] + // The above code doesn't explictly specify that the double buffer + // switch + // component is uniform. The following transformed code makes it + // explicit: + // for i in ... // double buffer loop + // ... = ld.shared [... + switch_index] + // doubleBufferSwitch(switch_index); + // So that the double buffer indices are all placed in uniform reg. + + auto maybe_read_index = db_info.getReadSwitchIndex(tv_out); + if (maybe_read_index.has_value()) { + // Instantiate and insert the update operator. + auto address_compute = + SimplifyingIrBuilder::create( + tv_out, + maybe_read_index.value(), + switch_size_in_byte, + 0, // assume this path only supports read + // so offset is 0 + db_info.getStageDepthFor( + double_buffer_loop_->iter_domain())); + + cloned_top_level_loop_->body().push_back(address_compute); + } + } + } + } + if (stage_depth > 2) { cloned_top_level_loop_->body().push_back( IrBuilder::create()); } + + // Hoist the address increment in the double buffer main + // loop, see also [Gmem Increment Hoisting] + if (loop_type_ == DoubleBufferLoopStage::Main && + std::any_of( + double_buffer_loop_->body().exprs().begin(), + double_buffer_loop_->body().exprs().end(), + isGmemIncrement) && + // FIXME: + // Below is current condition that is required for gmem increment + // hoisting because the gmem decrement is currently placed in + // CircularInitProlog which requires predicate peeling to + // be generated. + // To fix this should probably dedicate another double buffer + // loop stage, maybe GmemPointerDecrement, that is reserved + // for placing the gmem decrement before the main loop stage. + GpuLower::current()->predicatePeelingInfo().shouldPeelLoop( + double_buffer_loop_)) { + cloned_top_level_loop_ = hoistGmemIncrement(cloned_top_level_loop_); + } } void handle(kir::ForLoop* fl) final { @@ -303,6 +434,50 @@ class DoubleBufferLoopCloner : public kir::IrVisitor { cloned_scopes_.back()->push_back(expr); } } + + if (loop_type_ == DoubleBufferLoopStage::CircularInitProlog) { + // Convert the address compute ops to decrement in the circular + // buffer init prolog, see [Gmem Increment Hoisting]. + if (auto address_compute = dynamic_cast(expr)) { + if (address_compute->opType() == + kir::AddressCompute::AddressComputeOpType::GMEM_INCREMENT) { + cloned_scopes_.back()->push_back( + IrBuilder::create( + address_compute->addressTv(), + address_compute->dataTv(), + address_compute->incrementValue(), + true /* is_decrement */)); + } + } + } + + // Include the double buffer update expressions in prologs too as + // prolog does write into the double buffered space. + if (loop_type_ == DoubleBufferLoopStage::Prolog) { + if (auto address_compute = dynamic_cast(expr)) { + if (address_compute->opType() == + kir::AddressCompute::AddressComputeOpType::DOUBLE_BUFFER_UPDATE) { + if (std::any_of( + double_buffer_load_exprs_.begin(), + double_buffer_load_exprs_.end(), + [address_compute](Expr* expr) { + return ir_utils::getTvOutput(expr)->sameAs( + address_compute->dataTv()); + })) { + cloned_scopes_.back()->push_back(expr); + } + } + } + } + + if (loop_type_ != DoubleBufferLoopStage::CircularInitProlog) { + if (auto address_compute = dynamic_cast(expr)) { + if (address_compute->opType() == + kir::AddressCompute::AddressComputeOpType::GMEM_INCREMENT) { + cloned_scopes_.back()->push_back(expr); + } + } + } } //! Returns true if the expression is an initialization expr that @@ -495,6 +670,41 @@ class DoubleBufferInserter : private kir::ExprMutator { void insert( kir::ForLoop* double_buffer_loop, const std::vector& loads) { + // Allocate read switching index if they need to be updated + // independently. see [Uniform Double Buffer Offset] + for (auto load : loads) { + if (auto load_output = dynamic_cast(load->output(0))) { + auto uses = load_output->fusion()->unordered_uses(load_output); + if (load_output->getMemoryType() == MemoryType::Shared && + (load_output->isDoubleBuffered() || + load_output->isCircularBuffered()) && + load_output->shouldLiftReadAddress() && + // TODO: read switch index is only enabled for ldmatrix + // at the moment. + // Would need to extend the ld.shared usage to directly + // take pointers to use this in other cases. + std::all_of(uses.begin(), uses.end(), ir_utils::isLdMatrixOp)) { + auto switch_val = IrBuilder::create(); + switch_val->to32b(); + + // Record the read switch indexing variable so it can be + // used in the indexing pass. + // TODO: maybe want to do this in id graph instead + GpuLower::current()->doubleBufferInfo().setReadSwitchIndex( + load_output, switch_val); + + // Place allocation for the switching variable before the + // double buffer loop. + auto index_alloc = IrBuilder::create( + switch_val, + MemoryType::Local, + GpuLower::current()->kernel()->oneVal(), + true); + registerInsertBefore(double_buffer_loop, index_alloc); + } + } + } + auto prologue_loop = DoubleBufferLoopCloner::clone( double_buffer_loop, loads, DoubleBufferLoopStage::Prolog); registerInsertBefore(double_buffer_loop, prologue_loop); diff --git a/torch/csrc/jit/codegen/cuda/lower_double_buffer.h b/torch/csrc/jit/codegen/cuda/lower_double_buffer.h index 29d27778bccc13..5f49b3b75a4726 100644 --- a/torch/csrc/jit/codegen/cuda/lower_double_buffer.h +++ b/torch/csrc/jit/codegen/cuda/lower_double_buffer.h @@ -226,6 +226,24 @@ class TORCH_CUDA_CU_API DoubleBufferInfo { //! skew double buffer transform within the given outer loop. bool isLowerPrologWithin(IterDomain* db_loop, IterDomain* outer_loop); + //! Record the allocated double buffer switching index, + //! see [Uniform Double Buffer Offset] + void setReadSwitchIndex(TensorView* db_tv, Val* switch_index) { + TORCH_INTERNAL_ASSERT( + read_switch_index_map_.insert(std::make_pair(db_tv, switch_index)) + .second); + } + + //! Returns the double buffer switching index if one has been + //! allocated and recorded for the given tv. + c10::optional getReadSwitchIndex(TensorView* db_tv) { + auto val_it = read_switch_index_map_.find(db_tv); + if (val_it == read_switch_index_map_.end()) { + return c10::nullopt; + } + return val_it->second; + } + private: TvInfo& getTvInfo(const TensorView* tv); void buildSkewInfo(const TensorView* tv, const TvInfo& tv_info); @@ -258,6 +276,9 @@ class TORCH_CUDA_CU_API DoubleBufferInfo { //! mapping from inner loop to outer loop. std::unordered_map concrete_skewed_double_buffer_loop_map_; + + //! Keep track of read switch index + std::unordered_map read_switch_index_map_; }; } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/lower_index.cpp b/torch/csrc/jit/codegen/cuda/lower_index.cpp index df099078c52222..61eaf0a8d29490 100644 --- a/torch/csrc/jit/codegen/cuda/lower_index.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_index.cpp @@ -170,6 +170,20 @@ void IndexLowering::handle(const EyeOp* eop) { void IndexLowering::handle(const UnaryOp* uop) { const auto in = lowerSrcIndex(uop->in(), uop->out()); const auto out = lowerDstIndex(uop->out()); + + // Convert the output index to direct shared memory + // address usage if this unary op is initialization + // for cp.async op. see [Lifting smem address decoding for cp.async] + // In order to use the same register for indexing the init + // expression as well, the init expr also needs to + // directly use the shared memory address. + if (ir_utils::isCpAsyncInit(uop)) { + auto out_tv = ir_utils::getTvOutput(uop); + if (out_tv->shouldLiftWriteAddress()) { + out->as()->toSmemAddress(); + } + } + pushBack(IrBuilder::create(uop->getUnaryOpType(), out, in)); GpuLower::current()->propagateExprInfo(uop, back()); } @@ -1060,6 +1074,51 @@ void IndexLowering::handle(const kir::CpAsyncCommit* commit) { } void IndexLowering::handle(const kir::AddressCompute* address_compute) { + // Logic for double buffer switching: + if (address_compute->opType() == + kir::AddressCompute::AddressComputeOpType::DOUBLE_BUFFER_SWITCH) { + // no indexing is needed, just forward through the expression and + // attach the loop index corresponding to the double buffer loop. + auto db_loop = GpuLower::current()->doubleBufferInfo().getDoubleBufferLoop( + address_compute->dataTv()->as(), for_loops_, false); + TORCH_INTERNAL_ASSERT(db_loop != nullptr); + auto db_index = db_loop->isTrivial() ? db_loop->start() : db_loop->index(); + + pushBack(IrBuilder::create( + address_compute->dataTv()->as(), + address_compute->doubleBufferSwitchIndex(), + address_compute->doubleBufferByteSize(), + address_compute->loopOffset(), + address_compute->stageNumber(), + db_index)); + return; + } else if ( + address_compute->opType() == + kir::AddressCompute::AddressComputeOpType::DOUBLE_BUFFER_UPDATE) { + // Unpack the double buffer loop and double buffer allocation component + auto db_loop = GpuLower::current()->doubleBufferInfo().getDoubleBufferLoop( + address_compute->dataTv()->as(), for_loops_, false); + TORCH_INTERNAL_ASSERT(db_loop != nullptr); + auto db_index = db_loop->isTrivial() ? db_loop->start() : db_loop->index(); + auto loop_offset = + db_loop->doubleBufferLoopStage() == DoubleBufferLoopStage::Main + ? address_compute->stageNumber() - 1 + : 0; + + // Generate index into the address tensor to update. + auto indexed_address_tv = Index::generateAddressTensorIndex( + for_loops_, address_compute->addressTv()->as()); + + pushBack(IrBuilder::create( + indexed_address_tv, + address_compute->doubleBufferByteSize(), + address_compute->stageNumber(), + loop_offset, + address_compute->dataTv()->as(), + db_index)); + return; + } + // Logic for base address computation auto address_tv = address_compute->addressTv(); @@ -1071,8 +1130,29 @@ void IndexLowering::handle(const kir::AddressCompute* address_compute) { auto address_record = maybe_address_record.value(); + if (address_compute->opType() == + kir::AddressCompute::AddressComputeOpType::GMEM_INCREMENT || + address_compute->opType() == + kir::AddressCompute::AddressComputeOpType::GMEM_DECREMENT) { + // GMEM_INCREMENT/DECREMENT is only used on global producer tv + // currently, so only lowering source index for the address tensor + // to compute the amount of increment. + pushBack(IrBuilder::create( + Index::generateAddressTensorIndex( + for_loops_, address_compute->addressTv()->as()), + address_compute->dataTv(), + lowerSrcIndex( + address_record->dataTensor(), + address_record->indexReferenceTensor()) + ->as(), + address_compute->isDecrement())); + return; + } + Val* lowered_data_index = nullptr; + // This is the base address generation logic, lowering src/dst indexing + // math based on if this record is read or write. if (address_record->isRead()) { lowered_data_index = lowerSrcIndex( address_record->dataTensor(), address_record->indexReferenceTensor()); diff --git a/torch/csrc/jit/codegen/cuda/lower_index_compute.cpp b/torch/csrc/jit/codegen/cuda/lower_index_compute.cpp index ed6857a552b653..29783cd64cb8ac 100644 --- a/torch/csrc/jit/codegen/cuda/lower_index_compute.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_index_compute.cpp @@ -124,8 +124,28 @@ std::unordered_set getZeroIdSetsForAddressCompute( // Checks if this loop nest is calculating base address. bool is_address_tv_calculation = serial_loop->isBaseIndexLoop(); + // Check if this loop nest is incrementing a gmem address, + // see [Gmem address increment]; + bool is_increment = + std::any_of(loops.begin(), loops.end(), [](kir::ForLoop* fl) { + return fl->loopTransformInfo().is_increment_loop; + }); + std::unordered_set zero_ids; + if (is_increment) { + // In the case of increment calculation, just zero + // every loop except the serial loop from the address record. + for (auto fl : loops) { + if (fl != serial_loop) { + zero_ids.insert(fl->iter_domain()); + } + } + // Zero everything except the serial loop + // in the case of increment gmem iterator. + return zero_ids; + } + for (auto outer_loop_it = loops.begin(); outer_loop_it != loop_it; outer_loop_it++) { auto outer_loop = *outer_loop_it; @@ -189,6 +209,15 @@ std::unordered_set getZeroIdSetsForAddressCompute( loop_it++; } + if (address_record->isRead() && + address_record->dataTensor()->getMemoryType() == MemoryType::Global) { + // The serial loop is converted to increment mode, see [Gmem address + // increment] + // so it can be zeroed always. + // See also [Separability Analysis] on conditions when this is enabled. + zero_ids.insert(address_record->getConcreteSerialLoopId()); + } + return zero_ids; } @@ -230,6 +259,13 @@ IndexingParameters getLinearIndexParameters( maybe_address_record.value(), loop_indexing.loops()); } + bool is_increment = std::any_of( + loop_indexing.loops().begin(), + loop_indexing.loops().end(), + [](kir::ForLoop* fl) { + return fl->loopTransformInfo().is_increment_loop; + }); + auto& loops = loop_indexing.loops(); auto& loop_domain = loop_indexing.loopDomains(); auto& loop_index_map = index_parameters.initial_concrete_id_index; @@ -254,6 +290,26 @@ IndexingParameters getLinearIndexParameters( // Default use pre-allocated integers for index loop_index_map[index_domain] = loop->index(); } + + if (is_increment) { + TORCH_INTERNAL_ASSERT(maybe_address_record.has_value()); + if (GpuLower::current()->caMap()->areMapped( + concrete_loop_domain, + maybe_address_record.value()->getConcreteSerialLoopId(), + IdMappingMode::LOOP)) { + // For the increment calculation, the current implementation + // inserts a one for the loop index corresponding to the serial + // loop. This is valid if [Separability Analysis] checks ok + // on the serial id. + // TODO: + // The current Separability restriction on the serial loop makes this + // ok + // but should eventually use the f(i+1) - f(i) instead + // of a one for the increment calculation to enable more complex + // increment patterns. + loop_index_map[index_domain] = GpuLower::current()->kernel()->oneVal(); + } + } } // Derive the halo extents from the loop indexing result. @@ -269,7 +325,7 @@ IndexingParameters getLinearIndexParameters( // Setup double buffer increment for producer case: // TODO: could unify these double buffer index calculation // in follow ups. - if (index_producer) { + if (index_producer && !maybe_address_record.has_value()) { auto double_buffer_loop = GpuLower::current()->doubleBufferInfo().getDoubleBufferLoop( loop_indexing.consumerTv(), loops, true); diff --git a/torch/csrc/jit/codegen/cuda/lower_mem_index.cpp b/torch/csrc/jit/codegen/cuda/lower_mem_index.cpp index 7c364cb494d843..a6ad6dcb8f5e7b 100644 --- a/torch/csrc/jit/codegen/cuda/lower_mem_index.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_mem_index.cpp @@ -131,6 +131,22 @@ void AddressComputeInfo::build(Fusion* fusion) { if (!ir_utils::isTvOp(expr)) { continue; } + + if (ir_utils::isCpAsyncOp(expr)) { + auto in_tv = ir_utils::getTvInput(expr); + auto out_tv = ir_utils::getTvOutput(expr); + + // FIXME: + // It'd take 2 more variants of the resource string for cp.async + // to support lifting one of the producer/consumer indices. As + // the eventual goal of these analysis is to be turned on generically, + // the use case for lifting one of the components is limited so + // not prioritizing. + TORCH_INTERNAL_ASSERT( + in_tv->shouldLiftReadAddress() == out_tv->shouldLiftWriteAddress(), + "For cp.async op only support either lifting both producer and consumer indexing or neither."); + } + for (auto consumer_tv : ir_utils::filterByType(expr->outputs())) { if (consumer_tv->shouldLiftWriteAddress()) { @@ -904,9 +920,6 @@ void AddressComputeInfo::makeAddressRecord( isSeparable(reference_tv, serial_id, contig_merged_ids), "The serial id is required to be separable for the index lifting to work."); - // Create address record: - auto address_tv = makeAddressTv(alloc_ids_vec, !is_shared_mem_access); - // Assuming we are only having two scenarios, // either accessing a consumer in the consumer's loop, // or accessing the producer in producer's loop. @@ -914,6 +927,20 @@ void AddressComputeInfo::makeAddressRecord( ? AddressRecord::ReadWrite::WRITE : AddressRecord::ReadWrite::READ; + bool is_cp_async_write = + access_direction == AddressRecord::ReadWrite::WRITE && + ir_utils::isCpAsyncOp(data_tv->definition()); + + // Place holder for predicate lifting PR. + bool is_predicate_record = false; + + // Create address record: + auto address_tv = makeAddressTv( + alloc_ids_vec, + !is_shared_mem_access, + is_predicate_record, + is_cp_async_write); + TORCH_INTERNAL_ASSERT( serial_id != nullptr, "no support yet for global scope hoisting"); @@ -957,8 +984,24 @@ c10::optional AddressComputeInfo::getMaybeLiftedAddress( TensorView* AddressComputeInfo::makeAddressTv( std::vector address_domains, - bool is_global_address) { - DataType dtype = is_global_address ? DataType::Index : DataType::Int32; + bool is_global_address, + bool is_predicate_index, + bool is_cpasync_write) { + DataType dtype = is_predicate_index ? DataType::Index : DataType::Pointer; + + // Note: [Lifting smem address decoding for cp.async] + // A trick that saves register usage. + // Before: + // char* smem_ptr; + // for i in ... // main loop + // cp.async [smem_ptr + 123], ... + // After: + // unsigned smem_address = toSmem(smem_ptr); + // for i in ... // main loop + // cp.async [smem_addres+123], ... + if (is_cpasync_write) { + dtype = DataType::SmemAddress; + } return IrBuilder::create( IrBuilder::create( address_domains, std::vector(address_domains.size(), true)), @@ -1063,6 +1106,112 @@ class MemoryAddressComputeInserter : public kir::ExprMutator { // put the new loopnest before the hoisted loop registerInsertBefore(loop, outermost_innermost.first); + + auto data_tensor = insertion_info.address_compute_record->dataTensor(); + + // Insert double buffer increment: + // Note: [Inplace Double Buffer Update]: + // + // The trick used in [Uniform Double Buffer Offset] should be the default + // method of handling double buffer switching index when trying to save + // general purpose registers. But there are 2 exceptions: + // 1. On sm70 or below, there are no unifrom regs to use. + // 2. With cp.async, the consumer shared memory buffer currently + // does not provide access to the uniform reg operand so we could not use + // it. (will be actively discussed internally) + // + // To still avoid using too many registers on double buffered access, + // another code gen trick is used here, to enable near term progress: + // The code before transformation: + // for i in ... // double buffer loop + // ... = ld.shared [... + (i%5) * double_buffer_size] + // The code after transformation: + // R0 = ... + // for i in ... // double buffer loop + // ... = ld.shared [R0] + // doubleBufferUpdate(R0); + // This way essentially the double buffer offset is calculated inplace + // into R0 in each double buffer loop iteration. Note that comparing with + // [Uniform Double Buffer Offset] this method uses more instructions as + // all of the pointers will need to be updated, while using uniform regs + // will only need to update the uniform switch index. + + // FIXME: should move this logic into lower_double_buffer.cpp. + // will need to formulate into a separate pass as it needs to + // clone the loop nest. + if ((data_tensor->isDoubleBuffered() || + data_tensor->isCircularBuffered()) && + insertion_info.address_compute_record->isWrite() && + // Only have support doubleBufferUpdate for + // direct smem access for now. + // FIXME: + // Would need to extend to use this on Volta. + lower_utils::useDirectSmemAddress(data_tensor)) { + // Insert double buffer index update if it is a double buffered write: + // The insertion info loop nest starts with the serial loop, + // in the double buffer update we need to insert into the original + // serial loop itself, so remove the outermost level. + auto db_loop_nest = std::vector( + std::next(insertion_info.loop_nest.begin()), + insertion_info.loop_nest.end()); + + // Clone the loop nest containing the double buffered write + // expression for the consumer index update. + auto db_outer_inner = scope_utils::makeLoopNest(db_loop_nest); + + // Calculate the double buffer size. + auto& db_info = GpuLower::current()->doubleBufferInfo(); + + auto db_size_in_byte = SimplifyingIrBuilder::mulExpr( + db_info.getOriginalAllocSize(data_tensor), + SimplifyingIrBuilder::create( + dataTypeSize(data_tensor->dtype()))); + + // Create the double buffer update expression and insert + // them at the end of the double buffer loop. + auto update_expr = SimplifyingIrBuilder::create( + insertion_info.address_compute_record->addressTensor(), + db_size_in_byte, + data_tensor->circularBufferDepth(), + 0, + data_tensor); + + db_outer_inner.second->body().push_back(update_expr); + loop->body().push_back(db_outer_inner.first); + } + + // Insert gmem increment: + // Note [Gmem address increment]: + // This is a trick that helps lifting some instructions out of main + // loop. + // The code before this transformation: + // R0 = ... + // for i in ... // The serial loop on index lifting record + // x = ld.global [i*123 + R0] + // The code after transformation: + // R0 = ... + // for i in ... // The serial loop on index lifting record + // x = ld.global [R0] + // R0+=123; + // Note that [Separability Analysis] will be checked on the serial + // loop when creating these address records so doing this transformation + // on the serial loop index variable is safe. + if (data_tensor->getMemoryType() == MemoryType::Global && + insertion_info.address_compute_record->isRead()) { + // Create the loopnest to contain the increment expression. + auto increment_loop_vector = + createIncrementLoop(insertion_info.loop_nest); + auto increment_loop_outer_inner = + scope_utils::makeLoopNest(increment_loop_vector); + + // Create the increment expression. + auto inc_expr = SimplifyingIrBuilder::create( + insertion_info.address_compute_record->addressTensor(), + insertion_info.address_compute_record->dataTensor()); + + increment_loop_outer_inner.second->body().push_back(inc_expr); + loop->body().push_back(increment_loop_outer_inner.first); + } } } @@ -1096,6 +1245,31 @@ class MemoryAddressComputeInserter : public kir::ExprMutator { original_loop->loopTransformInfo().baseIndexLoop()); } + // Utility to create the loop nest for gmem increment, + // see [Gmem address increment]. + std::vector createIncrementLoop( + std::vector address_compute_loop_vector) { + std::vector loop_nest_to_clone( + std::next(address_compute_loop_vector.begin()), + address_compute_loop_vector.end()); + + std::vector cloned_loop_nest; + for (auto fl : loop_nest_to_clone) { + cloned_loop_nest.push_back(IrBuilder::create( + fl->iter_domain(), + fl->index(), + fl->start(), + fl->stop(), + fl->step(), + fl->vectorize(), + fl->vectorize_shift(), + fl->isUnrollRequired(), + fl->loopTransformInfo().incrementLoop())); + } + + return cloned_loop_nest; + } + std::vector createAddressComputeLoop( AddressRecord* address_record) { // Find the loop in the current loop nest that maps the concrete serial loop diff --git a/torch/csrc/jit/codegen/cuda/lower_mem_index.h b/torch/csrc/jit/codegen/cuda/lower_mem_index.h index 4c5e5449f12af5..d17e28043a119a 100644 --- a/torch/csrc/jit/codegen/cuda/lower_mem_index.h +++ b/torch/csrc/jit/codegen/cuda/lower_mem_index.h @@ -166,7 +166,9 @@ class AddressComputeInfo { // Utility to help allocate space for saving pre-computed address. TensorView* makeAddressTv( std::vector address_domains, - bool is_global_address); + bool is_global_address, + bool is_predicate_index, + bool is_cpasync_write = false); void makeAddressRecord(TensorView* data_tv, TensorView* reference_tv); diff --git a/torch/csrc/jit/codegen/cuda/lower_utils.cpp b/torch/csrc/jit/codegen/cuda/lower_utils.cpp index e3df0e14721517..d7c47b60c9d0f0 100644 --- a/torch/csrc/jit/codegen/cuda/lower_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_utils.cpp @@ -184,6 +184,11 @@ bool isCpAsyncOp(const Expr* expr) { } bool isTensorScalarFillOp(const Expr* expr) { + // Check that this expression outputs to tensor + if (getTvOutput(expr) == nullptr) { + return false; + } + // Check that the input is a single scalar. if (expr->inputs().size() == 1 && expr->input(0)->isScalar()) { // All load store op with a single scalar input @@ -344,7 +349,9 @@ std::unordered_map getParallelDomains( } bool isCpAsyncInit(const Expr* expr) { - return isTensorScalarFillOp(expr) && + return + + isTensorScalarFillOp(expr) && // FIXME: // We'd need to add a flag to all the init // exprs so we could robustly detect initialization @@ -776,6 +783,21 @@ bool supportInlinePredicate(Expr* expr) { return false; } +bool useDirectSmemAddress(const TensorView* tv) { + // Not applicable for any indexing that's not + // lifted. + if (!tv->shouldLiftWriteAddress() || + tv->getMemoryType() != MemoryType::Shared) { + return false; + } + + auto expr = tv->definition(); + // Direct usage of smem address should be avoided at all cost, + // so only allowing this very specific case where this is the + // necessary step to take to get efficient indexing code. + return expr != nullptr && ir_utils::isCpAsyncOp(expr); +} + } // namespace lower_utils } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/lower_utils.h b/torch/csrc/jit/codegen/cuda/lower_utils.h index f5041510ed011b..f6ebdddbec04a5 100644 --- a/torch/csrc/jit/codegen/cuda/lower_utils.h +++ b/torch/csrc/jit/codegen/cuda/lower_utils.h @@ -277,6 +277,10 @@ BasicAllocInfo getAllocInformation( //! as an inline argument. bool supportInlinePredicate(Expr* expr); +//! Returns true if the consumer indexing of this tensor directly +//! uses shared mem address. +bool useDirectSmemAddress(const TensorView* tv); + } // namespace lower_utils } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/runtime/memory.cu b/torch/csrc/jit/codegen/cuda/runtime/memory.cu index e064a43090fd7e..38bacb3d6a1fc3 100644 --- a/torch/csrc/jit/codegen/cuda/runtime/memory.cu +++ b/torch/csrc/jit/codegen/cuda/runtime/memory.cu @@ -105,6 +105,55 @@ DEVICE_INLINE void ldMatrixT(Array<__half, 8, 8>& out, void const* ptr) { : "r"(addr)); } +// Below are the variants of ldmatrix wrapper that supports lifted +// memory indexing. +DEVICE_INLINE void ldMatrix( + Array<__half, 4, 4>& out, + nvfuser_index_t index, + DataPointer base_ptr) { + uint2& val = reinterpret_cast(out); + unsigned addr = util::toSmem(base_ptr); + util::adjustPartialLdMatrixAddrInTuring(addr); + asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0,%1}, [%2];" + : "=r"(val.x), "=r"(val.y) + : "r"(addr + (unsigned)index)); +} + +DEVICE_INLINE void ldMatrixT( + Array<__half, 4, 4>& out, + nvfuser_index_t index, + DataPointer base_ptr) { + uint2& val = reinterpret_cast(out); + unsigned addr = util::toSmem(base_ptr); + util::adjustPartialLdMatrixAddrInTuring(addr); + asm volatile("ldmatrix.sync.aligned.x2.trans.m8n8.shared.b16 {%0,%1}, [%2];" + : "=r"(val.x), "=r"(val.y) + : "r"(addr + (unsigned)index)); +} + +DEVICE_INLINE void ldMatrix( + Array<__half, 8, 8>& out, + nvfuser_index_t index, + DataPointer base_ptr) { + uint4& val = reinterpret_cast(out); + unsigned addr = util::toSmem(base_ptr); + asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0,%1,%2,%3}, [%4];" + : "=r"(val.x), "=r"(val.y), "=r"(val.z), "=r"(val.w) + : "r"(addr + (unsigned)index)); +} + +DEVICE_INLINE void ldMatrixT( + Array<__half, 8, 8>& out, + nvfuser_index_t index, + DataPointer base_ptr) { + uint4& val = reinterpret_cast(out); + unsigned addr = util::toSmem(base_ptr); + asm volatile( + "ldmatrix.sync.aligned.x4.trans.m8n8.shared.b16 {%0,%1,%2,%3}, [%4];" + : "=r"(val.x), "=r"(val.y), "=r"(val.z), "=r"(val.w) + : "r"(addr + (unsigned)index)); +} + } // namespace Turing #endif // Arch 75 @@ -136,10 +185,8 @@ DEVICE_INLINE unsigned toSmem(void* ptr) { // Global to SMEM load that is asynchronous, // not guaranteed to be completed until cpAsyncBarrier() is called. template -DEVICE_INLINE void cpAsync( - Array* smem_ptr, - void const* gmem_ptr) { - unsigned smem_addr = util::toSmem(&(smem_ptr->array[0])); +DEVICE_INLINE void cpAsync(void* smem_ptr, void const* gmem_ptr) { + unsigned smem_addr = util::toSmem(smem_ptr); constexpr int byte_size = sizeof(dtype) * len; static_assert( @@ -156,10 +203,10 @@ DEVICE_INLINE void cpAsync( // not guaranteed to be completed until cpAsyncBarrier() is called. template DEVICE_INLINE void cpAsync( - Array* smem_ptr, + void* smem_ptr, void const* gmem_ptr, bool predicate) { - unsigned smem_addr = util::toSmem(&(smem_ptr->array[0])); + unsigned smem_addr = util::toSmem(smem_ptr); constexpr int byte_size = sizeof(dtype) * len; static_assert( @@ -177,6 +224,53 @@ DEVICE_INLINE void cpAsync( "r"((int)predicate)); } +// cp.async +// This is the variant that supports lifted indexing +template +DEVICE_INLINE void cpAsync( + nvfuser_index_t smem_index, + unsigned smem_addr, + nvfuser_index_t gmem_index, + DataPointer& gmem_ptr) { + constexpr int byte_size = sizeof(dtype) * len; + + static_assert( + byte_size == 4 || byte_size == 8 || byte_size == 16, + "cp_async : unsupported byte size"); + + asm volatile( + "cp.async.ca.shared.global [%0], [%1], %2;\n" ::"r"( + smem_addr + (unsigned)smem_index), + "l"(gmem_ptr + gmem_index), + "n"(byte_size)); +} + +// cp.async +// This is the variant that supports lifted indexing, with predicate inlined. +template +DEVICE_INLINE void cpAsync( + nvfuser_index_t smem_index, + unsigned smem_addr, + nvfuser_index_t gmem_index, + DataPointer& gmem_ptr, + bool predicate) { + constexpr int byte_size = sizeof(dtype) * len; + + static_assert( + byte_size == 4 || byte_size == 8 || byte_size == 16, + "cp_async : unsupported byte size"); + + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %3, 0;\n" + "@p cp.async.ca.shared.global [%0], [%1], %2;\n" + "}\n" ::"r"(smem_addr + (unsigned)smem_index), + "l"(gmem_ptr + gmem_index), + "n"(byte_size), + "r"((int)predicate)); +} + // TODO: Might have a different category of sync if we want to build out this: DEVICE_INLINE void cpAsyncBarrier() { asm volatile("cp.async.wait_all;"); @@ -195,4 +289,104 @@ DEVICE_INLINE void cpAsyncPartialBarrier() { #endif // Arch 80 +// Double buffer calculation utilities: + +// In place update of double buffer index that has been accumulated to the data +// buffer. +template +DEVICE_INLINE void doubleBufferUpdate( + DataPointer& data_buffer, + const nvfuser_index_t& loop_index, + nvfuser_index_t buffer_size) { + // static_assert( + // loop_offset < number_of_stage && loop_offset > -number_of_stage); + + // convert offset to [0, number_of_stage) + constexpr nvfuser_index_t offset = + loop_offset < 0 ? (loop_offset + number_of_stage) : loop_offset; + + // Rewind back at number_of_stage-1, otherwise increment by 1. + nvfuser_index_t increment = + (loop_index % number_of_stage) == (number_of_stage - 1 - offset) + ? buffer_size * (-number_of_stage + 1) + : buffer_size; + data_buffer += increment; +} + +template +DEVICE_INLINE void doubleBufferUpdate( + unsigned& data_buffer, + const nvfuser_index_t& loop_index, + nvfuser_index_t buffer_size) { + // static_assert( + // loop_offset < number_of_stage && loop_offset > -number_of_stage); + + // convert offset to [0, number_of_stage) + constexpr nvfuser_index_t offset = + loop_offset < 0 ? (loop_offset + number_of_stage) : loop_offset; + + // Rewind back at number_of_stage-1, otherwise increment by 1. + nvfuser_index_t increment = + (loop_index % number_of_stage) == (number_of_stage - 1 - offset) + ? buffer_size * (-number_of_stage + 1) + : buffer_size; + data_buffer += (unsigned)increment; +} + +// Update double buffer offset value for smem double buffered tensors. +// See [Uniform Double Buffer Offset] +template +DEVICE_INLINE void doubleBufferSwitch( + int& buffer_offset, + const nvfuser_index_t& loop_index, + nvfuser_index_t buffer_size) { + constexpr nvfuser_index_t offset = + loop_offset < 0 ? (loop_offset + number_of_stage) : loop_offset; + + // Rewind back at number_of_stage-1, otherwise increment by 1. + nvfuser_index_t increment = + (loop_index % number_of_stage) == (number_of_stage - 1 - offset) + ? buffer_size * (-number_of_stage + 1) + : buffer_size; + buffer_offset += (int)increment; +} + +// Reset smem space to zero +// TODO: try cp.async.ignore-source ? +template +DEVICE_INLINE void smemReset(SmemAddress smem_addr) { + constexpr int byte_size = sizeof(dtype) * len; + + static_assert( + byte_size == 4 || byte_size == 8 || byte_size == 16, + "cp_async : unsupported byte size"); + + switch (byte_size) { + case 4: + asm volatile( + "{\n" + "st.shared.u32 [%0], {%1};\n" + "}\n" + : + : "r"(smem_addr), "r"(0)); + break; + case 8: + asm volatile( + "{\n" + "st.shared.v2.u32 [%0], {%1, %2};\n" + "}\n" + : + : "r"(smem_addr), "r"(0), "r"(0)); + break; + case 16: + asm volatile( + "{\n" + "st.shared.v4.u32 [%0], {%1, %2, %3, %4};\n" + "}\n" + : + : "r"(smem_addr), "r"(0), "r"(0), "r"(0), "r"(0)); + break; + } +} + #undef DEVICE_INLINE diff --git a/torch/csrc/jit/codegen/cuda/scheduler/matmul.cpp b/torch/csrc/jit/codegen/cuda/scheduler/matmul.cpp index 44071a01931a6c..484f755abe1e9b 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/matmul.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/matmul.cpp @@ -534,6 +534,8 @@ void scheduleMatmul( .propagateParallelType() .propagateToBoundary()); + c->axis(-1)->parallelize(ParallelType::Vectorize); + if (params.index_lift_options.lift_gmem_read_address) { a->liftReadAddress(); b->liftReadAddress(); diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp index a6f777ef16255c..d29e0bbec29cce 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp @@ -847,7 +847,10 @@ TEST_F(NVFuserTest, FusionAmpereMatmulRegDoubleBuffer_CUDA) { FusionExecutor fe; NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( - 8, 0, fe.compileFusion(&fusion, {inputs.first, inputs.second})); + 8, + 0, + fe.compileFusion( + &fusion, {inputs.first, inputs.second}, LaunchParams())); auto cg_outputs = fe.runFusion({inputs.first, inputs.second}); auto tref = atMatmul( inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout); @@ -2798,6 +2801,7 @@ TEST_F(NVFuserTest, FusionAmpereMatmulLargeLoad_CUDA) { params.tile_sizes = gemm_tile; params.async_gmem_load_operands = true; params.double_buffer_options.double_buffer_smem_write = true; + params.double_buffer_options.double_buffer_smem_read = true; params.double_buffer_options.smem_double_buffer_stage = 3; scheduleMatmul(tv2, tv0, tv1, params); diff --git a/torch/csrc/jit/codegen/cuda/type.cpp b/torch/csrc/jit/codegen/cuda/type.cpp index 0943b22901063c..e4a6674e3eae03 100644 --- a/torch/csrc/jit/codegen/cuda/type.cpp +++ b/torch/csrc/jit/codegen/cuda/type.cpp @@ -30,6 +30,7 @@ bool isFloatingPointType(DataType dtype) { return true; case DataType::Bool: case DataType::Index: + case DataType::Pointer: case DataType::Int: case DataType::Int32: case DataType::ComplexFloat: @@ -75,6 +76,7 @@ bool isIntegralType(DataType dtype) { case DataType::ComplexDouble: return false; case DataType::Index: + case DataType::Pointer: case DataType::Int: case DataType::Int32: return true; @@ -97,6 +99,7 @@ bool isComplexType(DataType dtype) { case DataType::BFloat16: case DataType::Int: case DataType::Index: + case DataType::Pointer: case DataType::Int32: return false; case DataType::Null: @@ -236,6 +239,10 @@ static const char* data_type2string(DataType t) { return "int64_t"; case DataType::Index: return "nvfuser_index_t"; + case DataType::Pointer: + return "DataPointer"; + case DataType::SmemAddress: + return "SmemAddress"; case DataType::Int32: return "int"; case DataType::ComplexFloat: @@ -989,6 +996,7 @@ at::ScalarType data_type_to_aten(const DataType& data_type) { case DataType::Int: return at::ScalarType::Long; case DataType::Index: + case DataType::Pointer: TORCH_INTERNAL_ASSERT( false, "Index is determined at compile time,", @@ -1164,6 +1172,8 @@ std::string typePrefix(const DataType data_type) { case DataType::Int: case DataType::Int32: return "i"; + case DataType::Pointer: + return "p"; case DataType::ComplexFloat: case DataType::ComplexDouble: return "c"; @@ -1215,6 +1225,7 @@ size_t dataTypeSize(DataType type) { case DataType::BFloat16: return sizeof(at::BFloat16); case DataType::Index: + case DataType::Pointer: TORCH_INTERNAL_ASSERT( false, "The actual type of Index is only known at compile time."); case DataType::Int: diff --git a/torch/csrc/jit/codegen/cuda/type.h b/torch/csrc/jit/codegen/cuda/type.h index e6d4c5b87c1e10..f7908f28b511eb 100644 --- a/torch/csrc/jit/codegen/cuda/type.h +++ b/torch/csrc/jit/codegen/cuda/type.h @@ -66,6 +66,8 @@ enum class DataType { Half, Int, Index, + Pointer, + SmemAddress, Int32, Bool, BFloat16,