Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion csrc/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3757,7 +3757,7 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {

ArgumentBuilder template_args;
template_args.arg(kernel_->paddedParallelDimensions().is_tidx_single_warp);
template_args.arg(isAligned());
template_args.arg(has_warp_specialized_ ? false : isAligned());
template_args.arg(num_grouped_iterations);
template_args.arg(reduction_scheduler_utils::getComputeBdimx(
warp_specialized_on_, lparams_.bdimx()));
Expand Down
14 changes: 14 additions & 0 deletions csrc/device_lower/lower2device.h
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,16 @@ class GpuLower : public NonCopyable {
cluster_reduction_mbarrier_tensor_ = mbarrier;
}

//! Get the uniform warp id scalar
Val* uniformWarpId() const {
return uniform_warp_id_;
}

//! Set the uniform warp id scalar
void setUniformWarpId(Val* warp_id) {
uniform_warp_id_ = warp_id;
}

//! Define an alias for consumer as producer.
//!
//! If producer is already aliased, we chase the alias. If there are tensors
Expand Down Expand Up @@ -434,6 +444,10 @@ class GpuLower : public NonCopyable {
// The shared cluster reduction mbarrier tensor allocated during allocation
// pass
TensorView* cluster_reduction_mbarrier_tensor_ = nullptr;

// The uniform warp id scalar allocated during allocation pass for warp
// specialized kernels
Val* uniform_warp_id_ = nullptr;
};

#define NVFUSER_LOWER_VALIDATE(cond, ...) \
Expand Down
51 changes: 50 additions & 1 deletion csrc/device_lower/pass/allocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -976,7 +976,8 @@ Expr* initializeCircularBufferMbarrier(
GpuLower::current()
->info()
.parallelDimensionMap()
.getNumComputeThreadsEachBlock());
.getNumComputeThreadsEachBlock(
/*only_count_same_compute_warp_groups=*/true));
}

// Initialize mbarrier for each circular buffer stage. Use the thread
Expand Down Expand Up @@ -1231,6 +1232,47 @@ class AllocationInserter : public kir::ExprMutator {
return alloc_expr;
}

void computeUniformWarpId(Expr* expr) {
// Compute flat thread id: tid = threadIdx.x + threadIdx.y * blockDim.x +
// threadIdx.z * blockDim.x * blockDim.y
const auto& pdim = GpuLower::current()->info().parallelDimensionMap();
Val* tid = FusionGuard::getCurFusion()->zeroVal();
Val* bdimx = pdim.getRaw(ParallelType::TIDx);
Val* bdimy = pdim.getRaw(ParallelType::TIDy);
Val* bdimz = pdim.getRaw(ParallelType::TIDz);

if (bdimx != nullptr) {
tid = NamedScalar::getParallelIndex(ParallelType::TIDx);
}
if (bdimy != nullptr) {
Val* tidy = NamedScalar::getParallelIndex(ParallelType::TIDy);
if (bdimx != nullptr) {
tidy = SimplifyingIrBuilder::mulExpr(tidy, bdimx);
}
tid = SimplifyingIrBuilder::addExpr(tid, tidy);
}
if (bdimz != nullptr) {
Val* tidz = NamedScalar::getParallelIndex(ParallelType::TIDz);
if (bdimy != nullptr) {
tidz = SimplifyingIrBuilder::mulExpr(tidz, bdimy);
}
if (bdimx != nullptr) {
tidz = SimplifyingIrBuilder::mulExpr(tidz, bdimx);
}
tid = SimplifyingIrBuilder::addExpr(tid, tidz);
}

// Compute warp_id = tid / 32
Val* warp_size = IrBuilder::create<Val>(32L, DataType::Index);
Val* warp_id = SimplifyingIrBuilder::divExpr(tid, warp_size);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like there is nothing guaranteeing this will be warp-uniform since the compiler cannot know the block size so unless TIDz and TIDy are both >1 then it won't know that tid is the linear thread ID. So do we need to do a warp broadcast? See #2323.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, we need something like

// __shfl_sync helps PTXAS prove that every thread in the warp has the same
// uniform warp id.
__device__ __forceinline__ uint32_t getUniformWarpId() {
  const unsigned int tid = threadIdx.x + threadIdx.y * blockDim.x +
      threadIdx.z * blockDim.x * blockDim.y;
  const unsigned int warp_id = tid / 32;
  return __shfl_sync(0xFFFFFFFF, warp_id, 0);
}

This PR is not ready yet.


// Cast to UInt32 for use in predicates and store in GpuLower
Val* uniform_warp_id =
SimplifyingIrBuilder::maybeCastExpr(DataType::UInt32, warp_id);

GpuLower::current()->setUniformWarpId(uniform_warp_id);
}

// Insert cluster reduction mbarrier allocation and initialization at the
// beginning of the kernel for the first top-level expression
void insertClusterReductionMBarrier(Expr* expr) {
Expand Down Expand Up @@ -1679,6 +1721,13 @@ class AllocationInserter : public kir::ExprMutator {

AllocationInserter(const std::vector<Expr*>& exprs)
: gpu_lower_(GpuLower::current()) {
// Warp-id-based predicates (e.g., warp_id >= threshold) only work when
// async/compute warps have consecutive warp IDs.
if (gpu_lower_->info()
.parallelDimensionMap()
.canUseWarpIdBasedPredicate()) {
computeUniformWarpId(exprs.at(0));
}
// insert cluster reduction mbarrier at top-level scope
if (GpuLower::current()->clusterReductionCount() >= 1) {
insertClusterReductionMBarrier(exprs.at(0));
Expand Down
29 changes: 21 additions & 8 deletions csrc/device_lower/pass/circular_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1566,17 +1566,30 @@ class WarpSpecializedCircularBufferInserter : private kir::ExprMutator {
}

// Create predicate for warp-specialized IfThenElse:
// kir::Predicate is thread_axis >= block_dim_axis - padded_value
// If uniform warp ID is available, use warp-ID-based predicate (warp_id >=
// num_compute_warps)
kir::Predicate* getAsyncWarpPredicate(const CircularBufferOptions& options) {
const ParallelDimensionMap& pdim_map =
GpuLower::current()->info().parallelDimensionMap();

Val* uniform_warp_id = GpuLower::current()->uniformWarpId();
if (uniform_warp_id != nullptr) {
// Use uniform warp ID approach: async warps have warp_id >=
// num_compute_warps
Val* num_compute_warps = pdim_map.getNumComputeWarps();
NVF_ERROR(
num_compute_warps != nullptr,
"num_compute_warps must be initialized");
return IrBuilder::create<kir::Predicate>(
IrBuilder::geExpr(uniform_warp_id, num_compute_warps));
}

// Fallback: use parallel index comparison
ParallelType warp_specialize_on =
std::get<WarpSpecialized>(options.type).on;
int64_t warp_specialization_pad =
GpuLower::current()
->info()
.parallelDimensionMap()
.getWarpSpecializationPaddedVal(warp_specialize_on);
Val* raw = GpuLower::current()->info().parallelDimensionMap().get(
warp_specialize_on);
pdim_map.getWarpSpecializationPaddedVal(warp_specialize_on);
Val* raw = pdim_map.get(warp_specialize_on);
Val* raw_minus_pad = SimplifyingIrBuilder::subExpr(
raw, IrBuilder::create<Val>(warp_specialization_pad, DataType::Index));
return IrBuilder::create<kir::Predicate>(IrBuilder::geExpr(
Expand Down Expand Up @@ -2019,7 +2032,7 @@ kir::ForLoop* HopperPingPongMbarriers::initializePingPongMbarrier() {
GpuLower::current()
->info()
.parallelDimensionMap()
.getNumComputeThreadsEachBlock());
.getNumComputeThreadsEachBlock(true));
kir::TensorIndex* ping_pong_mbarrier_index =
IrBuilder::create<kir::TensorIndex>(mbarriers_, loop->index());
kir::MBarrierInit* ping_pong_mbarrier_init =
Expand Down
74 changes: 58 additions & 16 deletions csrc/parallel_dimension_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ void ParallelDimensionMap::adjustMappingsForWarpPadding() {
exact_types_.erase(ParallelType::TIDx);
}

int64_t ParallelDimensionMap::getThreadCountInDim(ParallelType pt) {
int64_t ParallelDimensionMap::getThreadCountInDim(ParallelType pt) const {
if (!dim_map_.contains(pt)) {
return 1;
}
Expand Down Expand Up @@ -257,12 +257,13 @@ Val* ParallelDimensionMap::getRawAsync(ParallelType pt) const {
return getRaw(pt);
}

Val* ParallelDimensionMap::getNumComputeThreadsEachBlock() const {
Val* ParallelDimensionMap::getNumComputeThreadsEachBlock(
bool only_count_same_compute_warp_groups) const {
Val* num_threads = FusionGuard::getCurFusion()->oneVal();
for (auto pt : kParallelTypeTIDs) {
// Skip warp specialized ParallelType if the are computation warp groups
// are independent.
if (isWarpSpecialized(pt) &&
if (only_count_same_compute_warp_groups && isWarpSpecialized(pt) &&
GpuLower::current()
->circularBufferInfo()
.hasIndependentComputeWarpGroups()) {
Expand All @@ -277,6 +278,35 @@ Val* ParallelDimensionMap::getNumComputeThreadsEachBlock() const {
return num_threads;
}

int64_t ParallelDimensionMap::getWarpSpecializationPaddedVal(
ParallelType pt) const {
NVF_ERROR(isWarpSpecialized(pt), "Can't find ParallelType: ", pt);
if (!warp_specialized_parallel_type_.has_value()) {
return 1;
}
NVF_ERROR(
warp_specialized_parallel_type_.value() == pt,
"Can't find padded val for: ",
pt);
return warp_specialized_padding_value_.value();
}

Val* ParallelDimensionMap::getNumComputeWarps() const {
NVF_ERROR(
hasWarpSpecialization(),
"getNumComputeWarps() should only be called for warp specialized "
"kernels");

Val* num_compute_threads = getNumComputeThreadsEachBlock(
/*only_count_same_compute_warp_groups=*/false);

// Divide by 32 to get the number of warps
Val* num_compute_warps = SimplifyingIrBuilder::divExpr(
num_compute_threads, IrBuilder::create<Val>(32L, DataType::Index));

return num_compute_warps;
}

// For warp-specialization, the CTA is padded so the AsyncWarp contains 128
// threads. This function maps the AsyncWarp CTA to a linear index from
// [0, 128). It is used to divide AsyncWarp into four independent warps.
Expand Down Expand Up @@ -310,19 +340,6 @@ Val* ParallelDimensionMap::getLinearThreadIndexAsync() const {
return index;
}

int64_t ParallelDimensionMap::getWarpSpecializationPaddedVal(
ParallelType pt) const {
NVF_ERROR(isWarpSpecialized(pt), "Can't find ParallelType: ", pt);
if (!warp_specialized_parallel_type_.has_value()) {
return 1;
}
NVF_ERROR(
warp_specialized_parallel_type_.value() == pt,
"Can't find padded val for: ",
pt);
return warp_specialized_padding_value_.value();
}

bool ParallelDimensionMap::canUseElectSyncInAsyncWarp() const {
// short-circuit: skip if warp specialization is not enabled
if (!hasWarpSpecialization()) {
Expand All @@ -344,6 +361,31 @@ bool ParallelDimensionMap::canUseElectSyncInAsyncWarp() const {
return false;
}

bool ParallelDimensionMap::canUseWarpIdBasedPredicate() const {
if (!hasWarpSpecialization()) {
return false;
}

// For consecutive warp IDs, all dimensions after the warp-specialized
// dimension must be 1. Otherwise outer dimensions create gaps in warp IDs.
NVF_ERROR(warp_specialized_parallel_type_.has_value());
ParallelType ws_pt = warp_specialized_parallel_type_.value();

bool found_ws_pt = false;
for (ParallelType pt : kParallelTypeTIDs) {
if (pt == ws_pt) {
found_ws_pt = true;
} else if (found_ws_pt) {
int64_t thread_count = getThreadCountInDim(pt);
if (thread_count == -1 || thread_count > 1) {
return false;
}
}
}

return true;
}

std::string ParallelDimensionMap::toString() const {
std::stringstream ss;
for (auto pt : kParallelTypeThreads) {
Expand Down
37 changes: 33 additions & 4 deletions csrc/parallel_dimension_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,22 @@ class ParallelDimensionMap {
//! And this function will return (32 * 16) because the extra one for TIDy is
//! introduced by warp specialization and only used for loading circular
//! buffer tensors.
Val* getNumComputeThreadsEachBlock() const;

//! Assign linear index to each thread of CTA. Assume (TDZ, TDY, TDX) order.
Val* getNumComputeThreadsEachBlock(
bool only_count_same_compute_warp_groups) const;

//! Get the number of compute warps for warp specialized kernels.
//! This computes the total number of compute threads across all dimensions
//! (TIDx, TIDy, TIDz), using the compute dimension (minus padding) for the
//! warp specialized dimension, then divides by 32 to get the number of warps.
//! Examples:
//! - If warp specialized on TIDx: (bdimx - pad) * bdimy * bdimz / 32
//! - If warp specialized on TIDy: bdimx * (bdimy - pad) * bdimz / 32
//! - If warp specialized on TIDz: bdimx * bdimy * (bdimz - pad) / 32
Val* getNumComputeWarps() const;

//! For warp-specialization, the CTA is padded so the AsyncWarp contains 128
//! threads. This function maps the AsyncWarp CTA to a linear index from
//! [0, 128). It is used to divide AsyncWarp into four independent warps.
Val* getLinearThreadIndexAsync() const;

//! Get if the kernel uses warp specialization
Expand All @@ -96,10 +109,26 @@ class ParallelDimensionMap {
// elect-sync cannot be used.
bool canUseElectSyncInAsyncWarp() const;

//! Check if warp-id-based predicates can be used for warp specialization.
//! Warp-id-based predicates (e.g., warp_id >= N) only work when the
//! warp-specialized dimension produces consecutive warp IDs. This requires
//! that the warp-specialized dimension is the outermost dimension > 1,
//! meaning ALL dimensions after it must be 1.
//!
//! Example: warp specialized on TIDy with CTA (32, 6, 2):
//! TIDz=2 after TIDy causes non-consecutive warps (FAILS)
//! Example: warp specialized on TIDz with CTA (32, 4, 3):
//! No dimensions after TIDz -> consecutive warps (WORKS)
//!
//! Returns true if:
//! - No warp specialization is used, OR
//! - All dimensions after the warp-specialized dimension are 1
bool canUseWarpIdBasedPredicate() const;

private:
//! Get number of threads for ParallelType axis
//! Not used: 1, Const: n, Dynamic: -1
int64_t getThreadCountInDim(ParallelType pt);
int64_t getThreadCountInDim(ParallelType pt) const;

//! TIDx may need to be marked as non-exact as it may be padded to a
//! multiple of the warp size.
Expand Down
Loading