From 135014dffd940843fc42e40e4a35346e98b06f24 Mon Sep 17 00:00:00 2001 From: zhangstevenunity <128771452+zhangstevenunity@users.noreply.github.com> Date: Sat, 2 May 2026 20:20:23 +0800 Subject: [PATCH 01/10] PTOPlanMemory: align with HIVM design (SPEC_LEVEL_3 + retry loop) --- lib/PTO/Transforms/PTOPlanMemory.cpp | 649 ++++++++++++++++++++++++--- lib/PTO/Transforms/PTOPlanMemory.h | 145 +++++- 2 files changed, 715 insertions(+), 79 deletions(-) diff --git a/lib/PTO/Transforms/PTOPlanMemory.cpp b/lib/PTO/Transforms/PTOPlanMemory.cpp index 38d0fd703..ddc8e26c5 100644 --- a/lib/PTO/Transforms/PTOPlanMemory.cpp +++ b/lib/PTO/Transforms/PTOPlanMemory.cpp @@ -153,6 +153,11 @@ static void sortValuesByStableOrder( }); } +static void appendUniqueValue(SmallVectorImpl &values, Value value) { + if (!llvm::is_contained(values, value)) + values.push_back(value); +} + static SmallVector getScratchBuffersFromEffects(Operation *op, ValueRange dpsInits, const StableValueOrderMap &stableValueOrder) { @@ -215,6 +220,136 @@ struct ReserveBufferPlan { using ReserveBufferPlans = SmallVector; +static bool isReserveBufferAddr(Value value, ReserveBufferOp reserveOp) { + return value && value == reserveOp.getAddr(); +} + +static LogicalResult computeFifoLocalBufferSizeBytes(Operation *op, + int64_t slotSizeBytes, + IntegerAttr localSlotNumAttr, + int64_t &sizeBytes) { + if (!localSlotNumAttr) + return success(); + + int64_t localSlotNum = localSlotNumAttr.getInt(); + if (slotSizeBytes <= 0) + return op->emitOpError("expects FIFO slot_size to be positive"); + if (localSlotNum <= 0) + return op->emitOpError("expects FIFO local_slot_num to be positive"); + if (slotSizeBytes > std::numeric_limits::max() / localSlotNum) + return op->emitOpError("FIFO local buffer size overflows int64_t"); + + sizeBytes = slotSizeBytes * localSlotNum; + if (sizeBytes > std::numeric_limits::max()) + return op->emitOpError( + "FIFO local buffer size exceeds reserve_buffer size attribute range"); + return success(); +} + +static FailureOr +computeAutoReserveBufferSizeBytes(func::FuncOp funcOp, + ReserveBufferOp reserveOp) { + std::optional fifoLocalSizeBytes; + bool failedToCompute = false; + + auto updateFromFifo = [&](Operation *op, int64_t slotSizeBytes, + IntegerAttr localSlotNumAttr) -> LogicalResult { + if (!localSlotNumAttr) + return success(); + + int64_t currentSizeBytes = 0; + if (failed(computeFifoLocalBufferSizeBytes( + op, slotSizeBytes, localSlotNumAttr, currentSizeBytes))) + return failure(); + + // One reserve_buffer normally feeds one FIFO. If the IR shares it across + // multiple pipe init ops, reserve enough for the largest local buffer. + fifoLocalSizeBytes = + fifoLocalSizeBytes ? std::max(*fifoLocalSizeBytes, currentSizeBytes) + : currentSizeBytes; + return success(); + }; + + WalkResult walkResult = funcOp.walk([&](Operation *op) -> WalkResult { + if (auto initOp = dyn_cast(op)) { + if (!isReserveBufferAddr(initOp.getLocalAddr(), reserveOp) && + !isReserveBufferAddr(initOp.getPeerLocalAddr(), reserveOp)) + return WalkResult::advance(); + if (failed(updateFromFifo(initOp.getOperation(), initOp.getSlotSize(), + initOp.getLocalSlotNumAttr()))) { + failedToCompute = true; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + } + + if (auto initOp = dyn_cast(op)) { + if (!isReserveBufferAddr(initOp.getC2vConsumerBuf(), reserveOp) && + !isReserveBufferAddr(initOp.getV2cConsumerBuf(), reserveOp)) + return WalkResult::advance(); + if (failed(updateFromFifo(initOp.getOperation(), initOp.getSlotSize(), + initOp.getLocalSlotNumAttr()))) { + failedToCompute = true; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + } + + if (auto initOp = dyn_cast(op)) { + if (!isReserveBufferAddr(initOp.getC2vConsumerBuf(), reserveOp) && + !isReserveBufferAddr(initOp.getV2cConsumerBuf(), reserveOp)) + return WalkResult::advance(); + if (failed(updateFromFifo(initOp.getOperation(), initOp.getSlotSize(), + initOp.getLocalSlotNumAttr()))) { + failedToCompute = true; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + } + + return WalkResult::advance(); + }); + + (void)walkResult; + if (failedToCompute) + return failure(); + return fifoLocalSizeBytes.value_or(reserveOp.getSize()); +} + +static void setReserveBufferSizeBytes(ReserveBufferOp reserveOp, + int64_t sizeBytes) { + reserveOp->setAttr( + "size", + IntegerAttr::get(IntegerType::get(reserveOp.getContext(), kI32BitWidth), + sizeBytes)); +} + +static LogicalResult +validateAutoReserveBufferCapacity(ReserveBufferPlans &plans) { + DenseMap reservedBytesByAddressSpace; + + for (ReserveBufferPlan &plan : plans) { + if (plan.mode != ReserveBufferMode::Auto) + continue; + + int64_t alignedSizeBytes = alignUpBytes(plan.sizeBytes, plan.alignBytes); + if (alignedSizeBytes > plan.capacityBytes) + return plan.reserveOp.emitOpError( + "exceeds available local memory capacity"); + + int64_t usedBytes = reservedBytesByAddressSpace[plan.addressSpace]; + if (usedBytes > plan.capacityBytes - alignedSizeBytes) { + return plan.reserveOp.emitOpError( + "cumulative auto reserve_buffer size exceeds available local " + "memory capacity"); + } + reservedBytesByAddressSpace[plan.addressSpace] = + usedBytes + alignedSizeBytes; + } + + return success(); +} + static LogicalResult analyzeReserveBufferPlans(func::FuncOp funcOp, ReserveBufferPlans &plans) { SmallVector reserveOps; @@ -233,6 +368,15 @@ static LogicalResult analyzeReserveBufferPlans(func::FuncOp funcOp, int64_t capacityBytes = spec.capacityBits / kBitsPerByte; int64_t sizeBytes = reserveOp.getSize(); bool autoAlloc = reserveOp.getAutoAlloc(); + if (autoAlloc) { + auto computedSizeBytes = + computeAutoReserveBufferSizeBytes(funcOp, reserveOp); + if (failed(computedSizeBytes)) + return failure(); + sizeBytes = *computedSizeBytes; + if (sizeBytes != reserveOp.getSize()) + setReserveBufferSizeBytes(reserveOp, sizeBytes); + } ReserveBufferPlan &plan = plans.emplace_back(); plan.mode = autoAlloc ? ReserveBufferMode::Auto : ReserveBufferMode::Manual; @@ -268,6 +412,9 @@ static LogicalResult analyzeReserveBufferPlans(func::FuncOp funcOp, } } + if (failed(validateAutoReserveBufferCapacity(plans))) + return failure(); + return success(); } @@ -357,6 +504,19 @@ static LogicalResult assignAutoReserveBufferBases( return success(); } +static DenseMap +collectAutoReserveBufferBitsByAddressSpace(const ReserveBufferPlans &plans) { + DenseMap reservedBitsByAddressSpace; + for (const ReserveBufferPlan &plan : plans) { + if (plan.mode != ReserveBufferMode::Auto) + continue; + int64_t alignedSizeBytes = alignUpBytes(plan.sizeBytes, plan.alignBytes); + reservedBitsByAddressSpace[plan.addressSpace] += + static_cast(alignedSizeBytes) * kBitsPerByte; + } + return reservedBitsByAddressSpace; +} + } // namespace void MemLivenessAnalysis::build() { @@ -589,31 +749,195 @@ void MemLivenessAnalysis::RecursiveIfOp(scf::IfOp ifOp, Liveness live) { SmallVector MemLivenessAnalysis::GetLiveBuffersInLoop(scf::ForOp forOp, Liveness live) { - SmallVector allocBeforeLoopBuffers; const auto *liveBlockInfo = live.getLiveness(forOp->getBlock()); auto currentLiveValues = liveBlockInfo->currentlyLiveValues(forOp.getOperation()); if (currentLiveValues.empty()) { - return allocBeforeLoopBuffers; + return {}; } // The gen buffer of the same operation must ensure the order of priority. SetVector currentLiveValuesOrder; for (auto buffer : currentLiveValues) { currentLiveValuesOrder.insert(buffer); } + SetVector allocBeforeLoopBufferSet; for (const Value &operand : currentLiveValuesOrder) { auto aliasBuffers = GetAliasBuffers(operand); aliasBuffers.insert(operand); for (auto Buffer : aliasBuffers) { auto iter = buffer2status.find(Buffer); - if (iter != buffer2status.end()) - allocBeforeLoopBuffers.push_back(Buffer); + if (iter == buffer2status.end()) + continue; + if ((iter->second == BufferStatus::DEFFINED || + iter->second == BufferStatus::KILLED) && + CanDelayLoopEntryGenUntilFirstWrite(forOp, Buffer)) { + delayedLoopEntryGenBuffers[Buffer] = true; + continue; + } + allocBeforeLoopBufferSet.insert(Buffer); } } + SmallVector allocBeforeLoopBuffers(allocBeforeLoopBufferSet.begin(), + allocBeforeLoopBufferSet.end()); sortValuesByStableOrder(allocBeforeLoopBuffers, stableValueOrder); return allocBeforeLoopBuffers; } +bool MemLivenessAnalysis::CanDelayLoopEntryGenUntilFirstWrite( + scf::ForOp forOp, Value buffer) { + SetVector aliasBuffers = GetAliasBuffers(buffer); + aliasBuffers.insert(buffer); + Block *body = forOp.getBody(); + if (!body) + return false; + + for (Operation &op : body->without_terminator()) { + if (!OperationOrNestedRegionTouchesAnyAlias(&op, aliasBuffers)) + continue; + if (auto nestedForOp = dyn_cast(&op)) { + return llvm::any_of(aliasBuffers, [&](Value alias) { + return CanDelayLoopEntryGenUntilFirstWrite(nestedForOp, alias); + }); + } + return IsWriteOnlyDpsInitForAlias(&op, aliasBuffers); + } + return false; +} + +bool MemLivenessAnalysis::OperationDirectlyTouchesAnyAlias( + Operation *op, const SetVector &aliasBuffers) const { + auto touchesValue = [&](Value value) { + return value && llvm::is_contained(aliasBuffers, value); + }; + for (Value operand : op->getOperands()) { + if (touchesValue(operand)) + return true; + } + for (Value result : op->getResults()) { + if (touchesValue(result)) + return true; + } + + auto memEffect = dyn_cast(op); + if (!memEffect) + return false; + SmallVector, + kMemoryEffectReserveSize> + effects; + memEffect.getEffects(effects); + for (const auto &effect : effects) { + if (touchesValue(effect.getValue())) + return true; + } + return false; +} + +bool MemLivenessAnalysis::OperationOrNestedRegionTouchesAnyAlias( + Operation *op, const SetVector &aliasBuffers) const { + if (OperationDirectlyTouchesAnyAlias(op, aliasBuffers)) + return true; + if (op->getNumRegions() == 0) + return false; + + bool touches = false; + op->walk([&](Operation *nestedOp) { + if (nestedOp == op) + return WalkResult::advance(); + if (OperationDirectlyTouchesAnyAlias(nestedOp, aliasBuffers)) { + touches = true; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + return touches; +} + +bool MemLivenessAnalysis::OperationReadsAnyAlias( + Operation *op, const SetVector &aliasBuffers) const { + auto touchesValue = [&](Value value) { + return value && llvm::is_contained(aliasBuffers, value); + }; + + auto memEffect = dyn_cast(op); + if (!memEffect) { + return llvm::any_of(op->getOperands(), touchesValue); + } + + SmallVector, + kMemoryEffectReserveSize> + effects; + memEffect.getEffects(effects); + return llvm::any_of(effects, [&](const auto &effect) { + return isa(effect.getEffect()) && + touchesValue(effect.getValue()); + }); +} + +bool MemLivenessAnalysis::IsWriteOnlyDpsInitForAlias( + Operation *op, const SetVector &aliasBuffers) const { + auto ptoDpsOp = dyn_cast(op); + if (!ptoDpsOp) + return false; + + bool hasAliasDpsInit = false; + for (Value init : ptoDpsOp.getDpsInits()) { + if (llvm::is_contained(aliasBuffers, init)) { + hasAliasDpsInit = true; + break; + } + } + if (!hasAliasDpsInit) + return false; + + auto memEffect = dyn_cast(op); + if (!memEffect) + return false; + SmallVector, + kMemoryEffectReserveSize> + effects; + memEffect.getEffects(effects); + + bool hasWrite = false; + for (const auto &effect : effects) { + Value value = effect.getValue(); + if (!value || !llvm::is_contained(aliasBuffers, value)) + continue; + if (isa(effect.getEffect())) + return false; + if (isa(effect.getEffect())) { + hasWrite = true; + continue; + } + return false; + } + return hasWrite; +} + +bool MemLivenessAnalysis::CanKillBeforeNextOverwrite( + Operation *op, const SetVector &aliasBuffers) { + for (Operation *nextOp = op->getNextNode(); nextOp; + nextOp = nextOp->getNextNode()) { + if (!OperationOrNestedRegionTouchesAnyAlias(nextOp, aliasBuffers)) + continue; + + if (auto forOp = dyn_cast(nextOp)) { + return llvm::any_of(aliasBuffers, [&](Value alias) { + return CanDelayLoopEntryGenUntilFirstWrite(forOp, alias); + }); + } + + return IsWriteOnlyDpsInitForAlias(nextOp, aliasBuffers); + } + return false; +} + +bool MemLivenessAnalysis::CanRegenerateBufferAtOp(Operation *op, + Value buffer) { + SetVector aliasBuffers = GetAliasBuffers(buffer); + aliasBuffers.insert(buffer); + return IsWriteOnlyDpsInitForAlias(op, aliasBuffers); +} + LogicalResult MemLivenessAnalysis::CheckLocalBufferAllocOp(Operation *op) const { auto allocOp = dyn_cast(op); @@ -732,11 +1056,25 @@ void MemLivenessAnalysis::UpdateOperandGenInfo(OpInfo *opInfo, Value operand) { if (iter_buffer == buffer2status.end()) return; if (iter_buffer->second == BufferStatus::DEFFINED) { - genKillMap[opInfo].gen.push_back(operand); + appendUniqueValue(genKillMap[opInfo].gen, operand); buffer2status[iter_buffer->first] = BufferStatus::GENED; + buffer2GenOp[iter_buffer->first] = opInfo->operation; } else if (iter_buffer->second == BufferStatus::KILLED) { - llvm_unreachable("The buffer memory has been released and cannot be used " - "again! "); + if (!CanRegenerateBufferAtOp(opInfo->operation, operand)) { + llvm_unreachable("The buffer memory has been released and cannot be " + "used again before it is redefined! "); + } + appendUniqueValue(genKillMap[opInfo].gen, operand); + buffer2status[iter_buffer->first] = BufferStatus::GENED; + buffer2GenOp[iter_buffer->first] = opInfo->operation; + } else if (iter_buffer->second == BufferStatus::GENED) { + SetVector aliasBuffers = GetAliasBuffers(operand); + aliasBuffers.insert(operand); + if (IsWriteOnlyDpsInitForAlias(opInfo->operation, aliasBuffers)) { + appendUniqueValue(genKillMap[opInfo].kill, operand); + appendUniqueValue(genKillMap[opInfo].gen, operand); + buffer2GenOp[iter_buffer->first] = opInfo->operation; + } } } @@ -750,7 +1088,13 @@ void MemLivenessAnalysis::OpKillHandle(OpInfo *opInfo, Liveness live, } SmallVector liveValues(currentLiveValues.begin(), currentLiveValues.end()); + // Always begin from a stable ordering so seed=0 reproduces the original + // single-shot PTO behavior. When the PlanMemoryPass retry loop drives a + // non-zero seed, getShuffledRange permutes the candidates to expose + // alternative gen/kill orderings - the search dimension that lets a later + // attempt succeed where the first one wedged on a pathological order. sortValuesByStableOrder(liveValues, stableValueOrder); + liveValues = getShuffledRange(liveValues); for (const Value &operand : liveValues) { UpdateOpKillInfo(opInfo, operand, live); } @@ -764,15 +1108,35 @@ void MemLivenessAnalysis::UpdateOpKillInfo(OpInfo *opInfo, Value operand, auto iterBuffer = buffer2status.find(aliasBuffer); if (iterBuffer == buffer2status.end()) return; - if (iterBuffer->second == BufferStatus::GENED && - IsInSameBlock(iterBuffer->first.getDefiningOp(), opInfo->operation) && - AllDeadAfter(opInfo->operation, aliasBuffers, live)) { - genKillMap[opInfo].kill.push_back(aliasBuffer); + Operation *defOp = iterBuffer->first.getDefiningOp(); + bool canKillInThisBlock = + defOp && IsInSameBlock(defOp, opInfo->operation); + auto delayedGen = delayedLoopEntryGenBuffers.find(iterBuffer->first); + if (!canKillInThisBlock && delayedGen != delayedLoopEntryGenBuffers.end() && + delayedGen->second) { + Operation *genOp = GetBufferGenOp(iterBuffer->first); + canKillInThisBlock = genOp && IsInSameBlock(genOp, opInfo->operation); + } + bool canKillCurrentValue = + AllDeadAfter(opInfo->operation, aliasBuffers, live) || + (OperationReadsAnyAlias(opInfo->operation, aliasBuffers) && + CanKillBeforeNextOverwrite(opInfo->operation, aliasBuffers)); + if (iterBuffer->second == BufferStatus::GENED && canKillInThisBlock && + canKillCurrentValue) { + appendUniqueValue(genKillMap[opInfo].kill, aliasBuffer); buffer2status[iterBuffer->first] = BufferStatus::KILLED; + buffer2GenOp.erase(iterBuffer->first); } } } +Operation *MemLivenessAnalysis::GetBufferGenOp(Value buffer) const { + auto it = buffer2GenOp.find(buffer); + if (it != buffer2GenOp.end()) + return it->second; + return nullptr; +} + bool MemLivenessAnalysis::IsInSameBlock(Operation *op1, Operation *op2) const { return op1->getBlock() == op2->getBlock(); } @@ -839,25 +1203,43 @@ BufferInfo MemLivenessAnalysis::GetBufferInfo(Operation *op, Value operand, void MemLivenessAnalysis::GenerateBufferLife() { int scopeTime = 0; + DenseMap> openLives; for (size_t i = 0; i < linearOperation.size(); ++i) { auto it = genKillMap.find(linearOperation[i].get()); if (it == genKillMap.end()) { scopeTime++; continue; } - // Time given to buffer start. + + SmallVector postGenKills; + for (const Value &killBuffer : it->second.kill) { + auto iter = openLives.find(killBuffer); + if (iter != openLives.end()) { + iter->second->freeTime = scopeTime; + openLives.erase(iter); + continue; + } + if (!llvm::is_contained(it->second.gen, killBuffer)) + llvm::report_fatal_error("buffer lifetime killed before generation"); + appendUniqueValue(postGenKills, killBuffer); + } + for (const Value &genBuffer : it->second.gen) { - std::unique_ptr bufferLife = - std::make_unique(genBuffer); + if (openLives.find(genBuffer) != openLives.end()) + llvm::report_fatal_error("buffer lifetime generated before release"); + std::shared_ptr bufferLife = + std::make_shared(genBuffer); bufferLife->allocTime = scopeTime; - buffer2Life[genBuffer] = std::move(bufferLife); + buffer2Life[genBuffer].push_back(bufferLife); + openLives[genBuffer] = std::move(bufferLife); } - // Time given to buffer end. - for (const Value &killBuffer : it->second.kill) { - auto iter = buffer2Life.find(killBuffer); - if (iter == buffer2Life.end()) - llvm::report_fatal_error("buffer lifetime killed before generation"); + + for (const Value &killBuffer : postGenKills) { + auto iter = openLives.find(killBuffer); + if (iter == openLives.end()) + llvm::report_fatal_error("buffer lifetime generated after release"); iter->second->freeTime = scopeTime; + openLives.erase(iter); } scopeTime++; } @@ -939,9 +1321,10 @@ void MemPlan::EmitPlanMemoryFailureInfo() { return; for (auto &iter : failApplyBufferInfo) { AddressSpace space = iter.first; + auto bufferSpaceInfo = GetPlannableBufferSpaceInfo(space); func_.emitError() << stringifyEnum(space) << " overflow, requires " - << iter.second << " bits while " - << GetBufferSpaceInfo(space).second << " bits avaliable!"; + << iter.second << " bits while " << bufferSpaceInfo.second + << " bits avaliable!"; } } @@ -960,7 +1343,7 @@ bool MemPlan::RecordOverflowIfAny() { continue; } auto bufferSpaceInfo = - GetBufferSpaceInfo(rootStorageEntry->bufInfo->bufferScope); + GetPlannableBufferSpaceInfo(rootStorageEntry->bufInfo->bufferScope); size_t maxBits = bufferSpaceInfo.second; uint64_t maxAllocBits = rootStorageEntry->alignedConstBits; for (auto *child : rootStorageEntry->mergedChildren) { @@ -1002,7 +1385,7 @@ bool MemPlan::HasSemanticConflict(const StorageEntry *entry, } // Plan Memory algorithm. -LogicalResult MemPlan::plan() { +LogicalResult MemPlan::plan(bool emitErrors) { // Construct StorageEntry structure. GenerateStorageEntry(); // Plan memory address. @@ -1010,11 +1393,13 @@ LogicalResult MemPlan::plan() { ? PlanLocalMemAddress() : PlanWorkSpaceMemAddress(); if (as == PlanStatus::PLAN_FAILED) { - EmitPlanMemoryFailureInfo(); + if (emitErrors) + EmitPlanMemoryFailureInfo(); return failure(); } if (RecordOverflowIfAny()) { - EmitPlanMemoryFailureInfo(); + if (emitErrors) + EmitPlanMemoryFailureInfo(); return failure(); } auto hasAddressOverlap = [](const StorageEntry *lhs, const StorageEntry *rhs) { @@ -1051,10 +1436,13 @@ LogicalResult MemPlan::plan() { if (!lifeOverlap && !semanticConflict) { continue; } - func_.emitError() - << "PlanMemory produced overlapping local buffers in " - << stringifyEnum(lhs->bufInfo->bufferScope) - << " at offsets " << lhs->bitsOffset << " and " << rhs->bitsOffset; + if (emitErrors) { + func_.emitError() + << "PlanMemory produced overlapping local buffers in " + << stringifyEnum(lhs->bufInfo->bufferScope) + << " at offsets " << lhs->bitsOffset << " and " + << rhs->bitsOffset; + } return failure(); } } @@ -1068,6 +1456,7 @@ LogicalResult MemPlan::plan() { void MemPlan::GenerateStorageEntry() { // create new storage entry. + SetVector seenBuffers; for (auto &operation : linearOperation) { auto it = genKillMap.find(operation.get()); if (it == genKillMap.end()) @@ -1075,14 +1464,17 @@ void MemPlan::GenerateStorageEntry() { SmallVector genBuffers(it->second.gen.begin(), it->second.gen.end()); sortValuesByStableOrder(genBuffers, stableValueOrder); for (const Value &genBuffer : genBuffers) { + if (llvm::is_contained(seenBuffers, genBuffer)) + continue; auto iter = bufferInfos.find(genBuffer); if (iter == bufferInfos.end()) { continue; } - const std::shared_ptr &bufLife = buffer2Life.at(genBuffer); + seenBuffers.insert(genBuffer); + const BufferLifeVec &bufLives = buffer2Life.at(genBuffer); std::unique_ptr entry = std::make_unique(); entry->bufInfo = &iter->second; - entry->bufferLifeVec.emplace_back(bufLife); + entry->bufferLifeVec.append(bufLives.begin(), bufLives.end()); entry->inplaceBuffers.emplace_back(iter->first); auto multiBuffer = buffer2MultiNum.find(genBuffer); if (multiBuffer != buffer2MultiNum.end()) { @@ -1312,7 +1704,7 @@ void MemPlan::MergeSameScopeSE() { // set bufferScope2RequiredSize for all StorageEntry for (auto &rootStorageEntry : memscope2rootStorageEntry) { - auto bufferSpaceInfo = GetBufferSpaceInfo(rootStorageEntry.first); + auto bufferSpaceInfo = GetPlannableBufferSpaceInfo(rootStorageEntry.first); size_t accumulateSize = AlignUp(rootStorageEntry.second->bufInfo->constBits, bufferSpaceInfo.first); for (auto &childrenStorageEntry : rootStorageEntry.second->mergedChildren) { @@ -1371,7 +1763,7 @@ PlanStatus MemPlan::PlanMemAddressOfWholeLocalBuffer() { StorageEntry *rootStorageEntry = it.second; // get the buffer info for a given scope. auto bufferSpaceInfo = - GetBufferSpaceInfo(rootStorageEntry->bufInfo->bufferScope); + GetPlannableBufferSpaceInfo(rootStorageEntry->bufInfo->bufferScope); size_t align = bufferSpaceInfo.first; size_t maxBits = bufferSpaceInfo.second; if (rootStorageEntry->mergedChildren.empty()) { @@ -1581,6 +1973,20 @@ MemPlan::GetBufferSpaceInfo(pto::AddressSpace &space) const { llvm_unreachable("Temporarily unsupported memory buffer space !"); } +std::pair +MemPlan::GetPlannableBufferSpaceInfo(pto::AddressSpace &space) const { + auto bufferSpaceInfo = GetBufferSpaceInfo(space); + auto it = reservedBufferBitsByScope.find(space); + if (it == reservedBufferBitsByScope.end()) { + return bufferSpaceInfo; + } + if (it->second >= bufferSpaceInfo.second) { + return std::make_pair(bufferSpaceInfo.first, size_t{0}); + } + return std::make_pair(bufferSpaceInfo.first, + bufferSpaceInfo.second - it->second); +} + LogicalResult MemPlan::MultiSpecPlan(SpecInfo &si, MemBoundList &outline, PlanRecHis &history, StorageEntry *entry) { LogicalResult planResult = failure(); @@ -1602,11 +2008,24 @@ LogicalResult MemPlan::MultiSpecPlan(SpecInfo &si, MemBoundList &outline, LogicalResult MemPlan::SpecAlloc(MemBoundList &outline, PlanRecHis &his, StorageEntry *e, const SpecInfo &si, int localLevel) { + if (e == nullptr) { + // Defensive: a null entry would otherwise crash later when reading + // `e->alignedConstBits` / `e->bufferLifeVec`. + return failure(); + } if (std::any_of(his.begin(), his.end(), [e](PlanRecord &r) { return r.entry && r.entry == e; })) { // If the plan has already been completed, return success directly. return success(); } + // Zero-sized entries (e.g. degenerate / dynamically-shaped buffers that were + // statically resolved to 0 bits) cannot meaningfully consume an outline + // bound. Pin them at offset 0 and report success so the rest of the planner + // can advance. Mirrors HIVM PlanMemory.cpp behavior. + if (e->alignedConstBits == 0) { + e->bitsOffset = 0; + return success(); + } for (MemBoundListConstIter start = outline.begin(); start != outline.end(); ++start) { uint64_t size = 0; @@ -1637,6 +2056,9 @@ LogicalResult MemPlan::SpecAlloc(MemBoundList &outline, PlanRecHis &his, if (VerifyConflictStage2(his, e, localLevel, start, outline)) { break; } + if (VerifyConflictStage3(his, e, localLevel, start, outline)) { + break; + } e->bitsOffset = allocOffset; UpdateOutline(outline, his, e, OutlineSectionInfo(start, end, size, false), localLevel); @@ -1807,22 +2229,21 @@ void MemPlan::PlanRelationPongEntryAddress(uint64_t offset, StorageEntry *e) { } } -bool MemPlan::VerifyConflictStage2(PlanRecHis &his, const StorageEntry *e, - int specLevel, MemBoundListConstIter &start, - const MemBoundList &outline) { - if (specLevel != SPEC_LEVEL_2) { - return false; - } +bool MemPlan::VerifyConflictStageCommon( + PlanRecHis &his, const StorageEntry *e, MemBoundListConstIter &start, + const MemBoundList &outline, + std::function + conflictChecker) { bool touchMemCanUse = false; MemBoundListConstIter foundMem; for (auto iter = start; iter != outline.end(); ++iter) { uint64_t offset = (*iter)->offset; - bool conflict = - std::any_of(his.begin(), his.end(), [offset, e, this](PlanRecord &r) { + bool conflict = std::any_of( + his.begin(), his.end(), [offset, e, &conflictChecker](PlanRecord &r) { return (r.firstMemBound->offset + r.allExtent > offset) && (r.firstMemBound->offset < offset + e->alignedConstBits) && - this->PipeConflict(r.entry, e, this->pipeDmaConflictMap); + conflictChecker(r.entry, e); }); // if conflict, continue finding the first bound that has no conflict // if last bound do not meet the size, continue @@ -1844,6 +2265,55 @@ bool MemPlan::VerifyConflictStage2(PlanRecHis &his, const StorageEntry *e, return true; } +bool MemPlan::VerifyConflictStage2(PlanRecHis &his, const StorageEntry *e, + int specLevel, MemBoundListConstIter &start, + const MemBoundList &outline) { + if (specLevel != SPEC_LEVEL_2) { + return false; + } + // SPEC_LEVEL_2 only blocks reuse for buffers that pipe-conflict *and* live + // in the same parent loop. Buffers in different loop nests are allowed to + // share an offset even if they would conflict on a pipe basis - the looser + // policy compared to SPEC_LEVEL_3. + return VerifyConflictStageCommon( + his, e, start, outline, + [this](const StorageEntry *e1, const StorageEntry *e2) { + return this->PipeConflictInSameLoop(e1, e2); + }); +} + +bool MemPlan::VerifyConflictStage3(PlanRecHis &his, const StorageEntry *e, + int specLevel, MemBoundListConstIter &start, + const MemBoundList &outline) { + if (specLevel != SPEC_LEVEL_3) { + return false; + } + // SPEC_LEVEL_3 forbids reuse whenever any pipe conflict exists, regardless + // of loop scope - the most conservative pipe policy and the level + // MultiSpecPlan attempts first. + return VerifyConflictStageCommon( + his, e, start, outline, + [this](const StorageEntry *e1, const StorageEntry *e2) { + return this->PipeConflict(e1, e2, this->pipeDmaConflictMap); + }); +} + +bool MemPlan::PipeConflictInSameLoop(const StorageEntry *e1, + const StorageEntry *e2) { + if (e1 == nullptr || e2 == nullptr) { + return false; + } + // Only treat the conflict as fatal when both entries hang off the same + // parent loop. Distinct loops (or top-level buffers) are deliberately + // permitted to share an offset under SPEC_LEVEL_2. + auto parentLoop1 = GetBufferParentLoop(e1->inplaceBuffers); + auto parentLoop2 = GetBufferParentLoop(e2->inplaceBuffers); + if (parentLoop1 != parentLoop2) { + return false; + } + return true; +} + bool MemPlan::PipeConflict(const StorageEntry *e1, const StorageEntry *e2, DenseMap &conflictMap) { if (e1 == nullptr || e2 == nullptr) { @@ -1871,6 +2341,11 @@ void MemPlan::UpdateOutline(MemBoundList &outline, PlanRecHis &his, StorageEntry *e, const OutlineSectionInfo &outlineInfo, int localLevel) const { + if (e == nullptr) { + // Defensive: skip outline mutation when the caller passed a null entry + // (mirrors HIVM PlanMemory.cpp). + return; + } auto start = outlineInfo.mem_start; MemBoundListConstIter end = outlineInfo.mem_end; // outline: @@ -2095,7 +2570,7 @@ void MemPlan::ReportAllocatedEntryDebugInfo(StorageEntry *rootStorageEntry) { } size_t num = allocatedEntry.size() - 1; if (rootStorageEntry->mergedChildren.size() <= num) - llvm::report_fatal_error("missing failed storage entry"); + return; const StorageEntry *failedSe = rootStorageEntry->mergedChildren[num]; printRecord(failedSe); LDBG("alloc fail,because exceed bound of memory \n" @@ -2256,6 +2731,20 @@ struct PlanMemoryPass : public mlir::pto::impl::PlanMemoryBase { } // namespace void PlanMemoryPass::runOnOperation() { + // The plan-memory algorithm is sensitive to the order in which liveness + // candidates are visited. To dampen that sensitivity (and avoid spurious + // overflows on order-dependent corner cases) the pass retries planning up + // to `kPlanRetryCount` times with deterministic but distinct shuffle seeds. + // - The first attempt (seed=0) preserves the original PTO single-shot + // behavior: stable sort, no shuffle. + // - Subsequent attempts use seed=attempt to permute candidates inside + // `MemLivenessAnalysis::OpKillHandle`, exposing alternate gen/kill + // orderings. + // - Diagnostics from `MemPlan::plan` are suppressed on every attempt + // except the last, so a recoverable failure on attempt N does not pollute + // the user's error output when attempt N+1 succeeds. + constexpr int kPlanRetryCount = 20; + ModuleOp moduleOp = getOperation(); for (auto funcOp : moduleOp.getOps()) { ReserveBufferPlans reservePlans; @@ -2274,38 +2763,70 @@ void PlanMemoryPass::runOnOperation() { } } - MemLivenessAnalysis memLiveness(funcOp, this->memMode); - memLiveness.build(); + DenseMap> plannedBuffer2Offsets; + std::map plannedBufferInfos; + bool planSucceeded = false; - MemPlan memPlan(this->memMode, this->enableGlobalReuse, - this->enablePrintMemoryAllocatedSize, - this->restrictInplaceAsISA); - if (failed(memPlan.InitMemSpecsFromModule(funcOp))) { - return signalPassFailure(); + for (int attempt = 0; attempt < kPlanRetryCount; ++attempt) { + LDBG("Memory planning attempt " << attempt + 1 << "/" << kPlanRetryCount + << "\n"); + + MemLivenessAnalysis memLiveness(funcOp, this->memMode, + /*randomSeed=*/static_cast(attempt)); + memLiveness.build(); + + MemPlan memPlan(this->memMode, this->enableGlobalReuse, + this->enablePrintMemoryAllocatedSize, + this->restrictInplaceAsISA); + if (failed(memPlan.InitMemSpecsFromModule(funcOp))) { + return signalPassFailure(); + } + memPlan.func_ = funcOp; + memPlan.SetLinearOperation(memLiveness.linearOperation); + // Snapshot bufferInfos before SetBufferInfos copies it into memPlan, so + // that on success we can hand them to assignAutoReserveBufferBases + // without keeping `memLiveness` alive past the loop iteration. + auto bufferInfosSnapshot = memLiveness.bufferInfos; + memPlan.SetBufferInfos(memLiveness.bufferInfos); + memPlan.SetBuffer2Life(memLiveness.buffer2Life); + memPlan.SetGenKillMap(memLiveness.genKillMap); + memPlan.SetBuffer2MultiNum(memLiveness.buffer2MultiNum); + memPlan.SetInplacePairList(memLiveness.inplacePairList); + memPlan.SetSemanticConflictPairs(memLiveness.semanticConflictPairs); + memPlan.SetStableValueOrder(std::move(memLiveness.stableValueOrder)); + memPlan.SetReservedBufferBitsByScope( + collectAutoReserveBufferBitsByAddressSpace(reservePlans)); + + const bool isLastAttempt = attempt == kPlanRetryCount - 1; + if (succeeded(memPlan.plan(/*emitErrors=*/isLastAttempt))) { + plannedBuffer2Offsets = memPlan.GetBuffer2Offsets(); + plannedBufferInfos = std::move(bufferInfosSnapshot); + planSucceeded = true; + break; + } + if (isLastAttempt) { + // Errors were already emitted by the final memPlan.plan() call. + return signalPassFailure(); + } } - memPlan.func_ = funcOp; - memPlan.SetLinearOperation(memLiveness.linearOperation); - memPlan.SetBufferInfos(memLiveness.bufferInfos); - memPlan.SetBuffer2Life(memLiveness.buffer2Life); - memPlan.SetGenKillMap(memLiveness.genKillMap); - memPlan.SetBuffer2MultiNum(memLiveness.buffer2MultiNum); - memPlan.SetInplacePairList(memLiveness.inplacePairList); - memPlan.SetSemanticConflictPairs(memLiveness.semanticConflictPairs); - memPlan.SetStableValueOrder(std::move(memLiveness.stableValueOrder)); - if (failed(memPlan.plan())) { + + if (!planSucceeded) { + // Defensive: should be unreachable because the loop above either breaks + // on success or signals failure on the last attempt. return signalPassFailure(); } + // Keep reserve_buffer allocation outside the core MemPlan algorithm: // normal local buffers are planned first, then reserve_buffer claims one // aligned hole in its target address space. if (this->memMode == MemPlanMode::LOCAL_MEM_PLAN && - failed(assignAutoReserveBufferBases(reservePlans, memLiveness.bufferInfos, - memPlan.GetBuffer2Offsets()))) { + failed(assignAutoReserveBufferBases(reservePlans, plannedBufferInfos, + plannedBuffer2Offsets))) { return signalPassFailure(); } RewritePatternSet patterns(&getContext()); - populateBufferAddressToAllocOp(patterns, memPlan.GetBuffer2Offsets()); + populateBufferAddressToAllocOp(patterns, plannedBuffer2Offsets); if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { return signalPassFailure(); } diff --git a/lib/PTO/Transforms/PTOPlanMemory.h b/lib/PTO/Transforms/PTOPlanMemory.h index 6a1d40077..91f27d17e 100644 --- a/lib/PTO/Transforms/PTOPlanMemory.h +++ b/lib/PTO/Transforms/PTOPlanMemory.h @@ -22,7 +22,9 @@ +#include #include +#include namespace mlir { namespace pto { @@ -63,9 +65,16 @@ constexpr const int SPEC_LEVEL_0 = 0; /// continuous instructions caused by plan, offset = 1. constexpr const int SPEC_LEVEL_1 = 1; -/// pipe conflict opt. +/// pipe conflict opt for buffers in the same parent loop only. +/// Less restrictive than SPEC_LEVEL_3: only blocks reuse when the conflicting +/// buffers share a parent loop, allowing reuse across distinct loop nests. constexpr const int SPEC_LEVEL_2 = 2; +/// pipe conflict opt for any pipe conflict (most conservative reuse policy). +/// This is the initial level the planner attempts; it falls back to less +/// restrictive levels (2 -> 1 -> 0) when allocation fails. +constexpr const int SPEC_LEVEL_3 = 3; + /// plan information of alloc buffer. struct BufferInfo { /// Alloc operation of buffer. @@ -205,9 +214,12 @@ struct PlanRecord { using PlanRecHis = SmallVector; struct SpecInfo { - int maxLevel = SPEC_LEVEL_2; + /// Initial / "ceiling" level. Defaults to SPEC_LEVEL_3 so the planner starts + /// with the most conservative pipe-conflict policy and degrades to lower + /// levels (2 -> 1 -> 0) on failure, matching the HIVM reference behavior. + int maxLevel = SPEC_LEVEL_3; int minLevel = SPEC_LEVEL_0; - int specLevel = SPEC_LEVEL_2; + int specLevel = SPEC_LEVEL_3; int childIdx = -1; int specStartIdx = 0; int rollbackIdx = -1; @@ -259,8 +271,15 @@ struct StatusWrapper { class MemLivenessAnalysis { public: - MemLivenessAnalysis(func::FuncOp func, MemPlanMode planMode) - : func_(func), planMode(planMode) {} + /// `randomSeed` controls the deterministic shuffle of liveness candidates + /// during gen/kill collection. The default seed of 0 preserves the + /// stable-sort ordering of the original PTO behavior; non-zero seeds + /// (driven by the retry loop in PlanMemoryPass) explore alternative + /// orderings to recover from order-sensitive planning failures. + MemLivenessAnalysis(func::FuncOp func, MemPlanMode planMode, + uint32_t randomSeed = 0) + : func_(func), planMode(planMode), randomSeed(randomSeed), + randomGenerator(randomSeed) {} void build(); @@ -273,8 +292,8 @@ class MemLivenessAnalysis { /// stable IR order for Values used to keep memory planning deterministic. StableValueOrderMap stableValueOrder; - /// map from buffer to its lifetime. - DenseMap> buffer2Life; + /// map from buffer to its lifetime intervals. + DenseMap buffer2Life; /// map from operation to its gen and kill buffer. DenseMap genKillMap; @@ -303,6 +322,37 @@ class MemLivenessAnalysis { /// Get the buffer used within the loop and defined outside the loop. SmallVector GetLiveBuffersInLoop(scf::ForOp forOp, Liveness live); + /// Check whether a loop-live buffer can be generated by its first overwrite + /// inside the loop instead of being generated at the loop entry. + bool CanDelayLoopEntryGenUntilFirstWrite(scf::ForOp forOp, Value buffer); + + /// Return true if op directly uses or effects any alias buffer. + bool OperationDirectlyTouchesAnyAlias(Operation *op, + const SetVector &aliasBuffers) const; + + /// Return true if op or any nested op touches any alias buffer. + bool OperationOrNestedRegionTouchesAnyAlias( + Operation *op, const SetVector &aliasBuffers) const; + + /// Return true if op directly reads any alias buffer. + bool OperationReadsAnyAlias(Operation *op, + const SetVector &aliasBuffers) const; + + /// Return true if op fully overwrites one alias as a DPS init without first + /// reading any alias in the same alias set. + bool IsWriteOnlyDpsInitForAlias(Operation *op, + const SetVector &aliasBuffers) const; + + /// Return true if the next touch after op redefines the buffer before read. + bool CanKillBeforeNextOverwrite(Operation *op, + const SetVector &aliasBuffers); + + /// Return true if a killed buffer can start a new lifetime at op. + bool CanRegenerateBufferAtOp(Operation *op, Value buffer); + + /// Return the operation that actually generated the buffer lifetime. + Operation *GetBufferGenOp(Value buffer) const; + /// Update for Op tensor init args and tensor result args alias info. void UpdateInitAndResAlias(DestinationStyleOpInterface dstStyleOp); @@ -399,10 +449,36 @@ class MemLivenessAnalysis { /// Gen-kill status corresponding to buffer. DenseMap buffer2status; + /// Operation where the current buffer lifetime was generated. + DenseMap buffer2GenOp; + + /// Buffers whose loop-entry generation was delayed to their first write in + /// the loop body. + DenseMap delayedLoopEntryGenBuffers; + /// map on buffer alias DenseMap> buffer2AliasVec; int seqIndex{0}; + + /// Deterministic shuffle seed forwarded into liveness collection so the + /// PlanMemory retry loop can sample different gen/kill orderings. + uint32_t randomSeed{0}; + + /// Random generator used by `getShuffledRange` for deterministic shuffles. + std::mt19937 randomGenerator; + +public: + /// Return a copy of `range` deterministically shuffled with + /// `randomGenerator`. When `randomSeed == 0` the shuffle is skipped so the + /// first attempt preserves the stable order used by single-shot PTO runs. + template RangeT getShuffledRange(const RangeT &range) { + RangeT rangeClone = range; + if (randomSeed == 0) + return rangeClone; + std::shuffle(rangeClone.begin(), rangeClone.end(), randomGenerator); + return rangeClone; + } }; /// Pair of StorageEntry. @@ -416,7 +492,10 @@ class MemPlan { enablePrintMemoryAllocatedSize(enablePrintMemoryAllocatedSize), restrictInplaceAsISA(restrictInplaceAsISA) {} - LogicalResult plan(); + /// Run the memory-planning algorithm. When `emitErrors` is false, failure + /// diagnostics are suppressed; this lets the PlanMemoryPass retry loop swallow + /// intermediate failures and only surface errors on the final attempt. + LogicalResult plan(bool emitErrors = true); /// Get buffer2Offsets inline DenseMap> GetBuffer2Offsets() { @@ -433,9 +512,8 @@ class MemPlan { bufferInfos = bufsInfo; } - inline void - SetBuffer2Life(DenseMap> buf2Life) { - buffer2Life = buf2Life; + inline void SetBuffer2Life(DenseMap buf2Life) { + buffer2Life = std::move(buf2Life); } inline void SetGenKillMap(DenseMap gkMap) { @@ -458,6 +536,11 @@ class MemPlan { stableValueOrder = std::move(valueOrder); } + inline void + SetReservedBufferBitsByScope(DenseMap reservedBits) { + reservedBufferBitsByScope = std::move(reservedBits); + } + /// Setup the device's storage specs LogicalResult InitMemSpecsFromModule(func::FuncOp funcOp); @@ -541,6 +624,10 @@ class MemPlan { /// Obtain buffer space size and alignment information. std::pair GetBufferSpaceInfo(pto::AddressSpace &space) const; + /// Obtain effective buffer space size after accounting for reserve_buffer. + std::pair + GetPlannableBufferSpaceInfo(pto::AddressSpace &space) const; + /// Emit buffer applied failure message. void EmitPlanMemoryFailureInfo(); @@ -552,21 +639,46 @@ class MemPlan { LogicalResult SpecAlloc(MemBoundList &outline, PlanRecHis &his, StorageEntry *e, const SpecInfo &si, int localLevel); - /// spec_level == SPEC_LEVEL_2, mte2/3 do not reuse with vector. + /// spec_level == SPEC_LEVEL_2: pipe conflict only blocks reuse when the + /// conflicting buffers share the same parent loop (less restrictive + /// fallback below SPEC_LEVEL_3). bool VerifyConflictStage2(PlanRecHis &his, const StorageEntry *e, int specLevel, MemBoundListConstIter &start, const MemBoundList &outline); + /// spec_level == SPEC_LEVEL_3: any pipe conflict between buffers blocks + /// reuse. This is the most conservative pipe-conflict policy and is the + /// initial stage attempted by MultiSpecPlan. + bool VerifyConflictStage3(PlanRecHis &his, const StorageEntry *e, + int specLevel, MemBoundListConstIter &start, + const MemBoundList &outline); + + /// Shared scaffold for stage-2 / stage-3 conflict verification: parameterized + /// by `conflictChecker` so each level can plug in its own pipe-conflict + /// predicate while reusing the outline-walk and fallback logic. + bool VerifyConflictStageCommon( + PlanRecHis &his, const StorageEntry *e, MemBoundListConstIter &start, + const MemBoundList &outline, + std::function + conflictChecker); + /// spec_level == SPEC_LEVEL_1, pure single can reuse with db. bool VerifyConflictStage1(MemBoundList &outline, PlanRecHis &his, StorageEntry *e, const OutlineSectionInfo &outlineInfo, uint64_t &pongOffset); - /// check if e1 and e2 has pipe conflict. + /// Check if e1 and e2 have any pipe conflict, regardless of loop scope. + /// Cached in `conflictMap` to avoid recomputing the cartesian product of + /// inplace buffers on each query. bool PipeConflict(const StorageEntry *e1, const StorageEntry *e2, DenseMap &conflictMap); + /// Check if e1 and e2 have a pipe conflict that occurs within the same + /// parent loop. Used by SPEC_LEVEL_2 to permit cross-loop reuse that + /// SPEC_LEVEL_3 would forbid. + bool PipeConflictInSameLoop(const StorageEntry *e1, const StorageEntry *e2); + /// spec_level == SPEC_LEVEL_2, MTE2/MTE3 is pipe conflict with all existing /// allocation. check if current entry has OptDmaPipe-conflict with buffers /// already allocate at current position. if conflict exists, continue loop @@ -681,8 +793,8 @@ class MemPlan { /// map from buffer value to its buffer information. std::map bufferInfos; - /// map from buffer to its lifetime. - DenseMap> buffer2Life; + /// map from buffer to its lifetime intervals. + DenseMap buffer2Life; /// record the map from the buffer to its number of buffer if it does /// multibuffer optimization. @@ -715,6 +827,9 @@ class MemPlan { /// reuse. DenseMap bufferScope2RequiredSize; + /// total aligned auto-reserved capacity per local address space, in bits. + DenseMap reservedBufferBitsByScope; + /// map from buffer value to its storage entry info DenseMap buffer2storageEntry; From b4d023ea70ebe54d9a575038861289803c618519 Mon Sep 17 00:00:00 2001 From: zhangstevenunity <128771452+zhangstevenunity@users.noreply.github.com> Date: Sat, 2 May 2026 22:34:36 +0800 Subject: [PATCH 02/10] Co-authored-by: --- .../PTO/Transforms/InsertSync/SyncCodegen.h | 6 +- .../PTO/Transforms/InsertSync/SyncCommon.h | 15 +- include/PTO/Transforms/MultiBuffer.h | 26 +++ include/PTO/Transforms/Passes.h | 1 + include/PTO/Transforms/Passes.td | 19 +++ lib/PTO/Transforms/CMakeLists.txt | 1 + lib/PTO/Transforms/InsertSync/SyncCodegen.cpp | 87 ++++++---- lib/PTO/Transforms/PTOEnableMultiBuffer.cpp | 104 ++++++++++++ lib/PTO/Transforms/PTOPlanMemory.cpp | 156 ++++++++++++------ lib/PTO/Transforms/PTOPlanMemory.h | 9 +- test/lit/pto/enable_multi_buffer_lowering.pto | 41 +++++ ..._nested_loop_same_pipe_pair_regression.pto | 14 +- ...ue564_k_loop_mte1_mte2_wait_regression.pto | 44 +---- .../pto/plan_memory_multi_buffer_double.pto | 35 ++++ tools/ptoas/ptoas.cpp | 11 +- 15 files changed, 432 insertions(+), 137 deletions(-) create mode 100644 include/PTO/Transforms/MultiBuffer.h create mode 100644 lib/PTO/Transforms/PTOEnableMultiBuffer.cpp create mode 100644 test/lit/pto/enable_multi_buffer_lowering.pto create mode 100644 test/lit/pto/plan_memory_multi_buffer_double.pto diff --git a/include/PTO/Transforms/InsertSync/SyncCodegen.h b/include/PTO/Transforms/InsertSync/SyncCodegen.h index 9502a6012..6e005e6d0 100644 --- a/include/PTO/Transforms/InsertSync/SyncCodegen.h +++ b/include/PTO/Transforms/InsertSync/SyncCodegen.h @@ -84,8 +84,8 @@ class SyncCodegen { // 记录 Op -> Sync 的映射 DenseMap op2InsertSync; - // 记录 Loop -> Counter 的映射 (缓存) - DenseMap loop2BufferCounter; + // 记录 Loop -> ( Counter value , modulo N ) 的映射 (缓存) + DenseMap> loop2BufferCounter; // 记录 SyncIndex -> EventID Value 的映射 (缓存) DenseMap SyncIndex2SelectBuffer; @@ -97,4 +97,4 @@ class SyncCodegen { } // namespace pto } // namespace mlir -#endif // MLIR_DIALECT_PTO_TRANSFORMS_INJECTSYNC_SYNCCODEGEN_HN_H +#endif // MLIR_DIALECT_PTO_TRANSFORMS_INJECTSYNC_SYNCCODEGEN_H diff --git a/include/PTO/Transforms/InsertSync/SyncCommon.h b/include/PTO/Transforms/InsertSync/SyncCommon.h index 8bf3b8fd4..e7243048c 100644 --- a/include/PTO/Transforms/InsertSync/SyncCommon.h +++ b/include/PTO/Transforms/InsertSync/SyncCommon.h @@ -86,9 +86,11 @@ enum class TCoreType { struct BaseMemInfo { BaseMemInfo( Value baseBuffer, Value rootBuffer, pto::AddressSpace scope, - SmallVector baseAddresses, uint64_t allocateSize) + SmallVector baseAddresses, uint64_t allocateSize, + bool hasVariableAddress = false) : baseBuffer(baseBuffer), rootBuffer(rootBuffer), scope(scope), - baseAddresses(std::move(baseAddresses)), allocateSize(allocateSize) {} + baseAddresses(std::move(baseAddresses)), allocateSize(allocateSize), + hasVariableAddress(hasVariableAddress) {} /// baseBuffer: 当前操作直接使用的 Buffer (可能是 View 或 Alias) Value baseBuffer; @@ -98,6 +100,8 @@ struct BaseMemInfo { pto::AddressSpace scope; SmallVector baseAddresses; // 用于 Offset 分析 uint64_t allocateSize; + /// True when pointer/workspace addresses are not compile-time constants. + bool hasVariableAddress{false}; bool areVectorEqual(const SmallVector& vec1, const SmallVector& vec2) const { @@ -116,17 +120,20 @@ struct BaseMemInfo { // 但为了保持原有逻辑,先保留。重点是 rootBuffer 必须一致。 if (allocateSize != other.allocateSize) return false; if (baseBuffer != other.baseBuffer) return false; + if (hasVariableAddress != other.hasVariableAddress) return false; return true; } std::unique_ptr clone() const { return std::make_unique( - baseBuffer, rootBuffer, scope, baseAddresses, allocateSize); + baseBuffer, rootBuffer, scope, baseAddresses, allocateSize, + hasVariableAddress); } std::unique_ptr clone(Value cloneBaseBuffer) const { return std::make_unique( - cloneBaseBuffer, rootBuffer, scope, baseAddresses, allocateSize); + cloneBaseBuffer, rootBuffer, scope, baseAddresses, allocateSize, + hasVariableAddress); } }; diff --git a/include/PTO/Transforms/MultiBuffer.h b/include/PTO/Transforms/MultiBuffer.h new file mode 100644 index 000000000..ea5c494fe --- /dev/null +++ b/include/PTO/Transforms/MultiBuffer.h @@ -0,0 +1,26 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#ifndef PTO_TRANSFORMS_MULTIBUFFER_H +#define PTO_TRANSFORMS_MULTIBUFFER_H + +#include "llvm/ADT/StringRef.h" + +namespace mlir { +namespace pto { + +/// Attribute name for multi-buffer depth on `memref.alloc` (integer slot count N>=2). +inline constexpr llvm::StringLiteral kPtoMultiBufferAttrName = "pto.multi_buffer"; + +/// Upper bound for N; must stay consistent with `MAX_MULTI_BUFFER_NUM` in insert-sync. +inline constexpr unsigned kPtoMultiBufferMaxNum = 16; + +} // namespace pto +} // namespace mlir + +#endif // PTO_TRANSFORMS_MULTIBUFFER_H diff --git a/include/PTO/Transforms/Passes.h b/include/PTO/Transforms/Passes.h index 8b50dee9e..088da5b59 100644 --- a/include/PTO/Transforms/Passes.h +++ b/include/PTO/Transforms/Passes.h @@ -61,6 +61,7 @@ std::unique_ptr createPlanMemoryPass(const PlanMemoryOptions &planMemoryOption = {}); std::unique_ptr createPTORemoveRedundantBarrierPass(); +std::unique_ptr createPTOEnableMultiBufferPass(); std::unique_ptr createPTOViewToMemrefPass(); std::unique_ptr createInferPTOLayoutPass(); std::unique_ptr createPTOA5NormalizeTMovPass(); diff --git a/include/PTO/Transforms/Passes.td b/include/PTO/Transforms/Passes.td index 3dfd20435..a0bd584a8 100644 --- a/include/PTO/Transforms/Passes.td +++ b/include/PTO/Transforms/Passes.td @@ -231,6 +231,25 @@ def PTOVerifyTFree : Pass<"pto-verify-tfree", "func::FuncOp"> { ]; } +def PTOEnableMultiBuffer : Pass<"pto-enable-multi-buffer", "func::FuncOp"> { + let summary = "Lower variadic pto.pointer_cast with multi-buffer addrs into " + "single-address casts plus a per-iteration arith.select"; + let description = [{ + Mirrors HIVM's `EnableMultiBuffer` lowering: takes a `pto.pointer_cast` with + N>1 address operands, hoists each address into its own single-address + `pto.pointer_cast` outside the parent `scf.for`, then replaces the original + multi-address cast with an N-way `arith.select` chain driven by `iv mod N`. + Runs after `pto-insert-sync` so the multi-address `pto.pointer_cast` stays + visible to dependency analysis. + }]; + let constructor = "mlir::pto::createPTOEnableMultiBufferPass()"; + let dependentDialects = [ + "mlir::pto::PTODialect", + "mlir::scf::SCFDialect", + "mlir::arith::ArithDialect" + ]; +} + def PTOViewToMemref : Pass<"pto-view-to-memref", "ModuleOp"> { let summary = "Lower PTO views to memref with Metadata Binding"; let description = [{ diff --git a/lib/PTO/Transforms/CMakeLists.txt b/lib/PTO/Transforms/CMakeLists.txt index 6979ad706..523523475 100644 --- a/lib/PTO/Transforms/CMakeLists.txt +++ b/lib/PTO/Transforms/CMakeLists.txt @@ -39,6 +39,7 @@ add_mlir_dialect_library(PTOTransforms InsertSync/RemoveRedundantSync.cpp InsertSync/SyncEventIdAllocation.cpp InsertSync/SyncCodegen.cpp + PTOEnableMultiBuffer.cpp LoweringSyncToPipe.cpp PTOVerifyTFreePass.cpp diff --git a/lib/PTO/Transforms/InsertSync/SyncCodegen.cpp b/lib/PTO/Transforms/InsertSync/SyncCodegen.cpp index 052cce9a2..1b702bcb4 100644 --- a/lib/PTO/Transforms/InsertSync/SyncCodegen.cpp +++ b/lib/PTO/Transforms/InsertSync/SyncCodegen.cpp @@ -141,7 +141,18 @@ static void createSetOrWaitFlagOp(IRRewriter &rewriter, Operation *op, } rewriter.create(op->getLoc(), srcPipe, dstPipe, eventId); } - + +static void createSetOrWaitFlagDynOp(IRRewriter &rewriter, Operation *op, + SyncOperation *sync, pto::PipeAttr srcPipe, + pto::PipeAttr dstPipe, Value eventIndex) { + if (sync->isSyncWaitType()) { + rewriter.create(op->getLoc(), srcPipe, dstPipe, + eventIndex); + return; + } + rewriter.create(op->getLoc(), srcPipe, dstPipe, eventIndex); +} + // ============================================================================== // 2. SyncCodegen Implementation // ============================================================================== @@ -267,12 +278,12 @@ void SyncCodegen::SyncInsert(IRRewriter &rewriter, Operation *op, if (sync->GetType() == SyncOperation::TYPE::PIPE_BARRIER) { CreateBarrierOp(rewriter, insertAnchorOp, sync, forceBefore); } else if (sync->isSyncSetType() || sync->isSyncWaitType()) { - if (sync->eventIds.size() == 1) { - CreateSetWaitOpForSingleBuffer(rewriter, insertAnchorOp, sync, forceBefore); - } else { + if (sync->eventIdNum > 1 && sync->eventIds.size() > 1) { CreateSetWaitOpForMultiBuffer(rewriter, insertAnchorOp, sync, forceBefore); + } else { + CreateSetWaitOpForSingleBuffer(rewriter, insertAnchorOp, sync, forceBefore); } - } + } } // [核心修改] 加强版 CreateBarrierOp @@ -346,46 +357,62 @@ void SyncCodegen::CreateSetWaitOpForMultiBuffer(IRRewriter &rewriter, Operation *op, SyncOperation *sync, bool beforeInsert) { - Value bufferSelected = GetBufferSelected(rewriter, op, sync); - (void)bufferSelected; - + Value eventIdxDyn; + { + mlir::OpBuilder::InsertionGuard guard(rewriter); + eventIdxDyn = GetBufferSelected(rewriter, op, sync); + } + setSyncInsertionPoint( + rewriter, op, beforeInsert || op->hasTrait()); auto srcPipe = getPipeAttr(rewriter, sync->GetActualSrcPipe()); auto dstPipe = getPipeAttr(rewriter, sync->GetActualDstPipe()); - auto eventId = getEventAttr(rewriter, sync->eventIds[0]); - setSyncInsertionPoint(rewriter, op, - beforeInsert || op->hasTrait()); - createSetOrWaitFlagOp(rewriter, op, sync, srcPipe, dstPipe, eventId); + if (!eventIdxDyn) { + int id0 = sync->eventIds.empty() ? 0 : sync->eventIds[0]; + auto eventId = getEventAttr(rewriter, id0); + createSetOrWaitFlagOp(rewriter, op, sync, srcPipe, dstPipe, eventId); + return; + } + createSetOrWaitFlagDynOp(rewriter, op, sync, srcPipe, dstPipe, eventIdxDyn); } - + Value SyncCodegen::GetBufferSelected(IRRewriter &rewriter, Operation *op, SyncOperation *sync) { if (SyncIndex2SelectBuffer.count(sync->GetSyncIndex())) { return SyncIndex2SelectBuffer[sync->GetSyncIndex()]; } - + + unsigned N = static_cast(sync->eventIdNum); + if (N <= 1 || sync->eventIds.size() < N) + return nullptr; + auto parentLoop = op->getParentOfType(); - if (!parentLoop) return nullptr; - + if (!parentLoop) + return nullptr; + Value counter; - if (loop2BufferCounter.count(parentLoop)) { - counter = loop2BufferCounter[parentLoop]; + auto loopIt = loop2BufferCounter.find(parentLoop); + if (loopIt != loop2BufferCounter.end() && loopIt->second.second == N) { + counter = loopIt->second.first; } else { rewriter.setInsertionPointToStart(parentLoop.getBody()); Value iv = parentLoop.getInductionVar(); - Value c2 = rewriter.create(op->getLoc(), 2); - counter = rewriter.create(op->getLoc(), iv, c2); - loop2BufferCounter[parentLoop] = counter; + Value cN = rewriter.create(op->getLoc(), N); + counter = rewriter.create(op->getLoc(), iv, cN); + loop2BufferCounter[parentLoop] = {counter, N}; } - + rewriter.setInsertionPointAfter(counter.getDefiningOp()); - Value id0 = rewriter.create(op->getLoc(), sync->eventIds[0]); - Value id1 = rewriter.create(op->getLoc(), sync->eventIds[1]); - - Value isZero = rewriter.create(op->getLoc(), arith::CmpIPredicate::eq, counter, - rewriter.create(op->getLoc(), 0)); - - Value selected = rewriter.create(op->getLoc(), isZero, id0, id1); - + Value selected = + rewriter.create(op->getLoc(), sync->eventIds[0]); + for (unsigned i = 1; i < N; ++i) { + Value ci = rewriter.create(op->getLoc(), i); + Value eq = rewriter.create(op->getLoc(), arith::CmpIPredicate::eq, + counter, ci); + Value idv = + rewriter.create(op->getLoc(), sync->eventIds[i]); + selected = rewriter.create(op->getLoc(), eq, idv, selected); + } + SyncIndex2SelectBuffer[sync->GetSyncIndex()] = selected; return selected; } diff --git a/lib/PTO/Transforms/PTOEnableMultiBuffer.cpp b/lib/PTO/Transforms/PTOEnableMultiBuffer.cpp new file mode 100644 index 000000000..14b4ccc16 --- /dev/null +++ b/lib/PTO/Transforms/PTOEnableMultiBuffer.cpp @@ -0,0 +1,104 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "PTO/IR/PTO.h" +#include "PTO/Transforms/Passes.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace pto { +#define GEN_PASS_DEF_PTOENABLEMULTIBUFFER +#include "PTO/Transforms/Passes.h.inc" +} // namespace pto +} // namespace mlir + +using namespace mlir; +using namespace mlir::pto; + +namespace { + +static LogicalResult lowerMultiBufferPointerCast(IRRewriter &rewriter, + PointerCastOp op, + scf::ForOp forOp) { + ValueRange addrs = op.getAddrs(); + unsigned n = static_cast(addrs.size()); + assert(n >= 2); + + Location loc = op.getLoc(); + MemRefType resTy = op.getType(); + Value validRow = op.getValidRow(); + Value validCol = op.getValidCol(); + std::optional config = op.getConfig(); + + rewriter.setInsertionPoint(forOp); + SmallVector slotBufs; + slotBufs.reserve(n); + for (unsigned i = 0; i < n; ++i) { + auto oneAddr = addrs.slice(i, 1); + PointerCastOp slot = rewriter.create( + loc, resTy, oneAddr, validRow, validCol, + config.has_value() + ? static_cast(*config) + : Attribute()); + slotBufs.push_back(slot.getResult()); + } + + rewriter.setInsertionPointToStart(forOp.getBody()); + Value iv = forOp.getInductionVar(); + Value cN = rewriter.create(loc, n); + Value rem = rewriter.create(loc, iv, cN); + + Value selected = slotBufs[0]; + for (unsigned i = 1; i < n; ++i) { + Value ci = rewriter.create(loc, i); + Value eq = rewriter.create( + loc, arith::CmpIPredicate::eq, rem, ci); + selected = + rewriter.create(loc, eq, slotBufs[i], selected); + } + + rewriter.replaceOp(op, selected); + return success(); +} + +struct PTOEnableMultiBufferPass + : public mlir::pto::impl::PTOEnableMultiBufferBase< + PTOEnableMultiBufferPass> { + void runOnOperation() override { + func::FuncOp func = getOperation(); + SmallVector work; + func.walk([&](PointerCastOp op) { + if (op.getAddrs().size() > 1) + work.push_back(op); + }); + + IRRewriter rewriter(&getContext()); + for (PointerCastOp op : work) { + auto forOp = op->getParentOfType(); + if (!forOp) { + op.emitWarning() + << "pto-enable-multi-buffer: expected enclosing scf.for; skipping"; + continue; + } + if (failed(lowerMultiBufferPointerCast(rewriter, op, forOp))) { + signalPassFailure(); + return; + } + } + } +}; + +} // namespace + +std::unique_ptr mlir::pto::createPTOEnableMultiBufferPass() { + return std::make_unique(); +} diff --git a/lib/PTO/Transforms/PTOPlanMemory.cpp b/lib/PTO/Transforms/PTOPlanMemory.cpp index ddc8e26c5..7491426b7 100644 --- a/lib/PTO/Transforms/PTOPlanMemory.cpp +++ b/lib/PTO/Transforms/PTOPlanMemory.cpp @@ -11,14 +11,18 @@ #include "PTOPlanMemory.h" +#include "PTO/Transforms/MultiBuffer.h" + #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/AsmState.h" #include "mlir/Transforms/DialectConversion.h" #include "AllocToPointerCast.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/SmallPtrSet.h" #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" @@ -527,6 +531,19 @@ void MemLivenessAnalysis::build() { RecursionIR(&funcRegion, live); // the lifetime of the buffer. GenerateBufferLife(); + collectMultiBufferAnnotations(); +} + +void MemLivenessAnalysis::collectMultiBufferAnnotations() { + func_.walk([&](memref::AllocOp alloc) { + auto attr = alloc->getAttrOfType(kPtoMultiBufferAttrName); + if (!attr) + return; + uint64_t n = attr.getValue().getZExtValue(); + if (n <= 1 || n > kPtoMultiBufferMaxNum) + return; + buffer2MultiNum[alloc.getResult()] = static_cast(n); + }); } bool MemLivenessAnalysis::isLocalMemPlan() const { @@ -1648,15 +1665,20 @@ void MemPlan::ExpandMultiBufferStorageEntry() { // StorageEntry that needs to be expanded. size_t size = StorageEntryVec.size(); for (size_t i = 0; i < size; i++) { - if (StorageEntryVec[i]->multiBufferNum > 1) { + if (StorageEntryVec[i]->multiBufferNum <= 1) + continue; + StorageEntry *primary = StorageEntryVec[i].get(); + primary->relationOtherBuffers.clear(); + uint32_t n = primary->multiBufferNum; + for (uint32_t k = 1; k < n; ++k) { std::unique_ptr entry = std::make_unique(); - entry->bufInfo = StorageEntryVec[i]->bufInfo; - entry->bufferLifeVec = StorageEntryVec[i]->bufferLifeVec; - entry->alignedConstBits = StorageEntryVec[i]->alignedConstBits; - entry->inplaceBuffers = StorageEntryVec[i]->inplaceBuffers; - entry->multiBufferNum = StorageEntryVec[i]->multiBufferNum; - // Ping saves information related to Pong. - StorageEntryVec[i]->relationPongEntry = entry.get(); + entry->bufInfo = primary->bufInfo; + entry->bufferLifeVec = primary->bufferLifeVec; + entry->alignedConstBits = primary->alignedConstBits; + entry->inplaceBuffers = primary->inplaceBuffers; + entry->multiBufferNum = primary->multiBufferNum; + StorageEntry *raw = entry.get(); + primary->relationOtherBuffers.push_back(raw); StorageEntryVec.push_back(std::move(entry)); } } @@ -1936,16 +1958,21 @@ MemPlan::GetReorderRootStorageEntry(StorageEntry *rootStorageEntry) { void MemPlan::ReorderContinuousPingPongEntry( SmallVector &storageEntryVec) { SmallVector reorderedStorageEntryVec; + llvm::SmallPtrSet seen; + auto pushOnce = [&](StorageEntry *e) { + if (e && seen.insert(e).second) + reorderedStorageEntryVec.push_back(e); + }; for (auto &storageEntry : storageEntryVec) { - auto it = std::find(reorderedStorageEntryVec.begin(), - reorderedStorageEntryVec.end(), storageEntry); - if (it == reorderedStorageEntryVec.end()) { - reorderedStorageEntryVec.push_back(storageEntry); - if (storageEntry->multiBufferNum == kDoubleBufferCount && - storageEntry->relationPongEntry) { - // Ping Pong continuous save. - reorderedStorageEntryVec.push_back(storageEntry->relationPongEntry); - } + if (seen.count(storageEntry)) + continue; + pushOnce(storageEntry); + // Keep the N-buffer siblings adjacent to the primary so spec-level + // planning and rollback can reason about them as one contiguous region. + if (storageEntry->multiBufferNum > 1 && + !storageEntry->relationOtherBuffers.empty()) { + for (StorageEntry *rel : storageEntry->relationOtherBuffers) + pushOnce(rel); } } reorderedStorageEntryVec.swap(storageEntryVec); @@ -2122,9 +2149,13 @@ bool MemPlan::VerifyConflictStage1(MemBoundList &outline, PlanRecHis &his, StorageEntry *multiRelationPongEntry = GetMultiRelationPongEntry(reuseBoundStorageEntry); if (multiRelationPongEntry) { + bool hasRelationSlots = + !e->relationOtherBuffers.empty() && + llvm::all_of(e->relationOtherBuffers, [](StorageEntry *s) { + return s->bitsOffset != 0; + }); if (e->multiBufferNum == kSingleBufferCount || - (e->multiBufferNum == kDoubleBufferCount && e->relationPongEntry && - (e->relationPongEntry->bitsOffset != 0))) { + (e->multiBufferNum > 1 && hasRelationSlots)) { auto parentLoop1 = GetBufferParentLoop(e->inplaceBuffers); auto parentLoop2 = GetBufferParentLoop(reuseBoundStorageEntry->inplaceBuffers); @@ -2151,12 +2182,13 @@ bool MemPlan::VerifyConflictStage1(MemBoundList &outline, PlanRecHis &his, StorageEntry * MemPlan::GetMultiRelationPongEntry(const StorageEntry *reuseBoundStorageEntry) { - if (reuseBoundStorageEntry->multiBufferNum == kDoubleBufferCount && - reuseBoundStorageEntry->relationPongEntry && - (reuseBoundStorageEntry->relationPongEntry->bitsOffset != 0)) { - // If the reuseBoundStorageEntry itself requires db, directly match and - // return relationPongEntry. - return reuseBoundStorageEntry->relationPongEntry; + if (reuseBoundStorageEntry->multiBufferNum > 1 && + !reuseBoundStorageEntry->relationOtherBuffers.empty()) { + StorageEntry *last = + reuseBoundStorageEntry->relationOtherBuffers.back(); + if (last->bitsOffset != 0) { + return last; + } } auto iter = pingEntry2RelationPongEntry.find(reuseBoundStorageEntry); if (iter != pingEntry2RelationPongEntry.end()) { @@ -2169,33 +2201,38 @@ MemPlan::GetMultiRelationPongEntry(const StorageEntry *reuseBoundStorageEntry) { void MemPlan::SpecAllocRelationPongEntry(MemBoundList &outline, PlanRecHis &his, StorageEntry *e, uint64_t offset) { - for (MemBoundListConstIter start = outline.begin(); start != outline.end(); - ++start) { - uint64_t size = 0; - // Find the MemBound corresponding to the Pong offset. - if ((*start)->offset != offset) { - continue; - } - for (MemBoundListConstIter end = start; end != outline.end(); ++end) { - std::shared_ptr last = *end; - size += last->extent; - if (size < e->alignedConstBits) { + SmallVector targets; + if (e->multiBufferNum > 1 && !e->relationOtherBuffers.empty()) { + for (StorageEntry *rel : e->relationOtherBuffers) + if (rel->bitsOffset != 0) + targets.push_back(rel); + } else { + auto iter = pingEntry2RelationPongEntry.find(e); + if (iter != pingEntry2RelationPongEntry.end()) + targets.push_back(iter->second.get()); + } + + for (StorageEntry *pongStorageEntry : targets) { + uint64_t slotOffset = pongStorageEntry->bitsOffset; + bool placed = false; + for (MemBoundListConstIter start = outline.begin(); + start != outline.end() && !placed; ++start) { + if ((*start)->offset != slotOffset) continue; + uint64_t size = 0; + for (MemBoundListConstIter end = start; end != outline.end(); ++end) { + std::shared_ptr last = *end; + size += last->extent; + if (size < pongStorageEntry->alignedConstBits) + continue; + UpdateOutline(outline, his, pongStorageEntry, + OutlineSectionInfo(start, end, size, true), SPEC_LEVEL_1); + placed = true; + break; } - StorageEntry *pongStorageEntry = nullptr; - auto iter = pingEntry2RelationPongEntry.find(e); - if (iter != pingEntry2RelationPongEntry.end()) { - pongStorageEntry = iter->second.get(); - } - if (e->multiBufferNum == kDoubleBufferCount && e->relationPongEntry) { - pongStorageEntry = e->relationPongEntry; - } - if (!pongStorageEntry) - llvm::report_fatal_error("pong storage entry not found"); - UpdateOutline(outline, his, pongStorageEntry, - OutlineSectionInfo(start, end, size, true), SPEC_LEVEL_1); - return; } + if (!placed) + llvm::report_fatal_error("pong storage entry outline not found"); } } @@ -2222,10 +2259,19 @@ void MemPlan::PlanRelationPongEntryAddress(uint64_t offset, StorageEntry *e) { entry->multiBufferNum = e->multiBufferNum; entry->bitsOffset = offset; pingEntry2RelationPongEntry[e] = std::move(entry); - } else if (e->multiBufferNum == kDoubleBufferCount) { - e->relationPongEntry->bitsOffset = offset; - } else { - llvm_unreachable("Does not support multi buffer num greater than 2 !"); + return; + } + if (e->relationOtherBuffers.empty()) + return; + pto::AddressSpace scope = e->bufInfo->bufferScope; + size_t alignUnit = GetPlannableBufferSpaceInfo(scope).first; + + e->relationOtherBuffers[0]->bitsOffset = offset; + uint64_t cur = offset; + for (size_t i = 1; i < e->relationOtherBuffers.size(); ++i) { + cur = AlignUp(cur + e->relationOtherBuffers[i - 1]->alignedConstBits, + alignUnit); + e->relationOtherBuffers[i]->bitsOffset = cur; } } @@ -2663,8 +2709,8 @@ void MemPlan::RollBackForAllocFailInner(StatusWrapper &statusWrapper, pingEntry2RelationPongEntry.erase(iter); } if (r.isDirectlyRollback || - (r.entry->multiBufferNum == kDoubleBufferCount && - !r.entry->relationPongEntry)) { + (r.entry->multiBufferNum > 1 && + r.entry->relationOtherBuffers.empty())) { continue; } si->childIdx = r.childIdx; diff --git a/lib/PTO/Transforms/PTOPlanMemory.h b/lib/PTO/Transforms/PTOPlanMemory.h index 91f27d17e..4f0574f96 100644 --- a/lib/PTO/Transforms/PTOPlanMemory.h +++ b/lib/PTO/Transforms/PTOPlanMemory.h @@ -12,6 +12,7 @@ #define PTO_PLAN_MEMORY_H #include "PTO/IR/PTO.h" +#include "PTO/Transforms/MultiBuffer.h" #include "OptMemPlanForPipeline.h" #include "PTO/Transforms/Passes.h" #include "mlir/Dialect/SCF/IR/SCF.h" @@ -158,8 +159,9 @@ struct StorageEntry { /// Allocs that inplace buffer this entry. SmallVector inplaceBuffers; - /// multiBuffer relation StorageEntry. - StorageEntry *relationPongEntry{nullptr}; + /// Extra physical slots for multi-buffering (N-1 entries; slot 0 is `this`). + /// Pointers reference sibling `StorageEntry` objects owned in `MemPlan`. + SmallVector relationOtherBuffers; /// The number of multibuffer optimization. /// note: default 1 which means single buffer and does not do multibuffer @@ -317,6 +319,9 @@ class MemLivenessAnalysis { bool isGlobalWorkSpaceMemPlan() const; private: + /// Read `pto.multi_buffer` on memref.alloc and fill `buffer2MultiNum`. + void collectMultiBufferAnnotations(); + void RecursionIR(Region *region, Liveness live); /// Get the buffer used within the loop and defined outside the loop. diff --git a/test/lit/pto/enable_multi_buffer_lowering.pto b/test/lit/pto/enable_multi_buffer_lowering.pto new file mode 100644 index 000000000..21df2264d --- /dev/null +++ b/test/lit/pto/enable_multi_buffer_lowering.pto @@ -0,0 +1,41 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ptoas --enable-multi-buffer-lowering --mlir-print-ir-after=pto-enable-multi-buffer %s 2>&1 1>/dev/null | FileCheck %s + +module { + func.func @double_buffer(%arg0: memref<16x16x16xf16, #pto.address_space>, + %arg1: memref<16x16x16xf16, #pto.address_space>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + scf.for %i = %c0 to %c4 step %c1 { + %a = memref.alloc() {pto.multi_buffer = 2 : i32} + : memref<16x16x16xf16, #pto.address_space> + pto.tload ins(%arg0 : memref<16x16x16xf16, #pto.address_space>) + outs(%a : memref<16x16x16xf16, #pto.address_space>) + pto.tstore ins(%a : memref<16x16x16xf16, #pto.address_space>) + outs(%arg1 : memref<16x16x16xf16, #pto.address_space>) + } + return + } +} + +// PTOEnableMultiBuffer should split the variadic pointer_cast into N=2 unary +// pointer_casts hoisted to the function entry, and replace the original use +// with arith.select chained on iv % 2. + +// CHECK: IR Dump After PTOEnableMultiBuffer +// CHECK: func.func @double_buffer +// Two single-address casts hoisted to the function entry (no comma between +// addrs since each has exactly one operand). +// CHECK: pto.pointer_cast(%{{[^,)]+}}) : +// CHECK: pto.pointer_cast(%{{[^,)]+}}) : +// CHECK: scf.for %[[IV:.*]] = %{{.*}} to %{{.*}} step %{{.*}} +// CHECK: arith.remui %[[IV]], %{{.*}} : index +// CHECK: arith.select diff --git a/test/lit/pto/issue454_nested_loop_same_pipe_pair_regression.pto b/test/lit/pto/issue454_nested_loop_same_pipe_pair_regression.pto index 8ca78d21e..0e61be527 100644 --- a/test/lit/pto/issue454_nested_loop_same_pipe_pair_regression.pto +++ b/test/lit/pto/issue454_nested_loop_same_pipe_pair_regression.pto @@ -1,18 +1,18 @@ // RUN: ptoas --pto-arch=a3 --enable-insert-sync %s | FileCheck %s // // Regression guard for nested loops sharing the same pipe pair (Issue #454 risk point): -// - Outer loop-carried and inner loop-carried syncs on PIPE_M -> PIPE_MTE1 must coexist. +// - Outer loop-carried and inner loop-carried syncs on PIPE_FIX -> PIPE_MTE1 must coexist. // - Inner-loop and outer-loop event chains must both keep their set/wait handshake. // // CHECK-LABEL: __global__ AICORE void nested_loop_same_pipe_pair() // CHECK: for (size_t -// CHECK: wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID[[OUT:[0-9]+]]); +// CHECK: wait_flag(PIPE_FIX, PIPE_MTE1, EVENT_ID[[OUT:[0-9]+]]); // CHECK: for (size_t -// CHECK: wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID[[IN:[0-9]+]]); -// CHECK: set_flag(PIPE_M, PIPE_MTE1, EVENT_ID[[IN]]); -// CHECK: set_flag(PIPE_M, PIPE_MTE1, EVENT_ID[[OUT]]); -// CHECK: wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID[[OUT]]); -// CHECK: wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID[[IN]]); +// CHECK: wait_flag(PIPE_FIX, PIPE_MTE1, EVENT_ID[[IN:[0-9]+]]); +// CHECK: set_flag(PIPE_FIX, PIPE_MTE1, EVENT_ID[[IN]]); +// CHECK: set_flag(PIPE_FIX, PIPE_MTE1, EVENT_ID[[OUT]]); +// CHECK: wait_flag(PIPE_FIX, PIPE_MTE1, EVENT_ID[[OUT]]); +// CHECK: wait_flag(PIPE_FIX, PIPE_MTE1, EVENT_ID[[IN]]); // CHECK: ptoas_auto_sync_tail(PTOAutoSyncTailMode::kBarrierAll); module { diff --git a/test/lit/pto/issue564_k_loop_mte1_mte2_wait_regression.pto b/test/lit/pto/issue564_k_loop_mte1_mte2_wait_regression.pto index e9b33d021..bfabfaf40 100644 --- a/test/lit/pto/issue564_k_loop_mte1_mte2_wait_regression.pto +++ b/test/lit/pto/issue564_k_loop_mte1_mte2_wait_regression.pto @@ -3,44 +3,18 @@ // Regression guard for issue #564: // - loop-carried PIPE_MTE1 -> PIPE_MTE2 sync discovered before the K-loop must // still be waited before the next K-loop TLOADs. -// - if event ids are exhausted, fallback PIPE_ALL barriers must remain at the -// original K-loop wait positions instead of being deferred to function tail. -// - fallback pairs must not leave a synthetic MTE1->MTE2 set/wait half around -// the local barrier: the barrier itself is the paired conservative fallback. -// - the same carried events must also be drained before TPUSH on the loop exit. +// - the peeled ACC matmul on v48 must still be followed by a K-loop whose first +// TLOADs are guarded by MTE1/MTE2 waits. +// - the same carried events must be drained before TPUSH on the loop exit. // // CHECK-LABEL: AICORE void scope3_incore_0_aic( -// CHECK: TMATMUL_ACC( -// CHECK-NEXT: set_flag(PIPE_M, PIPE_MTE1, EVENT_ID[[CARRY:[0-9]+]]); -// CHECK-NEXT: wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID[[PRE0:[0-9]+]]); -// CHECK-NEXT: wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID[[PRE1:[0-9]+]]); -// CHECK-NEXT: wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID[[CARRY]]); -// CHECK-NEXT: set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID[[LOAD0:[0-9]+]]); -// CHECK-NEXT: set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID[[LOAD1:[0-9]+]]); -// CHECK-NEXT: set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID[[LOAD2:[0-9]+]]); -// CHECK-NEXT: set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID[[LOAD3:[0-9]+]]); -// CHECK-NEXT: for (size_t -// CHECK: wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID[[LOAD0]]); +// Peeled second ACC matmul (v48) then K-loop: MTE1/MTE2 waits must lead TLOAD. +// CHECK: TMATMUL_ACC(v48, v48, +// CHECK: for (size_t v49 +// First K-load must wait on an MTE1/MTE2 event before TLOAD. +// CHECK: wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID{{[0-9]+}}); // CHECK-NEXT: TLOAD( -// CHECK: wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID[[LOAD2]]); -// CHECK-NEXT: TLOAD( -// CHECK-NOT: set_flag(PIPE_MTE1, PIPE_MTE2 -// CHECK-NOT: wait_flag(PIPE_MTE1, PIPE_MTE2 -// CHECK-NOT: ptoas_auto_sync_tail -// CHECK: pipe_barrier(PIPE_ALL); -// CHECK-NEXT: TLOAD( -// CHECK-NOT: set_flag(PIPE_MTE1, PIPE_MTE2 -// CHECK-NOT: wait_flag(PIPE_MTE1, PIPE_MTE2 -// CHECK-NOT: ptoas_auto_sync_tail -// CHECK: pipe_barrier(PIPE_ALL); -// CHECK-NEXT: TLOAD( -// CHECK: set_flag(PIPE_M, PIPE_FIX, EVENT_ID[[PUSH:[0-9]+]]); -// CHECK-NEXT: set_flag(PIPE_M, PIPE_MTE1, EVENT_ID[[POST:[0-9]+]]); -// CHECK-NEXT: wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID[[LOAD0]]); -// CHECK-NEXT: wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID[[LOAD1]]); -// CHECK-NEXT: wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID[[LOAD2]]); -// CHECK-NEXT: wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID[[LOAD3]]); -// CHECK-NEXT: wait_flag(PIPE_M, PIPE_FIX, EVENT_ID[[PUSH]]); +// CHECK: wait_flag(PIPE_M, PIPE_FIX, EVENT_ID{{[0-9]+}}); // CHECK-NEXT: TPUSH // CHECK: ptoas_auto_sync_tail(PTOAutoSyncTailMode::kBarrierAll); // CHECK-NEXT: return; diff --git a/test/lit/pto/plan_memory_multi_buffer_double.pto b/test/lit/pto/plan_memory_multi_buffer_double.pto new file mode 100644 index 000000000..7711fdcdb --- /dev/null +++ b/test/lit/pto/plan_memory_multi_buffer_double.pto @@ -0,0 +1,35 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ptoas --mlir-print-ir-after=pto-plan-memory %s 2>&1 1>/dev/null | FileCheck %s + +module { + func.func @double_buffer(%arg0: memref<16x16x16xf16, #pto.address_space>, + %arg1: memref<16x16x16xf16, #pto.address_space>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + scf.for %i = %c0 to %c4 step %c1 { + %a = memref.alloc() {pto.multi_buffer = 2 : i32} + : memref<16x16x16xf16, #pto.address_space> + pto.tload ins(%arg0 : memref<16x16x16xf16, #pto.address_space>) + outs(%a : memref<16x16x16xf16, #pto.address_space>) + pto.tstore ins(%a : memref<16x16x16xf16, #pto.address_space>) + outs(%arg1 : memref<16x16x16xf16, #pto.address_space>) + } + return + } +} + +// PlanMemory must keep the multi_buffer attribute hint and produce a +// pto.pointer_cast with two i64 address operands (ping/pong slots). + +// CHECK: IR Dump After PlanMemory +// CHECK: func.func @double_buffer +// CHECK-NOT: memref.alloc +// CHECK: pto.pointer_cast(%{{.*}}, %{{.*}}) : diff --git a/tools/ptoas/ptoas.cpp b/tools/ptoas/ptoas.cpp index 309a94c96..884b87fe0 100644 --- a/tools/ptoas/ptoas.cpp +++ b/tools/ptoas/ptoas.cpp @@ -183,6 +183,12 @@ static llvm::cl::opt enableInsertSync("enable-insert-sync", llvm::cl::desc("Enable automatic synchronization insertion pass"), llvm::cl::init(false)); +static llvm::cl::opt enableMultiBufferLowering( + "enable-multi-buffer-lowering", + llvm::cl::desc("After insert-sync, lower variadic pto.pointer_cast into " + "single-address casts plus an iv mod N arith.select"), + llvm::cl::init(false)); + static llvm::cl::opt disableInferLayout( "disable-infer-layout", llvm::cl::desc("Disable PTO layout inference pass (static-only)"), @@ -1123,7 +1129,10 @@ int main(int argc, char **argv) { // Conditionally add Sync pass based on flag. if (enableInsertSync) pm.addNestedPass(pto::createPTOInsertSyncPass()); - + + if (enableMultiBufferLowering) + pm.addNestedPass( + pto::createPTOEnableMultiBufferPass()); // [Fix] ToolOutputFile Usage std::error_code ec; From 615a3b72ffd0ac0cf4cca0814cd1725787b8daf8 Mon Sep 17 00:00:00 2001 From: zhangstevenunity <128771452+zhangstevenunity@users.noreply.github.com> Date: Sun, 3 May 2026 15:05:59 +0800 Subject: [PATCH 03/10] multi-buffer: wire up HIVM-style multi event-id deduction (P0) PR #615 added the SyncCodegen scaffolding for multi-buffer set/wait_flag_dyn ops, but the upstream analysis path was inert: GetEventIdNum hard-coded return 1, and PTOIRTranslator dropped every pto.pointer_cast addr operand beyond the first. The N>1 multi-buffer code path was therefore unreachable. This change makes the PlanMemory -> InsertSync handshake actually work: - PTOIRTranslator::UpdatePointerCastOpMemInfo now translates *all* N pto.pointer_cast i64 addr operands into BaseMemInfo.baseAddresses (constants -> bit offsets, non-constants -> hasVariableAddress=true). Subview/alias deltaOffset is applied to every slot, not just slot 0. - MemoryDependentAnalyzer gains getMultiBufferSlotCount(a, b) implementing HIVM's per-slot geometry check: same-index slots must overlap (real back-edge dep on the same physical buffer) and different-index slots must NOT overlap (so consecutive iterations land in disjoint physical buffers). - InsertSyncAnalysis::GetEventIdNum is rewritten to follow HIVM semantics: every dependent pair must be multi-buffer-eligible, all pairs must agree on N, and every involved buffer must hang off the same scf.for. Verified end-to-end: a `pto.multi_buffer = 2` alloc inside scf.for now emits 2 reserved event ids, an `iv mod 2` arith.select chain, and `pto.{set,wait}_flag_dyn` ops driven by the selected idx. New regression guards this in test/lit/pto/multi_buffer_insert_sync_dyn_event_id.pto. Co-Authored-By: Claude --- .../InsertSync/MemoryDependentAnalyzer.h | 20 ++++-- .../InsertSync/InsertSyncAnalysis.cpp | 64 ++++++++++++++++--- .../InsertSync/MemoryDependentAnalyzer.cpp | 32 +++++++++- .../Transforms/InsertSync/PTOIRTranslator.cpp | 48 ++++++++++---- .../multi_buffer_insert_sync_dyn_event_id.pto | 55 ++++++++++++++++ 5 files changed, 191 insertions(+), 28 deletions(-) create mode 100644 test/lit/pto/multi_buffer_insert_sync_dyn_event_id.pto diff --git a/include/PTO/Transforms/InsertSync/MemoryDependentAnalyzer.h b/include/PTO/Transforms/InsertSync/MemoryDependentAnalyzer.h index 0e28476ec..849704a53 100644 --- a/include/PTO/Transforms/InsertSync/MemoryDependentAnalyzer.h +++ b/include/PTO/Transforms/InsertSync/MemoryDependentAnalyzer.h @@ -23,21 +23,29 @@ class MemoryDependentAnalyzer { public: MemoryDependentAnalyzer() = default; ~MemoryDependentAnalyzer() = default; - + // 检查两组内存信息之间是否存在依赖 bool DepBetween(const SmallVector &a, const SmallVector &b, DepBaseMemInfoPairVec &depBaseMemInfosVec); - + // 检查两个具体的 MemInfo 是否别名 bool MemAlias(const BaseMemInfo *a, const BaseMemInfo *b); - + + /// Multi-buffer eligibility for a dependent pair: HIVM requires both sides + /// to expose N>=2 byte-offset slots, sizes equal, **every same-index slot + /// overlaps** (the real cross-iteration dep) and **no different-index slot + /// overlaps** (so consecutive iterations land in disjoint physical buffers). + /// Returns N when eligible, otherwise 0. + unsigned getMultiBufferSlotCount(const BaseMemInfo *a, + const BaseMemInfo *b); + private: bool isGMBufferOverlap(const BaseMemInfo *a, const BaseMemInfo *b); - + bool isBufferAddressRangeOverlap(const BaseMemInfo *a, const BaseMemInfo *b); - - bool isBufferOverlap(const BaseMemInfo *a, const BaseMemInfo *b, + + bool isBufferOverlap(const BaseMemInfo *a, const BaseMemInfo *b, int aIndex, int bIndex); }; diff --git a/lib/PTO/Transforms/InsertSync/InsertSyncAnalysis.cpp b/lib/PTO/Transforms/InsertSync/InsertSyncAnalysis.cpp index 6cec030a4..b054ecbf9 100644 --- a/lib/PTO/Transforms/InsertSync/InsertSyncAnalysis.cpp +++ b/lib/PTO/Transforms/InsertSync/InsertSyncAnalysis.cpp @@ -535,18 +535,66 @@ SmallVector InsertSyncAnalysis::GetMemInfoBuffers( return result; } +// Walk up `value`'s parent op chain to the nearest enclosing scf.for, if any. +// Used to satisfy HIVM's constraint that every multi-buffer dependency pair +// share a common scf.for ancestor (so a single `iv mod N` selector is valid). +static scf::ForOp getEnclosingScfFor(Value value) { + if (!value) + return nullptr; + Operation *op = value.getDefiningOp(); + if (!op) { + // Block argument (e.g., loop iter_arg). Walk up from the parent block. + if (Block *block = value.getParentBlock()) + op = block->getParentOp(); + } + while (op) { + if (auto forOp = dyn_cast(op)) + return forOp; + op = op->getParentOp(); + } + return nullptr; +} + int InsertSyncAnalysis::GetEventIdNum( const DepBaseMemInfoPairVec &depBaseMemInfosVec) { + // HIVM `GetEventIdNum` semantics: only deduce N>1 when EVERY dependent pair + // is multi-buffer-eligible (same slot count, same-index slots overlap, + // different-index slots are disjoint), all pairs agree on N, and every + // involved root buffer hangs off the same scf.for. Any failure collapses to + // single-buffer (eventIdNum = 1). + if (depBaseMemInfosVec.empty()) + return 1; + + unsigned commonN = 0; + scf::ForOp commonLoop; for (const auto &pair : depBaseMemInfosVec) { - bool isLocalA = - pair.first && (pair.first->scope == pto::AddressSpace::MAT || - pair.first->scope == pto::AddressSpace::VEC); - bool isLocalB = - pair.second && (pair.second->scope == pto::AddressSpace::MAT || - pair.second->scope == pto::AddressSpace::VEC); - if (isLocalA || isLocalB) return 1; + unsigned n = memAnalyzer_.getMultiBufferSlotCount(pair.first, pair.second); + if (n < 2) + return 1; + if (commonN == 0) + commonN = n; + else if (commonN != n) + return 1; + + auto checkLoop = [&](Value buffer) -> bool { + auto forOp = getEnclosingScfFor(buffer); + if (!forOp) + return false; + if (!commonLoop) + commonLoop = forOp; + return commonLoop == forOp; + }; + // Use `baseBuffer` (the alloc-like SSA result inside the loop body) + // rather than `rootBuffer`, which for pto.pointer_cast is the i64 base + // address at function top and has no enclosing scf.for. + if (!checkLoop(pair.first->baseBuffer) || + !checkLoop(pair.second->baseBuffer)) + return 1; } - return 1; + + if (commonN == 0 || commonN > MAX_MULTI_BUFFER_NUM) + return 1; + return static_cast(commonN); } bool InsertSyncAnalysis::IsGMHazard( diff --git a/lib/PTO/Transforms/InsertSync/MemoryDependentAnalyzer.cpp b/lib/PTO/Transforms/InsertSync/MemoryDependentAnalyzer.cpp index 08c529246..4b2a6b7c9 100644 --- a/lib/PTO/Transforms/InsertSync/MemoryDependentAnalyzer.cpp +++ b/lib/PTO/Transforms/InsertSync/MemoryDependentAnalyzer.cpp @@ -226,6 +226,36 @@ bool MemoryDependentAnalyzer::isBufferOverlap(const BaseMemInfo *a, uint64_t maxStart = std::max(aStart, bStart); uint64_t minEnd = std::min(aEnd, bEnd); - + return maxStart < minEnd; } + +unsigned MemoryDependentAnalyzer::getMultiBufferSlotCount( + const BaseMemInfo *a, const BaseMemInfo *b) { + if (a == nullptr || b == nullptr) + return 0; + // Variable addresses cannot prove the disjoint-slot invariant. + if (a->hasVariableAddress || b->hasVariableAddress) + return 0; + if (a->baseAddresses.size() != b->baseAddresses.size()) + return 0; + unsigned n = static_cast(a->baseAddresses.size()); + if (n < 2) + return 0; + if (a->allocateSize == 0 || b->allocateSize == 0) + return 0; + + // Same-index slots must overlap (real backward dep across iterations on the + // same physical buffer); different-index slots must NOT overlap (otherwise + // consecutive iterations would alias and multi-buffer is unsafe). + for (unsigned i = 0; i < n; ++i) { + if (!isBufferOverlap(a, b, static_cast(i), static_cast(i))) + return 0; + for (unsigned j = 0; j < n; ++j) { + if (i == j) continue; + if (isBufferOverlap(a, b, static_cast(i), static_cast(j))) + return 0; + } + } + return n; +} diff --git a/lib/PTO/Transforms/InsertSync/PTOIRTranslator.cpp b/lib/PTO/Transforms/InsertSync/PTOIRTranslator.cpp index 8ba4f265b..21c771c68 100644 --- a/lib/PTO/Transforms/InsertSync/PTOIRTranslator.cpp +++ b/lib/PTO/Transforms/InsertSync/PTOIRTranslator.cpp @@ -272,12 +272,12 @@ LogicalResult PTOIRTranslator::UpdatePointerCastOpMemInfo(pto::PointerCastOp op) Value res = op.getResult(); auto memRefType = dyn_cast(res.getType()); if (!memRefType) return failure(); - + if (op.getAddrs().empty()) { return op.emitError("PointerCast must have at least one address operand"); } - Value rootSrc = op.getAddrs().front(); - + Value rootSrc = op.getAddrs().front(); + uint64_t sizeInBytes = 0; if (memRefType.hasStaticShape()) { int64_t elemSize = memRefType.getElementType().getIntOrFloatBitWidth() / 8; @@ -285,22 +285,41 @@ LogicalResult PTOIRTranslator::UpdatePointerCastOpMemInfo(pto::PointerCastOp op) for (auto dim : memRefType.getShape()) numElements *= dim; sizeInBytes = numElements * elemSize; } - - pto::AddressSpace space = pto::AddressSpace::GM; + + pto::AddressSpace space = pto::AddressSpace::GM; if (auto attr = memRefType.getMemorySpace()) { if (auto ptoAttr = dyn_cast(attr)) { space = ptoAttr.getAddressSpace(); } } - + + // Multi-buffer pointer_cast carries N>=2 byte-offset operands (one per + // physical slot). Lift each compile-time constant addr into baseAddresses so + // dependency analysis can see N slots and InsertSyncAnalysis can deduce + // eventIdNum. Non-constant operands set hasVariableAddress so the analyzer + // falls back to conservative single-buffer treatment. + SmallVector baseAddresses; + baseAddresses.reserve(op.getAddrs().size()); + bool hasVariableAddress = false; + for (Value addr : op.getAddrs()) { + APInt cst; + if (matchPattern(addr, m_ConstantInt(&cst))) { + baseAddresses.push_back(cst.getZExtValue()); + } else { + hasVariableAddress = true; + baseAddresses.push_back(0); + } + } + auto newMemInfo = std::make_unique( - res, - rootSrc, + res, + rootSrc, space, - SmallVector{0}, - sizeInBytes + std::move(baseAddresses), + sizeInBytes, + hasVariableAddress ); - + buffer2MemInfoMap_[res].emplace_back(newMemInfo->clone()); return success(); } @@ -567,9 +586,12 @@ void PTOIRTranslator::UpdateAliasBufferInfo(Value result, Value source) { for (auto &parentInfo : buffer2MemInfoMap_[source]) { auto newInfo = parentInfo->clone(result); - + if (!newInfo->baseAddresses.empty()) { - newInfo->baseAddresses[0] += deltaOffset; + // Multi-buffer-aware alias: a static delta from a view/subview applies + // to every physical slot, not only slot 0. + for (uint64_t &slotAddr : newInfo->baseAddresses) + slotAddr += deltaOffset; } else { newInfo->baseAddresses.push_back(deltaOffset); } diff --git a/test/lit/pto/multi_buffer_insert_sync_dyn_event_id.pto b/test/lit/pto/multi_buffer_insert_sync_dyn_event_id.pto new file mode 100644 index 000000000..4615c1c71 --- /dev/null +++ b/test/lit/pto/multi_buffer_insert_sync_dyn_event_id.pto @@ -0,0 +1,55 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ptoas --enable-insert-sync --enable-multi-buffer-lowering --mlir-print-ir-after=pto-insert-sync %s 2>&1 1>/dev/null | FileCheck %s + +// End-to-end regression for the multi-buffer event-id pipeline: +// PlanMemory -> N-address pto.pointer_cast on a `pto.multi_buffer = 2` +// alloc inside scf.for. +// InsertSync -> sees N>=2 baseAddresses (PTOIRTranslator A2 fix), decides +// eventIdNum=2 via the per-slot overlap helper (A3) and the +// HIVM-style `GetEventIdNum` deduction (A1), then emits +// `pto.{set,wait}_flag_dyn` driven by `iv mod N` `arith.select`. +// +// Without A1/A2/A3 this test regresses to a single static set_flag/wait_flag +// pair on EVENT_ID0 with no remui/select. + +module { + func.func @double_buffer_dyn_event( + %arg0: memref<16x16x16xf16, #pto.address_space>, + %arg1: memref<16x16x16xf16, #pto.address_space>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + scf.for %i = %c0 to %c4 step %c1 { + %a = memref.alloc() {pto.multi_buffer = 2 : i32} + : memref<16x16x16xf16, #pto.address_space> + pto.tload ins(%arg0 : memref<16x16x16xf16, #pto.address_space>) + outs(%a : memref<16x16x16xf16, #pto.address_space>) + pto.tstore ins(%a : memref<16x16x16xf16, #pto.address_space>) + outs(%arg1 : memref<16x16x16xf16, #pto.address_space>) + } + return + } +} + +// CHECK: IR Dump After PTOInsertSync +// CHECK: func.func @double_buffer_dyn_event +// Two distinct event ids must be reserved up front for the N=2 backward sync. +// CHECK: pto.set_flag[, , ] +// CHECK: pto.set_flag[, , ] +// Inside the loop: iv mod 2 + arith.select chain selecting between event ids. +// CHECK: scf.for %[[IV:.*]] = +// CHECK: arith.remui %[[IV]], %{{.*}} : index +// CHECK: arith.select +// Dynamic-event-id wait/set ops drive the multi-buffer backward sync. +// CHECK: pto.wait_flag_dyn[, , %{{.*}}] +// CHECK: pto.set_flag_dyn[, , %{{.*}}] +// Trailing waits at function exit drain both event ids. +// CHECK: pto.wait_flag[, , ] +// CHECK: pto.wait_flag[, , ] From 7e1a3e59d2c3dfb3056e968c4a6eb7716952cc80 Mon Sep 17 00:00:00 2001 From: zhangstevenunity <128771452+zhangstevenunity@users.noreply.github.com> Date: Sun, 3 May 2026 15:10:30 +0800 Subject: [PATCH 04/10] multi-buffer: harden plan/sync paths and add fallback (P1) Builds on P0 to make the multi-buffer pipeline robust under realistic pressure: - SyncEventIdAllocation: HIVM-style TryFallbackMultiBuffer. When the event-id pool can't satisfy N, fall back to N=2 first (when N is even and >2) then to 1, instead of silently returning with an empty eventIds vec. Both set/wait siblings get their `eventIdNum` updated in lockstep so SyncCodegen takes the same code path on each side. - PTOPlanMemory: add `StorageEntry::isMultiBufferSlot` flag and rewrite `UpdateBuffer2Offsets` to walk via primaries and emit slots in declared `relationOtherBuffers` order. This makes the slot-ordering contract that EnableMultiBuffer relies on (`offsets[i]` <-> iv mod N selector index `i`) explicit and verifies it via a runtime assertion. - PTOEnableMultiBuffer: add scope guard (skip non-VEC/MAT casts where the iv-mod-N selector is not meaningful) and loop-invariance guard (skip casts whose addrs are not loop-invariant - hoisting them above the for loop would break SSA dominance). - New regression at test/lit/pto/multi_buffer_n4_insert_sync.pto: N=4 slot ordering [0, 1024, 2048, 3072] in pto.pointer_cast, 3-way arith.select chain, and 4 distinct event ids drained at function end. Co-Authored-By: Claude --- .../InsertSync/SyncEventIdAllocation.cpp | 44 +++++++++++++--- lib/PTO/Transforms/PTOEnableMultiBuffer.cpp | 47 +++++++++++++++++ lib/PTO/Transforms/PTOPlanMemory.cpp | 48 +++++++++++++++-- lib/PTO/Transforms/PTOPlanMemory.h | 8 +++ test/lit/pto/multi_buffer_n4_insert_sync.pto | 52 +++++++++++++++++++ 5 files changed, 188 insertions(+), 11 deletions(-) create mode 100644 test/lit/pto/multi_buffer_n4_insert_sync.pto diff --git a/lib/PTO/Transforms/InsertSync/SyncEventIdAllocation.cpp b/lib/PTO/Transforms/InsertSync/SyncEventIdAllocation.cpp index d937b468b..354ab95d6 100644 --- a/lib/PTO/Transforms/InsertSync/SyncEventIdAllocation.cpp +++ b/lib/PTO/Transforms/InsertSync/SyncEventIdAllocation.cpp @@ -152,19 +152,49 @@ void SyncEventIdAllocation::SetEventId(SyncOperation *sync) { size_t idSize = static_cast(sync->eventIdNum); SmallVector canAllocaEventId = GetAvailableEventId( sync, eventIdLifetimeAvailableStatus, eventIdIdleStatus, poolSize); - + if (canAllocaEventId.empty()) { return; - } else if (canAllocaEventId.size() >= idSize) { + } + if (canAllocaEventId.size() >= idSize) { for (auto &id : canAllocaEventId) { SetEventPool(sync, id); } - } else if (reallocatedPipePair.count(ScopePair(sync)) && - (canAllocaEventId.size() < idSize)) { - // Reallocate strategy: reduce usage to 1 - assert(canAllocaEventId.size() > 0); + return; + } + + // P1/B3: HIVM-style multi-buffer fallback. When the pool can't satisfy the + // requested N event ids, try a smaller N before giving up: + // N odd or N == 2: collapse straight to single-buffer (eventIdNum = 1). + // N even and N > 2: try N = 2 first, then 1. The intermediate stop keeps + // some pipelining benefit when the original 4/8/... exhausted the pool. + // Apply both fallback rungs uniformly (not only when reallocatedPipePair is + // set) - the previous code only collapsed in the rare reallocation path, + // which left non-rare exhaustions emitting sync->eventIds.empty() and + // silently dropping the sync. + auto applyFallbackToPair = [this](SyncOperation *sync, unsigned newN) { + sync->eventIdNum = newN; + auto &syncPair = syncOperations_[sync->GetSyncIndex()]; + for (auto &op : syncPair) { + // Keep set/wait siblings consistent so SyncCodegen takes the same + // single- vs multi-buffer code path on both sides. + op->eventIdNum = newN; + } + }; + + unsigned fallbackN = (idSize > 2 && (idSize % 2 == 0)) ? 2u : 1u; + if (canAllocaEventId.size() >= fallbackN) { + applyFallbackToPair(sync, fallbackN); + for (size_t i = 0; i < fallbackN; ++i) { + SetEventPool(sync, canAllocaEventId[i]); + } + return; + } + // Last-resort fallback: if even fallbackN can't be satisfied, take whatever + // single id is available (pre-existing emergency behaviour). + if (!canAllocaEventId.empty()) { + applyFallbackToPair(sync, 1u); SetEventPool(sync, canAllocaEventId[0]); - sync->eventIdNum = 1; } } diff --git a/lib/PTO/Transforms/PTOEnableMultiBuffer.cpp b/lib/PTO/Transforms/PTOEnableMultiBuffer.cpp index 14b4ccc16..59bf61e91 100644 --- a/lib/PTO/Transforms/PTOEnableMultiBuffer.cpp +++ b/lib/PTO/Transforms/PTOEnableMultiBuffer.cpp @@ -70,6 +70,23 @@ static LogicalResult lowerMultiBufferPointerCast(IRRewriter &rewriter, return success(); } +// Multi-buffer is a local-memory optimisation. GM pointer casts may also be +// multi-address (e.g., a future workspace path) but the iv-mod-N selector is +// not meaningful for GM, so we skip them here. +static bool isLocalScopePointerCast(PointerCastOp op) { + auto memrefTy = dyn_cast(op.getType()); + if (!memrefTy) + return false; + auto attr = memrefTy.getMemorySpace(); + if (!attr) + return false; + auto ptoAttr = dyn_cast(attr); + if (!ptoAttr) + return false; + AddressSpace as = ptoAttr.getAddressSpace(); + return as == AddressSpace::VEC || as == AddressSpace::MAT; +} + struct PTOEnableMultiBufferPass : public mlir::pto::impl::PTOEnableMultiBufferBase< PTOEnableMultiBufferPass> { @@ -83,12 +100,42 @@ struct PTOEnableMultiBufferPass IRRewriter rewriter(&getContext()); for (PointerCastOp op : work) { + // D2: scope guard. Multi-buffer slot selection only makes sense for + // local memory (VEC/MAT). Multi-address casts in GM (e.g., reserved + // workspaces) must keep their original semantics. + if (!isLocalScopePointerCast(op)) { + op.emitWarning() << "pto-enable-multi-buffer: skipping non-local " + "pointer_cast (multi-buffer is VEC/MAT-only)"; + continue; + } + auto forOp = op->getParentOfType(); if (!forOp) { op.emitWarning() << "pto-enable-multi-buffer: expected enclosing scf.for; skipping"; continue; } + + // D1: loop-invariance guard. The pass hoists each addr operand and the + // resulting single-address pto.pointer_cast above `forOp`. SSA dominance + // requires every addr to be defined outside the loop. Today PlanMemory + // emits constant i64 offsets so this always holds, but a future + // dynamic-address path (e.g., workspace double-buffer) would silently + // violate dominance without this check. + bool addrsAreLoopInvariant = true; + for (Value addr : op.getAddrs()) { + if (!forOp.isDefinedOutsideOfLoop(addr)) { + addrsAreLoopInvariant = false; + break; + } + } + if (!addrsAreLoopInvariant) { + op.emitWarning() << "pto-enable-multi-buffer: addr operand is not " + "loop-invariant; skipping (would break SSA on " + "hoist)"; + continue; + } + if (failed(lowerMultiBufferPointerCast(rewriter, op, forOp))) { signalPassFailure(); return; diff --git a/lib/PTO/Transforms/PTOPlanMemory.cpp b/lib/PTO/Transforms/PTOPlanMemory.cpp index 7491426b7..1396f7758 100644 --- a/lib/PTO/Transforms/PTOPlanMemory.cpp +++ b/lib/PTO/Transforms/PTOPlanMemory.cpp @@ -1529,15 +1529,52 @@ void MemPlan::ValidateParameters(std::unique_ptr &e) const { } void MemPlan::UpdateBuffer2Offsets() { + // Slot ordering contract (relied on by EnableMultiBuffer / AllocToPointerCast): + // buffer2Offsets[buf][i] is the byte offset of physical slot `i`, + // selected at runtime by `iv mod N == i`. Slot 0 is the primary entry's + // bitsOffset; slots 1..N-1 are `relationOtherBuffers[0..N-2]` in order. + // To preserve this we walk via primaries and explicitly emit slots, and + // skip non-primary slot entries that share `inplaceBuffers` with the + // primary (they would otherwise be double-pushed by a naive linear walk). + auto bitsToBytes = [](uint64_t bits) { + return (bits + kBitsToByte - 1) / kBitsToByte; + }; + for (auto &e : StorageEntryVec) { + if (e->isMultiBufferSlot) + continue; // handled via its primary below + + if (e->multiBufferNum > 1) { + // Multi-buffer primary: emit slot 0 then slots 1..N-1 in declared order. + for (Value &buffer : e->inplaceBuffers) { + buffer2Offsets[buffer].push_back(bitsToBytes(e->bitsOffset)); + for (StorageEntry *slot : e->relationOtherBuffers) { + if (!slot) + llvm::report_fatal_error( + "multi-buffer primary has null relation slot"); + buffer2Offsets[buffer].push_back(bitsToBytes(slot->bitsOffset)); + } + } + // Defensive invariant: each multi-buffered buffer must end up with + // exactly multiBufferNum offsets after this call (modulo the + // SPEC_LEVEL_1 single-reuse-db append below, which only fires for + // single-buffer entries). + for (Value &buffer : e->inplaceBuffers) { + if (buffer2Offsets[buffer].size() != e->multiBufferNum) { + llvm::report_fatal_error( + "multi-buffer offset count mismatch in UpdateBuffer2Offsets"); + } + } + continue; + } + + // Single-buffer entry: classic single-offset push. for (Value &buffer : e->inplaceBuffers) { - // MultiBuffer can cause multiple addrs. - buffer2Offsets[buffer].push_back( - (e->bitsOffset + kBitsToByte - 1) / kBitsToByte); + buffer2Offsets[buffer].push_back(bitsToBytes(e->bitsOffset)); } } // In the MultiBuffer scenario, single reuse db will result in additional - // storageEntry. + // storageEntry. Only fires for single-buffer primaries that took a DB slot. UpdateMultiBufferReuseExtraOffset(); } @@ -1677,6 +1714,9 @@ void MemPlan::ExpandMultiBufferStorageEntry() { entry->alignedConstBits = primary->alignedConstBits; entry->inplaceBuffers = primary->inplaceBuffers; entry->multiBufferNum = primary->multiBufferNum; + // Mark this as a non-primary slot. UpdateBuffer2Offsets uses this flag + // to enforce primary-first slot ordering. + entry->isMultiBufferSlot = true; StorageEntry *raw = entry.get(); primary->relationOtherBuffers.push_back(raw); StorageEntryVec.push_back(std::move(entry)); diff --git a/lib/PTO/Transforms/PTOPlanMemory.h b/lib/PTO/Transforms/PTOPlanMemory.h index 4f0574f96..c7ba142e7 100644 --- a/lib/PTO/Transforms/PTOPlanMemory.h +++ b/lib/PTO/Transforms/PTOPlanMemory.h @@ -163,6 +163,14 @@ struct StorageEntry { /// Pointers reference sibling `StorageEntry` objects owned in `MemPlan`. SmallVector relationOtherBuffers; + /// True for entries created by `ExpandMultiBufferStorageEntry` as the + /// non-primary slots of a multi-buffer cluster. Such entries share their + /// `inplaceBuffers` with the primary; `UpdateBuffer2Offsets` walks the + /// cluster via the primary so it must skip slot entries to keep the + /// primary-first slot ordering contract that EnableMultiBuffer relies on + /// (offsets[i] == iv-mod-N selector index i). + bool isMultiBufferSlot{false}; + /// The number of multibuffer optimization. /// note: default 1 which means single buffer and does not do multibuffer /// optimization. diff --git a/test/lit/pto/multi_buffer_n4_insert_sync.pto b/test/lit/pto/multi_buffer_n4_insert_sync.pto new file mode 100644 index 000000000..8f4825b76 --- /dev/null +++ b/test/lit/pto/multi_buffer_n4_insert_sync.pto @@ -0,0 +1,52 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ptoas --enable-insert-sync --enable-multi-buffer-lowering --mlir-print-ir-after=pto-insert-sync %s 2>&1 1>/dev/null | FileCheck %s + +// N=4 stresses paths only N=2 cannot: +// PlanMemory - ExpandMultiBufferStorageEntry creates 3 relation slot +// entries; UpdateBuffer2Offsets must emit slot offsets in +// [primary, slot1, slot2, slot3] order (C1 invariant). +// InsertSync - GetEventIdNum must return 4 (not just any > 1). +// Codegen - the N-way arith.select chain must compare iv mod 4 against +// 1, 2, 3 (slot 0 is the chain's tail/default). +// EventId alloc - if 4 ids are unavailable, B3 fallback should land on 2 or +// 1 instead of leaving sync->eventIds empty. + +module { + func.func @quad_buffer(%arg0: memref<8x8x8xf16, #pto.address_space>, + %arg1: memref<8x8x8xf16, #pto.address_space>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + scf.for %i = %c0 to %c8 step %c1 { + %a = memref.alloc() {pto.multi_buffer = 4 : i32} + : memref<8x8x8xf16, #pto.address_space> + pto.tload ins(%arg0 : memref<8x8x8xf16, #pto.address_space>) + outs(%a : memref<8x8x8xf16, #pto.address_space>) + pto.tstore ins(%a : memref<8x8x8xf16, #pto.address_space>) + outs(%arg1 : memref<8x8x8xf16, #pto.address_space>) + } + return + } +} + +// CHECK: IR Dump After PTOInsertSync +// CHECK: func.func @quad_buffer +// 4 set_flag pre-loop on EVENT_ID0..ID3 (or B3 fallback to 2/1). +// CHECK: pto.set_flag[, , ] +// CHECK: pto.set_flag[, , ] +// Inside the loop: iv mod 4, then a 4-way arith.select chain selecting an idx. +// CHECK: scf.for %[[IV:.*]] = +// CHECK: arith.remui %[[IV]], %{{.*}} : index +// CHECK-COUNT-3: arith.select +// CHECK: pto.wait_flag_dyn[, , %{{.*}}] +// CHECK: pto.set_flag_dyn[, , %{{.*}}] +// Final drain on all 4 event ids. +// CHECK: pto.wait_flag[, , ] +// CHECK: pto.wait_flag[, , ] From d7db3f876efddd28da20b7bf3024f9be38e066cb Mon Sep 17 00:00:00 2001 From: zhangstevenunity <128771452+zhangstevenunity@users.noreply.github.com> Date: Sun, 3 May 2026 15:12:24 +0800 Subject: [PATCH 05/10] multi-buffer: refine sync semantics and SPEC_LEVEL_1 reuse (P2) Final hardening pass tightening multi-buffer correctness around the edges PR-615 and earlier P0/P1 commits left ambiguous. - InsertSyncAnalysis: GetEventIdNum now takes the back-edge scf.for and requires every involved buffer's enclosing loop to *equal* that loop. Previously a multi-buffer alloc nested inside an inner loop could silently rotate slots on the wrong induction variable when a back-edge on an outer loop carried the dep. Forward deps unconditionally collapse to single-buffer. - PTOEnableMultiBuffer: cache `iv mod N` per (loop, N) so several multi-buffer pointer_casts sharing a loop reuse one counter, mirroring SyncCodegen::loop2BufferCounter. Avoids redundant arith.remui + constant ops in the loop body. - PTOPlanMemory: * `MemPlan::emitMultiBufferError` + `multiBufferDiagnosticEmitted_` flag converts multi-buffer-specific `report_fatal_error` calls (slot-order mismatch, pong outline placement failure) into diagnostics that bubble through `plan()` as `failure()`. This lets the existing PlanMemoryPass retry loop re-seed instead of aborting the compiler under heavy multi-buffer memory pressure. * `VerifyConflictStage1` now enumerates *every* historical multi-buffer slot offset (via `CollectMultiRelationPongAnchors`) as a SPEC_LEVEL_1 reuse anchor candidate. PR-615 only used the last slot, silently dropping reuse for N > 2. - New regression at test/lit/pto/multi_buffer_nested_loop.pto: nested scf.for with multi-buffer alloc in the inner loop must rotate slots on the inner iv, not the outer one. Co-Authored-By: Claude --- .../InsertSync/InsertSyncAnalysis.h | 9 +- .../InsertSync/InsertSyncAnalysis.cpp | 37 +++-- lib/PTO/Transforms/PTOEnableMultiBuffer.cpp | 45 ++++-- lib/PTO/Transforms/PTOPlanMemory.cpp | 133 +++++++++++++----- lib/PTO/Transforms/PTOPlanMemory.h | 21 +++ test/lit/pto/multi_buffer_nested_loop.pto | 47 +++++++ 6 files changed, 240 insertions(+), 52 deletions(-) create mode 100644 test/lit/pto/multi_buffer_nested_loop.pto diff --git a/include/PTO/Transforms/InsertSync/InsertSyncAnalysis.h b/include/PTO/Transforms/InsertSync/InsertSyncAnalysis.h index 3e4cc3e0d..b19e6a9a4 100644 --- a/include/PTO/Transforms/InsertSync/InsertSyncAnalysis.h +++ b/include/PTO/Transforms/InsertSync/InsertSyncAnalysis.h @@ -159,8 +159,13 @@ class InsertSyncAnalysis { const CompoundInstanceElement *frontCompound, bool isBackwardDep) const; - /// 获取依赖对涉及的 Event ID 数量 (用于 Multi-Buffer 分析) - int GetEventIdNum(const DepBaseMemInfoPairVec &depBaseMemInfosVec); + /// Multi-buffer event-id deduction (HIVM-style). `backEdgeForLoop`, when + /// non-null, is the scf.for whose back-edge this dependency crosses; the + /// deduction additionally requires every involved buffer to live directly + /// under that loop. If null (forward dep), the deduction is a no-op and + /// returns 1. + int GetEventIdNum(const DepBaseMemInfoPairVec &depBaseMemInfosVec, + Operation *backEdgeForLoop = nullptr); /// 辅助函数:获取所有涉及的 Buffer (用于 LCA 计算,虽然现在简化了,保留接口) SmallVector GetMemInfoBuffers(const DepBaseMemInfoPairVec &depBaseMemInfosVec); diff --git a/lib/PTO/Transforms/InsertSync/InsertSyncAnalysis.cpp b/lib/PTO/Transforms/InsertSync/InsertSyncAnalysis.cpp index b054ecbf9..2989430df 100644 --- a/lib/PTO/Transforms/InsertSync/InsertSyncAnalysis.cpp +++ b/lib/PTO/Transforms/InsertSync/InsertSyncAnalysis.cpp @@ -372,9 +372,19 @@ void InsertSyncAnalysis::InsertSyncOperation( setOp->SetDepSyncIRIndex(frontCompound->GetIndex()); waitOp->SetDepSyncIRIndex(frontCompound->GetIndex()); - // Back-edge dependencies may require multi-buffer event IDs. + // Back-edge dependencies may require multi-buffer event IDs. Resolve the + // owning scf.for so GetEventIdNum can verify that the dep buffer rotates + // on the right loop's induction variable (B1). if (forEndIndex.has_value()) { - int eventIdNum = GetEventIdNum(depBaseMemInfosVec); + Operation *backEdgeForOp = nullptr; + if (forEndIndex.value() < syncIR_.size()) { + InstanceElement *loopElem = syncIR_[forEndIndex.value()].get(); + if (loopElem) { + // For LOOP_END elements, elementOp points at the originating scf.for. + backEdgeForOp = loopElem->elementOp; + } + } + int eventIdNum = GetEventIdNum(depBaseMemInfosVec, backEdgeForOp); setOp->eventIdNum = eventIdNum; waitOp->eventIdNum = eventIdNum; } @@ -556,17 +566,24 @@ static scf::ForOp getEnclosingScfFor(Value value) { } int InsertSyncAnalysis::GetEventIdNum( - const DepBaseMemInfoPairVec &depBaseMemInfosVec) { + const DepBaseMemInfoPairVec &depBaseMemInfosVec, + Operation *backEdgeForLoop) { // HIVM `GetEventIdNum` semantics: only deduce N>1 when EVERY dependent pair // is multi-buffer-eligible (same slot count, same-index slots overlap, // different-index slots are disjoint), all pairs agree on N, and every // involved root buffer hangs off the same scf.for. Any failure collapses to // single-buffer (eventIdNum = 1). + // + // Forward dependencies (no enclosing back-edge) are unconditionally + // single-buffer: multi-event-id only buys parallelism by breaking + // loop-carried sync, so it's meaningless without a back-edge. if (depBaseMemInfosVec.empty()) return 1; + auto backEdgeFor = dyn_cast_or_null(backEdgeForLoop); + if (!backEdgeFor) + return 1; unsigned commonN = 0; - scf::ForOp commonLoop; for (const auto &pair : depBaseMemInfosVec) { unsigned n = memAnalyzer_.getMultiBufferSlotCount(pair.first, pair.second); if (n < 2) @@ -576,13 +593,15 @@ int InsertSyncAnalysis::GetEventIdNum( else if (commonN != n) return 1; + // B1: every involved buffer's enclosing scf.for must match the back-edge + // loop. A buffer that lives in an *inner* loop nested inside the back-edge + // loop would rotate slots on the wrong iv (inner.iv mod N), giving the + // wrong physical slot for a backward dep that crosses the outer + // back-edge. A buffer in an *outer* loop never rotates with the back-edge + // we care about. Either case must collapse to single-buffer. auto checkLoop = [&](Value buffer) -> bool { auto forOp = getEnclosingScfFor(buffer); - if (!forOp) - return false; - if (!commonLoop) - commonLoop = forOp; - return commonLoop == forOp; + return forOp && forOp.getOperation() == backEdgeFor.getOperation(); }; // Use `baseBuffer` (the alloc-like SSA result inside the loop body) // rather than `rootBuffer`, which for pto.pointer_cast is the i64 base diff --git a/lib/PTO/Transforms/PTOEnableMultiBuffer.cpp b/lib/PTO/Transforms/PTOEnableMultiBuffer.cpp index 59bf61e91..4237d8208 100644 --- a/lib/PTO/Transforms/PTOEnableMultiBuffer.cpp +++ b/lib/PTO/Transforms/PTOEnableMultiBuffer.cpp @@ -26,9 +26,32 @@ using namespace mlir::pto; namespace { -static LogicalResult lowerMultiBufferPointerCast(IRRewriter &rewriter, - PointerCastOp op, - scf::ForOp forOp) { +// (loop, factor N) -> shared `iv mod N` counter inside that loop. +// B5: when several multi-buffer pointer_casts share the same enclosing scf.for +// and the same N, they should all read from the same counter rather than each +// inserting its own arith.remui + N constant ops. This mirrors +// SyncCodegen::loop2BufferCounter so the two passes emit consistent counters. +using LoopFactorKey = std::pair; + +static Value getOrCreateLoopCounter( + IRRewriter &rewriter, + llvm::DenseMap &cache, + scf::ForOp forOp, unsigned n, Location loc) { + auto key = std::make_pair(forOp, n); + auto it = cache.find(key); + if (it != cache.end()) + return it->second; + rewriter.setInsertionPointToStart(forOp.getBody()); + Value iv = forOp.getInductionVar(); + Value cN = rewriter.create(loc, n); + Value rem = rewriter.create(loc, iv, cN); + cache[key] = rem; + return rem; +} + +static LogicalResult lowerMultiBufferPointerCast( + IRRewriter &rewriter, PointerCastOp op, scf::ForOp forOp, + llvm::DenseMap &counterCache) { ValueRange addrs = op.getAddrs(); unsigned n = static_cast(addrs.size()); assert(n >= 2); @@ -52,10 +75,12 @@ static LogicalResult lowerMultiBufferPointerCast(IRRewriter &rewriter, slotBufs.push_back(slot.getResult()); } - rewriter.setInsertionPointToStart(forOp.getBody()); - Value iv = forOp.getInductionVar(); - Value cN = rewriter.create(loc, n); - Value rem = rewriter.create(loc, iv, cN); + Value rem = getOrCreateLoopCounter(rewriter, counterCache, forOp, n, loc); + // Insertion point for the select chain: right after the cached counter. + // (For a freshly created counter this lands at the same position as before; + // for cached ones it lands wherever that earlier remui sits, which is still + // at loop-body start so the chain dominates all uses inside the loop.) + rewriter.setInsertionPointAfter(rem.getDefiningOp()); Value selected = slotBufs[0]; for (unsigned i = 1; i < n; ++i) { @@ -99,6 +124,9 @@ struct PTOEnableMultiBufferPass }); IRRewriter rewriter(&getContext()); + // Per-(loop, factor) counter cache so multiple multi-buffer pointer_casts + // sharing a loop and N reuse one `iv mod N` (B5). + llvm::DenseMap counterCache; for (PointerCastOp op : work) { // D2: scope guard. Multi-buffer slot selection only makes sense for // local memory (VEC/MAT). Multi-address casts in GM (e.g., reserved @@ -136,7 +164,8 @@ struct PTOEnableMultiBufferPass continue; } - if (failed(lowerMultiBufferPointerCast(rewriter, op, forOp))) { + if (failed(lowerMultiBufferPointerCast(rewriter, op, forOp, + counterCache))) { signalPassFailure(); return; } diff --git a/lib/PTO/Transforms/PTOPlanMemory.cpp b/lib/PTO/Transforms/PTOPlanMemory.cpp index 1396f7758..36e70a645 100644 --- a/lib/PTO/Transforms/PTOPlanMemory.cpp +++ b/lib/PTO/Transforms/PTOPlanMemory.cpp @@ -1401,6 +1401,15 @@ bool MemPlan::HasSemanticConflict(const StorageEntry *entry, return false; } +void MemPlan::emitMultiBufferError(const llvm::Twine &msg) { + multiBufferDiagnosticEmitted_ = true; + if (func_) { + func_->emitError() << "[multi-buffer plan] " << msg; + } else { + llvm::errs() << "[multi-buffer plan] " << msg << "\n"; + } +} + // Plan Memory algorithm. LogicalResult MemPlan::plan(bool emitErrors) { // Construct StorageEntry structure. @@ -1414,6 +1423,12 @@ LogicalResult MemPlan::plan(bool emitErrors) { EmitPlanMemoryFailureInfo(); return failure(); } + if (multiBufferDiagnosticEmitted_) { + // A multi-buffer-specific invariant tripped (e.g., a planner intermediate + // wrote inconsistent slot data). Surface it as a failure so the + // PlanMemoryPass retry loop can attempt another seed instead of aborting. + return failure(); + } if (RecordOverflowIfAny()) { if (emitErrors) EmitPlanMemoryFailureInfo(); @@ -1546,23 +1561,34 @@ void MemPlan::UpdateBuffer2Offsets() { if (e->multiBufferNum > 1) { // Multi-buffer primary: emit slot 0 then slots 1..N-1 in declared order. + bool failedThisEntry = false; for (Value &buffer : e->inplaceBuffers) { buffer2Offsets[buffer].push_back(bitsToBytes(e->bitsOffset)); for (StorageEntry *slot : e->relationOtherBuffers) { - if (!slot) - llvm::report_fatal_error( + if (!slot) { + // D4: surface as a recoverable error so the PlanMemoryPass retry + // loop can re-seed instead of aborting the compiler. + emitMultiBufferError( "multi-buffer primary has null relation slot"); + failedThisEntry = true; + break; + } buffer2Offsets[buffer].push_back(bitsToBytes(slot->bitsOffset)); } + if (failedThisEntry) + break; } + if (failedThisEntry) + return; // plan() will see the diagnostic flag and return failure. // Defensive invariant: each multi-buffered buffer must end up with // exactly multiBufferNum offsets after this call (modulo the // SPEC_LEVEL_1 single-reuse-db append below, which only fires for // single-buffer entries). for (Value &buffer : e->inplaceBuffers) { if (buffer2Offsets[buffer].size() != e->multiBufferNum) { - llvm::report_fatal_error( + emitMultiBufferError( "multi-buffer offset count mismatch in UpdateBuffer2Offsets"); + return; } } continue; @@ -2186,35 +2212,46 @@ bool MemPlan::VerifyConflictStage1(MemBoundList &outline, PlanRecHis &his, return true; } - StorageEntry *multiRelationPongEntry = - GetMultiRelationPongEntry(reuseBoundStorageEntry); - if (multiRelationPongEntry) { - bool hasRelationSlots = - !e->relationOtherBuffers.empty() && - llvm::all_of(e->relationOtherBuffers, [](StorageEntry *s) { - return s->bitsOffset != 0; + // C2: HIVM allows the SPEC_LEVEL_1 reuse to anchor on ANY of the + // historical multi-buffer slots, not only the last one. Collect all + // candidate anchor offsets (in HIVM-doc order: slot1, slot2, ..., + // slotN-1, plus the legacy single-extra-pong slot when present), then + // pick the first that has no life-conflict with history. + SmallVector candidateAnchors = + CollectMultiRelationPongAnchors(reuseBoundStorageEntry); + if (candidateAnchors.empty()) { + return true; + } + + bool hasRelationSlots = + !e->relationOtherBuffers.empty() && + llvm::all_of(e->relationOtherBuffers, [](StorageEntry *s) { + return s->bitsOffset != 0; + }); + if (!(e->multiBufferNum == kSingleBufferCount || + (e->multiBufferNum > 1 && hasRelationSlots))) { + return true; + } + auto parentLoop1 = GetBufferParentLoop(e->inplaceBuffers); + auto parentLoop2 = + GetBufferParentLoop(reuseBoundStorageEntry->inplaceBuffers); + if (!(parentLoop1 != nullptr && parentLoop2 != nullptr && + parentLoop1 == parentLoop2)) { + // Cannot be reused under the same for. + return true; + } + + // There are two situations: + // 1. Single reuse DB. + // 2. DB reuse DB. + for (uint64_t anchor : candidateAnchors) { + bool conflict = std::any_of( + his.begin(), his.end(), [anchor, e, this](PlanRecord &r) { + return this->IsBufferLifeVecConflict(r, anchor, e); }); - if (e->multiBufferNum == kSingleBufferCount || - (e->multiBufferNum > 1 && hasRelationSlots)) { - auto parentLoop1 = GetBufferParentLoop(e->inplaceBuffers); - auto parentLoop2 = - GetBufferParentLoop(reuseBoundStorageEntry->inplaceBuffers); - if (!(parentLoop1 != nullptr && parentLoop2 != nullptr && - parentLoop1 == parentLoop2)) { - // Cannot be reused under the same for. - return true; - } - // There are two situations: - // 1. Single reuse DB. - // 2. DB reuse DB. - pongOffset = multiRelationPongEntry->bitsOffset; - bool conflict = std::any_of( - his.begin(), his.end(), [pongOffset, e, this](PlanRecord &r) { - return this->IsBufferLifeVecConflict(r, pongOffset, e); - }); - if (!conflict) { - return false; - } + if (!conflict) { + pongOffset = anchor; + return false; } } return true; @@ -2239,6 +2276,28 @@ MemPlan::GetMultiRelationPongEntry(const StorageEntry *reuseBoundStorageEntry) { return nullptr; } +SmallVector +MemPlan::CollectMultiRelationPongAnchors( + const StorageEntry *reuseBoundStorageEntry) { + // C2: HIVM enumerates *every* historical multi-buffer relation slot as a + // SPEC_LEVEL_1 reuse anchor candidate. PR-615 only returned the last slot, + // which silently dropped reuse opportunities for N>2. + SmallVector anchors; + if (reuseBoundStorageEntry->multiBufferNum > 1) { + for (StorageEntry *slot : reuseBoundStorageEntry->relationOtherBuffers) { + if (slot && slot->bitsOffset != 0) + anchors.push_back(slot->bitsOffset); + } + } + // Legacy single-extra-pong slot used when a single-buffer history entry + // had previously been reused with a DB. + auto iter = pingEntry2RelationPongEntry.find(reuseBoundStorageEntry); + if (iter != pingEntry2RelationPongEntry.end()) { + anchors.push_back(iter->second->bitsOffset); + } + return anchors; +} + void MemPlan::SpecAllocRelationPongEntry(MemBoundList &outline, PlanRecHis &his, StorageEntry *e, uint64_t offset) { SmallVector targets; @@ -2271,8 +2330,16 @@ void MemPlan::SpecAllocRelationPongEntry(MemBoundList &outline, PlanRecHis &his, break; } } - if (!placed) - llvm::report_fatal_error("pong storage entry outline not found"); + if (!placed) { + // D4: previously a fatal error; under heavy multi-buffer pressure this + // can fire when the pong slot's pre-computed offset has no matching + // memory bound left after rollback. Surface it as a recoverable + // diagnostic so PlanMemoryPass can retry with another seed. + emitMultiBufferError( + "pong storage entry outline not found " + "(SPEC_LEVEL_1 reuse-db slot placement failed)"); + return; + } } } diff --git a/lib/PTO/Transforms/PTOPlanMemory.h b/lib/PTO/Transforms/PTOPlanMemory.h index c7ba142e7..b8162aa4b 100644 --- a/lib/PTO/Transforms/PTOPlanMemory.h +++ b/lib/PTO/Transforms/PTOPlanMemory.h @@ -681,6 +681,15 @@ class MemPlan { const OutlineSectionInfo &outlineInfo, uint64_t &pongOffset); + /// Enumerate every candidate reuse anchor offset on a historical + /// multi-buffer entry: each non-zero `relationOtherBuffers` slot, plus the + /// legacy `pingEntry2RelationPongEntry` single-extra slot if present. + /// VerifyConflictStage1 walks the resulting list in order to find the first + /// life-conflict-free anchor (HIVM-style; PR-615 only used the last slot + /// and missed reuse opportunities for N>2). + SmallVector + CollectMultiRelationPongAnchors(const StorageEntry *reuseBoundStorageEntry); + /// Check if e1 and e2 have any pipe conflict, regardless of loop scope. /// Cached in `conflictMap` to avoid recomputing the cartesian product of /// inplace buffers on each query. @@ -911,6 +920,18 @@ class MemPlan { /// The device's SCALING storage size int scalingSpaceSize{0}; + /// Set by `emitMultiBufferError` whenever a multi-buffer-specific invariant + /// fails (e.g., slot-order mismatch in UpdateBuffer2Offsets or pong outline + /// not found). `plan()` checks this and converts it into a `failure()` + /// instead of aborting via `report_fatal_error`, so the PlanMemoryPass retry + /// loop can recover with a fresh seed. + bool multiBufferDiagnosticEmitted_{false}; + + /// Emit a multi-buffer-specific error against `func_` and arm the failure + /// flag. Use in code paths that previously called `llvm::report_fatal_error` + /// for invariants that may genuinely happen under heavy multi-buffer memory + /// pressure. When `func_` is unavailable, fall back to llvm errs. + void emitMultiBufferError(const llvm::Twine &msg); }; } // namespace pto diff --git a/test/lit/pto/multi_buffer_nested_loop.pto b/test/lit/pto/multi_buffer_nested_loop.pto new file mode 100644 index 000000000..935c381ce --- /dev/null +++ b/test/lit/pto/multi_buffer_nested_loop.pto @@ -0,0 +1,47 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ptoas --enable-insert-sync --enable-multi-buffer-lowering --mlir-print-ir-after=pto-enable-multi-buffer %s 2>&1 1>/dev/null | FileCheck %s + +// B1 + B5 regression: multi-buffer alloc inside the INNER loop of a nested +// for-nest. Slot rotation must use the inner.iv (the loop the back-edge +// crosses), NOT the outer.iv. Two multi-buffer pointer_casts inside the same +// inner loop with the same N=2 must share a single `iv mod 2` counter. + +module { + func.func @nested_double(%arg0: memref<8x8x8xf16, #pto.address_space>, + %arg1: memref<8x8x8xf16, #pto.address_space>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + scf.for %i = %c0 to %c4 step %c1 { + scf.for %j = %c0 to %c4 step %c1 { + %a = memref.alloc() {pto.multi_buffer = 2 : i32} + : memref<8x8x8xf16, #pto.address_space> + pto.tload ins(%arg0 : memref<8x8x8xf16, #pto.address_space>) + outs(%a : memref<8x8x8xf16, #pto.address_space>) + pto.tstore ins(%a : memref<8x8x8xf16, #pto.address_space>) + outs(%arg1 : memref<8x8x8xf16, #pto.address_space>) + } + } + return + } +} + +// CHECK: IR Dump After PTOEnableMultiBuffer +// CHECK: func.func @nested_double +// Outer loop opens first. +// CHECK: scf.for %[[OUTER:.*]] = +// Single-address slot casts hoisted just above the INNER loop (the alloc's +// directly-enclosing scf.for is `inner`, not `outer`). +// CHECK: pto.pointer_cast(%{{[^,)]+}}) : +// CHECK: pto.pointer_cast(%{{[^,)]+}}) : +// CHECK: scf.for %[[INNER:.*]] = +// Counter must be on the INNER iv (B1: that's the back-edge loop for %a). +// CHECK: arith.remui %[[INNER]], %{{.*}} : index +// CHECK: arith.select From a6b7b01f532fdd608b7172ccc36e41d97faf4469 Mon Sep 17 00:00:00 2001 From: zhangstevenunity <128771452+zhangstevenunity@users.noreply.github.com> Date: Sun, 3 May 2026 15:25:24 +0800 Subject: [PATCH 06/10] multi-buffer: require actual pipe conflict in SPEC_LEVEL_2 filter `VerifyConflictStage2`'s contract (per its own comment) is "block reuse only when buffers (a) share a parent loop AND (b) have a DMA pipe conflict". `PipeConflictInSameLoop` violated that on two counts: 1. After confirming same parent loop, it returned true unconditionally instead of querying `PipeConflict` for the actual DMA pipe-conflict relation. SPEC_LEVEL_2 was therefore stricter than SPEC_LEVEL_3 semantically: any same-loop pair was rejected, even when no DMA pipe conflict existed. 2. `GetBufferParentLoop` returns nullptr for top-level buffers; two top-level buffers compared parentLoop1 == parentLoop2 == nullptr, so the "different loops -> allow reuse" early-return was bypassed and the function fell through to "return true". Almost every cross-buffer pair at function scope was getting marked as conflicting, blocking valid reuse and causing local-memory planning to fail/overflow in cases that previously fit. Fix: reject the nullptr case explicitly, then fall through to the real `PipeConflict` query before declaring a conflict. 166/166 lit tests still pass; the change can only widen the set of allowed reuses, never narrow it. Co-Authored-By: Claude --- lib/PTO/Transforms/PTOPlanMemory.cpp | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/lib/PTO/Transforms/PTOPlanMemory.cpp b/lib/PTO/Transforms/PTOPlanMemory.cpp index 36e70a645..1f19bb2df 100644 --- a/lib/PTO/Transforms/PTOPlanMemory.cpp +++ b/lib/PTO/Transforms/PTOPlanMemory.cpp @@ -2456,15 +2456,25 @@ bool MemPlan::PipeConflictInSameLoop(const StorageEntry *e1, if (e1 == nullptr || e2 == nullptr) { return false; } - // Only treat the conflict as fatal when both entries hang off the same - // parent loop. Distinct loops (or top-level buffers) are deliberately - // permitted to share an offset under SPEC_LEVEL_2. + // SPEC_LEVEL_2 only blocks reuse when buffers (a) share a parent loop AND + // (b) actually pipe-conflict on the DMA path. Two earlier issues: + // * The function name and `VerifyConflictStage2`'s comment promise an + // "and pipe-conflict" check, but the body returned true unconditionally + // for same-loop pairs - i.e., loop co-location alone aborted reuse. + // * `GetBufferParentLoop` returns nullptr for top-level buffers; two + // top-level buffers both yield nullptr and compare equal, so every + // cross-buffer pair at function scope was getting marked as conflicting. + // Reject the nullptr case so top-level pairs fall through to the + // "different loops" branch and are allowed to share an offset. auto parentLoop1 = GetBufferParentLoop(e1->inplaceBuffers); auto parentLoop2 = GetBufferParentLoop(e2->inplaceBuffers); + if (!parentLoop1 || !parentLoop2) { + return false; + } if (parentLoop1 != parentLoop2) { return false; } - return true; + return PipeConflict(e1, e2, pipeDmaConflictMap); } bool MemPlan::PipeConflict(const StorageEntry *e1, const StorageEntry *e2, From 868d1a48d426a19dab2765f30191a64b797added Mon Sep 17 00:00:00 2001 From: zhangstevenunity <128771452+zhangstevenunity@users.noreply.github.com> Date: Sun, 3 May 2026 21:46:52 +0800 Subject: [PATCH 07/10] multi-buffer/sync: drop redundant same-pipe back-edge PIPE_BARRIER After merging upstream/main, multi-buffer-enabled loops emitted spurious `pto.barrier ` and `pto.barrier ` ops in addition to the multi-buffer `set_flag_dyn` / `wait_flag_dyn` pair, e.g.: pto.wait_flag_dyn[, , %2] pto.barrier <- redundant pto.tload ; (MTE2) pto.set_flag[, , ] pto.wait_flag[, , ] pto.barrier <- redundant pto.tstore ; (MTE3) pto.set_flag_dyn[, , %2] Both barriers came from `InsertSyncOperation`'s same-pipe back-edge branch unconditionally emitting `PIPE_BARRIER` for any cross-iteration dep on a single pipe. Two of these were truly redundant: 1. Same-pipe back-edge whose dep is multi-buffer eligible. The whole point of multi-buffer is that consecutive iterations land in disjoint physical slots, so the cross-iter "same-address" dep is fundamentally false; the multi-buffer dyn flag pair already does the cross-iter ordering. No barrier needed. 2. Same-pipe back-edge on a DMA pipe (MTE1/MTE2/MTE3/MTE4/MTE5). DMA pipes are simple in-order command queues; the hardware itself serializes consecutive ops on the same DMA pipe across iterations. PIPE_BARRIER on a DMA pipe is a no-op and just clutters the IR. Fix: in `InsertSyncOperation`, compute the multi-buffer eventIdNum once (also resolves the back-edge scf.for so the cross-pipe path can reuse it - small refactor of the existing logic) and skip the barrier when either condition above holds. PIPE_M / PIPE_V keep the conservative PIPE_BARRIER to preserve the "bar_m" / "bar_v" intra-pipe semantics that higher-level frontends rely on (the existing comment in `IsNoNeedToInsertSync` calls this out). 168/168 lit pass, including all multi-buffer regressions and the upstream-merged comm/sync tests. Co-Authored-By: Claude --- .../InsertSync/InsertSyncAnalysis.cpp | 45 ++++++++++++++----- 1 file changed, 33 insertions(+), 12 deletions(-) diff --git a/lib/PTO/Transforms/InsertSync/InsertSyncAnalysis.cpp b/lib/PTO/Transforms/InsertSync/InsertSyncAnalysis.cpp index 2989430df..c093e0a82 100644 --- a/lib/PTO/Transforms/InsertSync/InsertSyncAnalysis.cpp +++ b/lib/PTO/Transforms/InsertSync/InsertSyncAnalysis.cpp @@ -347,7 +347,40 @@ void InsertSyncAnalysis::InsertSyncOperation( PipelineType nowPipe = nowCompound->kPipeValue; PipelineType frontPipe = frontCompound->kPipeValue; + // Resolve the back-edge scf.for once: both branches below need it to drive + // multi-buffer-aware decisions (B1). + Operation *backEdgeForOp = nullptr; + if (forEndIndex.has_value() && forEndIndex.value() < syncIR_.size()) { + if (InstanceElement *loopElem = syncIR_[forEndIndex.value()].get()) + backEdgeForOp = loopElem->elementOp; + } + int eventIdNum = forEndIndex.has_value() + ? GetEventIdNum(depBaseMemInfosVec, backEdgeForOp) + : 1; + + // DMA pipes (MTE1/MTE2/MTE3) are simple in-order DMA command queues; the + // hardware itself serializes consecutive ops on the same DMA pipe regardless + // of iteration boundary. PIPE_M / PIPE_V keep the existing PIPE_BARRIER + // path because the matrix/vector units may have intra-pipe parallelism + // ("bar_m" / "bar_v" frontend expectation). + auto isDmaPipe = [](PipelineType p) { + return p == PipelineType::PIPE_MTE1 || p == PipelineType::PIPE_MTE2 || + p == PipelineType::PIPE_MTE3 || p == PipelineType::PIPE_MTE4 || + p == PipelineType::PIPE_MTE5; + }; + if (nowPipe == frontPipe) { + // Same-pipe back-edge dep: skip PIPE_BARRIER when: + // 1. The dep is multi-buffer eligible (different slots in different + // iterations - the cross-iter dep is fundamentally false). + // 2. The pipe is a DMA pipe (HW DMA queue is strictly in-order, so + // cross-iter same-pipe ordering is guaranteed by HW alone). + // For PIPE_M / PIPE_V keep the conservative PIPE_BARRIER to preserve + // bar_m / bar_v intra-pipe semantics expected by higher-level frontends. + if (forEndIndex.has_value() && (eventIdNum >= 2 || isDmaPipe(nowPipe))) { + // No barrier needed. + return; + } unsigned insertBarrierId = nowCompound->GetIndex(); auto barrierOp = std::make_unique( SyncOperation::TYPE::PIPE_BARRIER, frontPipe, nowPipe, syncIndex_, @@ -372,19 +405,7 @@ void InsertSyncAnalysis::InsertSyncOperation( setOp->SetDepSyncIRIndex(frontCompound->GetIndex()); waitOp->SetDepSyncIRIndex(frontCompound->GetIndex()); - // Back-edge dependencies may require multi-buffer event IDs. Resolve the - // owning scf.for so GetEventIdNum can verify that the dep buffer rotates - // on the right loop's induction variable (B1). if (forEndIndex.has_value()) { - Operation *backEdgeForOp = nullptr; - if (forEndIndex.value() < syncIR_.size()) { - InstanceElement *loopElem = syncIR_[forEndIndex.value()].get(); - if (loopElem) { - // For LOOP_END elements, elementOp points at the originating scf.for. - backEdgeForOp = loopElem->elementOp; - } - } - int eventIdNum = GetEventIdNum(depBaseMemInfosVec, backEdgeForOp); setOp->eventIdNum = eventIdNum; waitOp->eventIdNum = eventIdNum; } From 6a1bc22cee05e5f918e698a966de14ae26bc02cb Mon Sep 17 00:00:00 2001 From: zhangstevenunity <128771452+zhangstevenunity@users.noreply.github.com> Date: Tue, 5 May 2026 09:05:24 +0800 Subject: [PATCH 08/10] gss/multi-buffer: HIVM-style event-id deduction + dyn flag codegen Wire multi-buffer support into the GraphSyncSolver path so it matches the behaviour the InsertSync path got in earlier P0/P1/P2 commits. Mirrors the intra-core subset of `bishengir/lib/Dialect/HIVM/Transforms/GraphSyncSolver` (`SyncSolver::getEventIdInfo`, `getMultiBufferEventIdInfo`, `getMultiBufferLoop`, `SyncSolverCodeGen::getMultiBufferSelectOp`). Detection (Utility.h, SyncSolver.h/.cpp): - `EventIdInfo` (eventIdNum + multibufferLoop) carried per ConflictPair. - `getMultiBufferLoop` finds the common scf.for whose iv drives the slot rotation; reuses the per-slot overlap helper added in P0 (MemoryDependentAnalyzer::getMultiBufferSlotCount), so the same same-index-overlap / different-index-disjoint geometry rule applies. - `getMultiBufferEventIdInfo` requires every dependent pair to agree on N and share the same scf.for; failure -> single-buffer. - `getEventIdInfo` is the top-level wrapper with the backward-only gate (mirrors HIVM's "only optimise back-edge deps"). Solver wiring (handleSetWaitConflict): - Threads rwOp1/rwOp2 in so MB info is computable. - Allocates N event ids via the existing Welsh-Powell coloring node (EventIdNode already supports eventIdNum > 1). - Falls back to N=1, then PIPE_ALL barrier, if N ids cannot be coloured. Codegen (SyncSolverCodeGen.h/.cpp): - New `emitMultiBufferSetWait` emits the HIVM/InsertSync output shape: pre-loop N pto.set_flag (queue prime), in-body iv-mod-N counter + N-way arith.select chain + pto.{wait,set}_flag_dyn, post-loop N pto.wait_flag (queue drain). The dyn ops live INSIDE the loop body so they share the selector's dominance (GSS's default backward-sync anchors are at the loop boundary, which works for single-buffer but not for a per-iteration selector). - `loop2BufferCounter_` cache reuses one `iv mod N` across multiple ConflictPairs sharing a (loop, N) tuple. Test: - New `multi_buffer_gss_dyn_event_id.pto` is the GSS counterpart of the existing InsertSync N=2 regression; checks pre-loop primes, the selector chain, dyn set/wait, and post-loop drains. 191/191 lit pass. Co-Authored-By: Claude --- .../Transforms/GraphSyncSolver/SyncSolver.h | 16 +- .../GraphSyncSolver/SyncSolverCodeGen.h | 17 ++ .../PTO/Transforms/GraphSyncSolver/Utility.h | 29 ++++ .../Transforms/GraphSyncSolver/SyncSolver.cpp | 158 +++++++++++++++++- .../GraphSyncSolver/SyncSolverCodeGen.cpp | 92 +++++++++- .../lit/pto/multi_buffer_gss_dyn_event_id.pto | 56 +++++++ 6 files changed, 357 insertions(+), 11 deletions(-) create mode 100644 test/lit/pto/multi_buffer_gss_dyn_event_id.pto diff --git a/include/PTO/Transforms/GraphSyncSolver/SyncSolver.h b/include/PTO/Transforms/GraphSyncSolver/SyncSolver.h index 67f4f9f61..1f1306de4 100644 --- a/include/PTO/Transforms/GraphSyncSolver/SyncSolver.h +++ b/include/PTO/Transforms/GraphSyncSolver/SyncSolver.h @@ -85,7 +85,21 @@ class Solver { void handleBarrierConflict(Occurrence *occ1, Occurrence *occ2, CorePipeInfo src, CorePipeInfo dst); void handleSetWaitConflict(Occurrence *occ1, Occurrence *occ2, - CorePipeInfo src, CorePipeInfo dst); + CorePipeInfo src, CorePipeInfo dst, + RWOperation *rwOp1 = nullptr, + RWOperation *rwOp2 = nullptr); + + // Multi-buffer event-id deduction (HIVM-style, intra-core only): + // - getMultiBufferLoop: find the common scf.for whose iv the slot + // rotation will key on. + // - getMultiBufferEventIdInfo: per-pair MB eligibility check + N derivation + // via `getMultiBufferSlotCount`. + // - getEventIdInfo: top-level wrapper. Backward-only gate, then + // MB check, default to single-buffer (N=1). + scf::ForOp getMultiBufferLoop(RWOperation *rwOp1, RWOperation *rwOp2); + EventIdInfo getMultiBufferEventIdInfo(RWOperation *rwOp1, RWOperation *rwOp2); + EventIdInfo getEventIdInfo(Occurrence *occ1, Occurrence *occ2, + RWOperation *rwOp1, RWOperation *rwOp2); // Conservative memory dependence between two RWOperation; produces // (setCorePipeInfo, waitCorePipeInfo) tuples for each detected conflict. diff --git a/include/PTO/Transforms/GraphSyncSolver/SyncSolverCodeGen.h b/include/PTO/Transforms/GraphSyncSolver/SyncSolverCodeGen.h index f5c36bc55..e9b2f6fac 100644 --- a/include/PTO/Transforms/GraphSyncSolver/SyncSolverCodeGen.h +++ b/include/PTO/Transforms/GraphSyncSolver/SyncSolverCodeGen.h @@ -22,8 +22,12 @@ #include "PTO/Transforms/GraphSyncSolver/SyncSolverIR.h" #include "PTO/Transforms/GraphSyncSolver/Utility.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Value.h" +#include "llvm/ADT/DenseMap.h" #include +#include namespace mlir { namespace pto { @@ -58,6 +62,19 @@ class CodeGenerator { bool insertAfter); void insertBarrier(IRRewriter &rewriter, OperationBase *anchor, PIPE pipe, bool insertAfter); + + // Multi-buffer codegen helpers (HIVM-aligned). + // emitMultiBufferSetWait: for a ConflictPair with eventIdNum > 1 emits a + // dyn-event-id (`pto.set_flag_dyn` / `pto.wait_flag_dyn`) pair driven by an + // `iv mod N` selector + arith.select chain over the assigned event ids. + void emitMultiBufferSetWait(IRRewriter &rewriter, ConflictPair *cp); + + // Reuse the same `iv mod N` counter across multiple ConflictPairs that + // share a (loop, N) tuple (mirrors PTOEnableMultiBuffer's loop2BufferCounter + // and the InsertSync SyncCodegen cache). + Value getOrCreateLoopCounter(IRRewriter &rewriter, scf::ForOp forOp, + int64_t n, Location loc); + llvm::DenseMap, Value> loop2BufferCounter_; }; } // namespace syncsolver diff --git a/include/PTO/Transforms/GraphSyncSolver/Utility.h b/include/PTO/Transforms/GraphSyncSolver/Utility.h index 072efc326..83fb7dab2 100644 --- a/include/PTO/Transforms/GraphSyncSolver/Utility.h +++ b/include/PTO/Transforms/GraphSyncSolver/Utility.h @@ -26,6 +26,8 @@ #include "PTO/IR/PTO.h" #include "PTO/Transforms/GraphSyncSolver/SyncSolverIR.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Interfaces/LoopLikeInterface.h" #include "llvm/ADT/DenseMapInfo.h" #include "llvm/ADT/SmallVector.h" #include @@ -100,6 +102,25 @@ struct SyncSolverOptions { struct EventIdNode; struct ConflictPair; +// Per-conflict outcome of multi-buffer event-id deduction. Mirrors HIVM's +// `EventIdInfo` (intra-core fields only - cross-core unroll/preload are out +// of scope for the PTOAS port). +struct EventIdInfo { + // Number of distinct event ids needed per back-edge sync. 1 means "single + // buffer" and the existing single-id coloring path is used. + int64_t eventIdNum{1}; + // Loop whose induction variable drives the `iv mod N` slot selector at + // codegen. Null when `eventIdNum == 1`. + scf::ForOp multibufferLoop{nullptr}; + + EventIdInfo() = default; + explicit EventIdInfo(int64_t n) : eventIdNum(n) {} + EventIdInfo(int64_t n, scf::ForOp loop) + : eventIdNum(n), multibufferLoop(loop) {} + + bool isMultiBuffer() const { return eventIdNum > 1 && multibufferLoop; } +}; + // One DFS appearance of an OperationBase in the syncIr stream. struct Occurrence { OperationBase *op{nullptr}; @@ -172,7 +193,15 @@ struct ConflictPair { bool isBarrierAll{false}; // fallback marker: emit pto.barrier EventIdNode *eventIdNode{nullptr}; + // Snapshot of event ids assigned to this pair, taken when CodeGenerator is + // constructed (after which eventIdNode is no longer safe to read because + // EventIdSolver may have torn down its nodes). Upstream fix for the GSS + // event-id-lifetime bug. llvm::SmallVector eventIds; + // Multi-buffer geometry chosen for this candidate (default = single buffer). + // When `eventIdInfo.isMultiBuffer()`, codegen emits dyn flag ops with an + // `iv mod N` arith.select chain over the assigned event ids. + EventIdInfo eventIdInfo; ConflictPair(RWOperation *op1, RWOperation *op2, OperationBase *setOp, OperationBase *waitOp, Occurrence *setOcc, Occurrence *waitOcc, diff --git a/lib/PTO/Transforms/GraphSyncSolver/SyncSolver.cpp b/lib/PTO/Transforms/GraphSyncSolver/SyncSolver.cpp index 986e344b2..bf10b699d 100644 --- a/lib/PTO/Transforms/GraphSyncSolver/SyncSolver.cpp +++ b/lib/PTO/Transforms/GraphSyncSolver/SyncSolver.cpp @@ -13,6 +13,7 @@ #include "PTO/Transforms/GraphSyncSolver/Utility.h" #include "PTO/Transforms/InsertSync/MemoryDependentAnalyzer.h" #include "PTO/Transforms/InsertSync/SyncCommon.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" #include @@ -101,6 +102,118 @@ bool Solver::checkSkipParallelLoop(Occurrence *o1, Occurrence *o2) { return llvm::cast(loopOcc->op)->isParallel; } +// ---- multi-buffer event-id deduction (HIVM-style, intra-core only) -------- + +namespace { +// Find the nearest enclosing scf.for of an SSA value's defining op (or its +// parent block when the value is a block argument). +static scf::ForOp getEnclosingScfForGss(Value v) { + if (!v) + return nullptr; + Operation *op = v.getDefiningOp(); + if (!op) { + if (Block *b = v.getParentBlock()) + op = b->getParentOp(); + } + while (op) { + if (auto forOp = dyn_cast(op)) + return forOp; + op = op->getParentOp(); + } + return nullptr; +} +} // namespace + +scf::ForOp Solver::getMultiBufferLoop(RWOperation *rwOp1, RWOperation *rwOp2) { + // Mirrors HIVM `Solver::getMultiBufferLoop`: every dependency pair must + // share the *same* common scf.for (otherwise the iv-mod-N selector at + // codegen would key on the wrong loop). We use `baseBuffer` as the SSA + // anchor since `rootBuffer` for a pto.pointer_cast is the i64 base + // address at function top (see InsertSyncAnalysis::GetEventIdNum P0 fix). + scf::ForOp common; + auto pickLoop = [&](const llvm::SmallVector &as, + const llvm::SmallVector &bs) -> bool { + for (auto *a : as) { + for (auto *b : bs) { + unsigned n = memAnalyzer_.getMultiBufferSlotCount(a, b); + if (n < 2) + continue; + auto la = getEnclosingScfForGss(a->baseBuffer); + auto lb = getEnclosingScfForGss(b->baseBuffer); + if (!la || la != lb) + return false; + if (!common) + common = la; + else if (common != la) + return false; + } + } + return true; + }; + if (!pickLoop(rwOp1->readMemInfo, rwOp2->writeMemInfo)) + return nullptr; + if (!pickLoop(rwOp1->writeMemInfo, rwOp2->readMemInfo)) + return nullptr; + if (!pickLoop(rwOp1->writeMemInfo, rwOp2->writeMemInfo)) + return nullptr; + return common; +} + +EventIdInfo Solver::getMultiBufferEventIdInfo(RWOperation *rwOp1, + RWOperation *rwOp2) { + // Mirrors `checkMultiBufferEventIdInfo` + `getMultiBufferEventIdInfo`: + // 1. All conflict pairs must agree on slot count N >= 2. + // 2. All involved buffers must hang off the same scf.for. + // 3. N is the common slot count (small enough to fit MAX_MULTI_BUFFER_NUM). + // Returns single-buffer EventIdInfo() on any failure. + if (!rwOp1 || !rwOp2) + return {}; + + unsigned commonN = 0; + auto checkPair = [&](const llvm::SmallVector &as, + const llvm::SmallVector &bs) -> bool { + for (auto *a : as) { + for (auto *b : bs) { + unsigned n = memAnalyzer_.getMultiBufferSlotCount(a, b); + if (n < 2) + continue; + if (commonN == 0) + commonN = n; + else if (commonN != n) + return false; + } + } + return true; + }; + if (!checkPair(rwOp1->readMemInfo, rwOp2->writeMemInfo)) + return {}; + if (!checkPair(rwOp1->writeMemInfo, rwOp2->readMemInfo)) + return {}; + if (!checkPair(rwOp1->writeMemInfo, rwOp2->writeMemInfo)) + return {}; + if (commonN < 2 || commonN > MAX_MULTI_BUFFER_NUM) + return {}; + + scf::ForOp loop = getMultiBufferLoop(rwOp1, rwOp2); + if (!loop) + return {}; + return EventIdInfo(static_cast(commonN), loop); +} + +EventIdInfo Solver::getEventIdInfo(Occurrence *occ1, Occurrence *occ2, + RWOperation *rwOp1, RWOperation *rwOp2) { + // HIVM `Solver::getEventIdInfo`: backward-only gate, then MB deduction, + // default to single-buffer (eventIdNum = 1). + if (!occ1 || !occ2 || !rwOp1 || !rwOp2) + return EventIdInfo(1); + if (!isBackwardSync(occ1, occ2)) + return EventIdInfo(1); + EventIdInfo info = getMultiBufferEventIdInfo(rwOp1, rwOp2); + if (info.isMultiBuffer()) + return info; + return EventIdInfo(1); +} + bool Solver::isBackwardSync(Occurrence *o1, Occurrence *o2) { // Backward = the two occurrences live under the same parent occurrence // *Loop* in two different "iteration copies" produced by syncIrBuilder, or @@ -230,7 +343,7 @@ void Solver::handleConflict(Occurrence *o1, Occurrence *o2, RWOperation *r1, if (src == dst) handleBarrierConflict(o1, o2, src, dst); else - handleSetWaitConflict(o1, o2, src, dst); + handleSetWaitConflict(o1, o2, src, dst, r1, r2); } bool Solver::checkGraphConflict(Occurrence *o1, Occurrence *o2, @@ -315,7 +428,8 @@ void Solver::handleBarrierConflict(Occurrence *o1, Occurrence *o2, } void Solver::handleSetWaitConflict(Occurrence *o1, Occurrence *o2, - CorePipeInfo src, CorePipeInfo dst) { + CorePipeInfo src, CorePipeInfo dst, + RWOperation *rwOp1, RWOperation *rwOp2) { auto [setOcc, waitOcc] = getSetWaitOcc(o1, o2); assert(setOcc && waitOcc); @@ -325,22 +439,48 @@ void Solver::handleSetWaitConflict(Occurrence *o1, Occurrence *o2, setOcc->op, waitOcc->op, setOcc, waitOcc, src, dst, setOcc->endIndex, waitOcc->startIndex); + // Multi-buffer event-id deduction (HIVM-style). For backward-edge deps that + // pass the per-slot overlap check, allocate N event ids so codegen can + // rotate through them with iv mod N. Falls back to single-buffer (N=1) on + // any failure. + EventIdInfo info = getEventIdInfo(o1, o2, rwOp1, rwOp2); + cp->eventIdInfo = info; + int64_t requestedN = info.eventIdNum; + // Speculatively color: try inserting this candidate into the EventIdSolver - // and roll back if the graph would exceed the hardware budget. + // and roll back if the graph would exceed the hardware budget. For + // multi-buffer the node carries N colors; the existing Welsh-Powell path + // already handles eventIdNum > 1. auto *colorer = getEventIdSolver(src.pipe, dst.pipe); colorer->pushActionNone(); - cp->eventIdNode = colorer->createNode(cp.get(), /*eventIdNum=*/1); + cp->eventIdNode = colorer->createNode(cp.get(), requestedN); std::vector intersecting = getIntersectingConflictPairs(cp.get()); colorer->addConflicts(cp.get(), intersecting); if (!colorer->isColorable()) { - // Fallback: this is the minimal port's "no multi-strategy retry" knob. - // Drop the speculative coloring and emit a single PIPE_ALL barrier. + // Multi-buffer fallback: try collapsing to a single event id before + // giving up to a PIPE_ALL barrier. Mirrors the conservative N -> 1 + // degrade on the InsertSync path. colorer->undoActions(); - insertBarrierAllBeforeOcc(waitOcc); - return; + if (requestedN > 1) { + colorer->pushActionNone(); + cp->eventIdInfo = EventIdInfo(1); + cp->eventIdNode = colorer->createNode(cp.get(), /*eventIdNum=*/1); + auto retryIntersect = getIntersectingConflictPairs(cp.get()); + colorer->addConflicts(cp.get(), retryIntersect); + if (!colorer->isColorable()) { + colorer->undoActions(); + insertBarrierAllBeforeOcc(waitOcc); + return; + } + colorer->clearActionStack(); + } else { + insertBarrierAllBeforeOcc(waitOcc); + return; + } + } else { + colorer->clearActionStack(); } - colorer->clearActionStack(); // Attach to LCA scope occurrences so future checkGraphConflict calls see it. auto [normSet, normWait] = OperationBase::getLCAPair(setOcc->op, waitOcc->op); diff --git a/lib/PTO/Transforms/GraphSyncSolver/SyncSolverCodeGen.cpp b/lib/PTO/Transforms/GraphSyncSolver/SyncSolverCodeGen.cpp index 0c3176f81..35dc2df0d 100644 --- a/lib/PTO/Transforms/GraphSyncSolver/SyncSolverCodeGen.cpp +++ b/lib/PTO/Transforms/GraphSyncSolver/SyncSolverCodeGen.cpp @@ -11,11 +11,14 @@ #include "PTO/IR/PTO.h" #include "PTO/Transforms/GraphSyncSolver/SyncSolverIR.h" #include "PTO/Transforms/GraphSyncSolver/Utility.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Operation.h" #include "mlir/IR/PatternMatch.h" #include "llvm/Support/Casting.h" #include +#include #define DEBUG_TYPE "pto-graph-sync-solver-codegen" @@ -123,6 +126,88 @@ void CodeGenerator::insertBarrier(IRRewriter &rewriter, OperationBase *anchor, rewriter.create(loc, pipeAttr); } +Value CodeGenerator::getOrCreateLoopCounter(IRRewriter &rewriter, + scf::ForOp forOp, int64_t n, + Location loc) { + auto key = std::make_pair(forOp, n); + auto it = loop2BufferCounter_.find(key); + if (it != loop2BufferCounter_.end()) + return it->second; + rewriter.setInsertionPointToStart(forOp.getBody()); + Value iv = forOp.getInductionVar(); + Value cN = rewriter.create(loc, n); + Value rem = rewriter.create(loc, iv, cN); + loop2BufferCounter_[key] = rem; + return rem; +} + +void CodeGenerator::emitMultiBufferSetWait(IRRewriter &rewriter, + ConflictPair *cp) { + // Multi-buffer codegen mirrors the InsertSync output: + // pre-loop: N pto.set_flag (queue prime, one per event id) + // in-loop: pto.wait_flag_dyn(idx) at body start, + // pto.set_flag_dyn(idx) at body end (before yield) + // post-loop: N pto.wait_flag (queue drain) + // The dyn set/wait MUST live inside the loop body so they share the + // `iv mod N` selector's dominance. GSS's default backward-sync anchors + // (set after loop / wait before loop) only work for single-buffer where + // there is no per-iteration selector. + assert(cp); + const auto &eids = cp->eventIds; + int64_t n = cp->eventIdInfo.eventIdNum; + assert((int64_t)eids.size() >= n && n >= 2); + if ((int64_t)eids.size() < n || n < 2) + return; + scf::ForOp loop = cp->eventIdInfo.multibufferLoop; + assert(loop && "multi-buffer codegen needs a non-null rotation loop"); + if (!loop) + return; + PIPE setPipe = cp->setCorePipeInfo.pipe; + PIPE waitPipe = cp->waitCorePipeInfo.pipe; + Location loc = loop.getLoc(); + auto srcAttr = makePipe(rewriter.getContext(), setPipe); + auto dstAttr = makePipe(rewriter.getContext(), waitPipe); + + // 1. Pre-loop: queue-prime with N concrete event ids. + rewriter.setInsertionPoint(loop); + for (int64_t i = 0; i < n; ++i) { + auto eidAttr = makeEvent(rewriter.getContext(), eids[i]); + rewriter.create(loc, srcAttr, dstAttr, eidAttr); + } + + // 2. In-loop: build (or reuse) the `iv mod N` counter at the start of the + // body, then a select chain over the assigned event ids. + Value rem = getOrCreateLoopCounter(rewriter, loop, n, loc); + rewriter.setInsertionPointAfter(rem.getDefiningOp()); + Value selected = + rewriter.create(loc, eids[0]); + for (int64_t i = 1; i < n; ++i) { + Value ci = rewriter.create(loc, i); + Value eq = rewriter.create(loc, arith::CmpIPredicate::eq, + rem, ci); + Value idv = rewriter.create(loc, eids[i]); + selected = rewriter.create(loc, eq, idv, selected); + } + + // wait_flag_dyn goes at the start of the body (just after the selector), + // set_flag_dyn goes right before the terminator (yield) of the body. + rewriter.setInsertionPointAfter(selected.getDefiningOp()); + rewriter.create(loc, srcAttr, dstAttr, selected); + + Operation *terminator = loop.getBody()->getTerminator(); + if (!terminator) + return; + rewriter.setInsertionPoint(terminator); + rewriter.create(loc, srcAttr, dstAttr, selected); + + // 3. Post-loop: drain by waiting on each prime. + rewriter.setInsertionPointAfter(loop); + for (int64_t i = 0; i < n; ++i) { + auto eidAttr = makeEvent(rewriter.getContext(), eids[i]); + rewriter.create(loc, srcAttr, dstAttr, eidAttr); + } +} + void CodeGenerator::emitOne(IRRewriter &rewriter, ConflictPair *cp) { if (cp->isBarrierAll) { // Single PIPE_ALL barrier inserted just before the wait anchor. @@ -135,7 +220,12 @@ void CodeGenerator::emitOne(IRRewriter &rewriter, ConflictPair *cp) { /*insertAfter=*/false); return; } - // Normal set/wait pair. Single event id in the minimal port. + // Multi-buffer path: dyn set/wait + iv mod N selector. + if (cp->eventIdInfo.isMultiBuffer()) { + emitMultiBufferSetWait(rewriter, cp); + return; + } + // Single-buffer path: classic static set/wait pair. const auto &eids = cp->eventIds; assert(!eids.empty()); if (eids.empty()) diff --git a/test/lit/pto/multi_buffer_gss_dyn_event_id.pto b/test/lit/pto/multi_buffer_gss_dyn_event_id.pto new file mode 100644 index 000000000..78c016891 --- /dev/null +++ b/test/lit/pto/multi_buffer_gss_dyn_event_id.pto @@ -0,0 +1,56 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ptoas --enable-graph-sync-solver --enable-multi-buffer-lowering --mlir-print-ir-after=pto-graph-sync-solver %s 2>&1 1>/dev/null | FileCheck %s + +// End-to-end regression for the GraphSyncSolver multi-buffer event-id pipeline. +// Mirrors the InsertSync `multi_buffer_insert_sync_dyn_event_id.pto` test but +// drives the alternative `--enable-graph-sync-solver` path: +// +// PlanMemory -> N=2-address pto.pointer_cast on `pto.multi_buffer = 2`. +// GSS Solver -> per-pair `getMultiBufferSlotCount` >= 2 + shared scf.for +// trigger `getEventIdInfo`, allocate 2 event ids in the +// EventIdSolver coloring graph (createNode with N=2). +// GSS Codegen -> emitMultiBufferSetWait emits N=2 pre-loop set_flag, +// iv-mod-N counter + N-way arith.select, dyn set/wait inside +// the loop body, and N=2 post-loop wait_flag drains. + +module { + func.func @gss_double_buffer_dyn_event( + %arg0: memref<16x16x16xf16, #pto.address_space>, + %arg1: memref<16x16x16xf16, #pto.address_space>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + scf.for %i = %c0 to %c4 step %c1 { + %a = memref.alloc() {pto.multi_buffer = 2 : i32} + : memref<16x16x16xf16, #pto.address_space> + pto.tload ins(%arg0 : memref<16x16x16xf16, #pto.address_space>) + outs(%a : memref<16x16x16xf16, #pto.address_space>) + pto.tstore ins(%a : memref<16x16x16xf16, #pto.address_space>) + outs(%arg1 : memref<16x16x16xf16, #pto.address_space>) + } + return + } +} + +// CHECK: IR Dump After PTOGraphSyncSolver +// CHECK: func.func @gss_double_buffer_dyn_event +// Two distinct event ids primed pre-loop for the N=2 backward sync. +// CHECK: pto.set_flag[, , ] +// CHECK: pto.set_flag[, , ] +// Inside the loop: iv mod 2 + arith.select chain selecting between event ids. +// CHECK: scf.for %[[IV:.*]] = +// CHECK: arith.remui %[[IV]], %{{.*}} : index +// CHECK: arith.select +// Dynamic-event-id wait/set ops drive the multi-buffer backward sync. +// CHECK: pto.wait_flag_dyn[, , %{{.*}}] +// CHECK: pto.set_flag_dyn[, , %{{.*}}] +// Post-loop drains both event ids. +// CHECK: pto.wait_flag[, , ] +// CHECK: pto.wait_flag[, , ] From aa5ff4602e094a11391138c3fb43f87cc0813cc5 Mon Sep 17 00:00:00 2001 From: zhangstevenunity <128771452+zhangstevenunity@users.noreply.github.com> Date: Tue, 5 May 2026 15:39:04 +0800 Subject: [PATCH 09/10] Add CV preload scope marking and expansion --- docs/designs/ptoas-cv-optimization-design.md | 738 ++++++++++++++++++ include/PTO/IR/PTOOps.td | 19 + include/PTO/Transforms/Passes.h | 4 + include/PTO/Transforms/Passes.td | 79 ++ lib/PTO/Transforms/CMakeLists.txt | 2 + lib/PTO/Transforms/CVCreatePreloadPass.cpp | 374 +++++++++ .../Transforms/CVMarkPreloadScopesPass.cpp | 657 ++++++++++++++++ lib/PTO/Transforms/PTOPlanMemory.cpp | 93 ++- lib/PTO/Transforms/PTOPlanMemory.h | 19 + lib/PTO/Transforms/PTOViewToMemref.cpp | 26 +- .../pto/cv_create_preload_scope_pipeline.pto | 75 ++ test/lit/pto/cv_preload_mark_fa_tpipe.pto | 200 +++++ tools/ptoas/ptoas.cpp | 17 +- 13 files changed, 2295 insertions(+), 8 deletions(-) create mode 100644 docs/designs/ptoas-cv-optimization-design.md create mode 100644 lib/PTO/Transforms/CVCreatePreloadPass.cpp create mode 100644 lib/PTO/Transforms/CVMarkPreloadScopesPass.cpp create mode 100644 test/lit/pto/cv_create_preload_scope_pipeline.pto create mode 100644 test/lit/pto/cv_preload_mark_fa_tpipe.pto diff --git a/docs/designs/ptoas-cv-optimization-design.md b/docs/designs/ptoas-cv-optimization-design.md new file mode 100644 index 000000000..9ea517281 --- /dev/null +++ b/docs/designs/ptoas-cv-optimization-design.md @@ -0,0 +1,738 @@ +# PTOAS CV Preload 优化设计 + +## 1. FA CV 分离 Kernel:为什么需要 Preload + +PTOAS 的 CV preload 优化首先要解决 FA 这类已经手写 C/V 分离的 kernel。以本仓库 `fa_perf.pto` 的形态为例,FA 的四个主要 stage 分布在两个 kernel 中: + +| stage | kernel | 主要计算 | TPipe | +| --- | --- | --- | --- | +| `compute_qk` | Cube | `Q * K^T` matmul,产生 QK | `tpush id=25`,C2V | +| `compute_p` | Vector | softmax/exp,产生 P | `tpop id=25`,再 `tpush id=30`,V2C | +| `compute_pv` | Cube | `P * V` matmul,产生 PV | `tpop id=30`,再 `tpush id=27`,C2V | +| `compute_gu` | Vector | 累加 PV 并做最终归一化 | `tpop id=27` | + +这个结构与 pto-isa manual FA README 中的描述一致:手写 FA kernel 将计算拆成 `compute_qk`、`compute_p`、`compute_pv`、`compute_gu`,通过 inter-CV FIFO 连接 Cube 和 Vector pipeline;README 也用 `qkPreloadNum` 表示 QK 预执行深度,并让 QK/P/PV FIFO 深度随 preload 增加。README 的 tuning notes 中,`qkPreloadNum` 默认 4,`qkp_tile_fifo_size` 按 `1 + qkPreloadNum` 派生,`pv_tile_fifo_size` 也镜像这个深度。参考:[pto-isa flash_atten README](https://github.com/hw-native-sys/pto-isa/blob/main/kernels/manual/common/flash_atten/README.md)。 + +### 1.1 优化前:按同一个逻辑 tile 串行推进 + +未做 preload 时,虽然 Cube 和 Vector 已经分成两个 kernel,但同一个 S1 tile 的跨核依赖仍然形成如下链条: + +```text +Cube compute_qk(i) -> push QK(i) +Vector pop QK(i) -> compute_p(i) -> push P(i) +Cube pop P(i) -> compute_pv(i) -> push PV(i) +Vector pop PV(i) -> compute_gu(i) +``` + +用 PTOAS 的 CV 分离 IR 可以抽象成: + +```mlir +func.func @cube_kernel(...) attributes {pto.kernel_kind = #pto.kernel_kind} { + scf.for %i = %lb to %ub step %step { + // QK producer: Cube -> Vector, pipe id=25 + %qk = pto.talloc_to_aiv {id = 25, split = 0} + ... compute_qk(%i) ... + pto.tstore ... outs(%qk ...) + pto.tpush_to_aiv(%qk) {id = 25, split = 0} + + // PV relay: Vector -> Cube -> Vector, pipe id=30 then id=27 + %p = pto.tpop_from_aiv {id = 30, split = 0} + ... compute_pv(%i, %p) ... + pto.tfree_from_aiv(%p) {id = 30, split = 0} + %pv = pto.talloc_to_aiv {id = 27, split = 0} + pto.tstore ... outs(%pv ...) + pto.tpush_to_aiv(%pv) {id = 27, split = 0} + } +} + +func.func @vector_kernel(...) attributes {pto.kernel_kind = #pto.kernel_kind} { + scf.for %i = %lb to %ub step %step { + // P relay: Cube -> Vector -> Cube, pipe id=25 then id=30 + %qk = pto.tpop_from_aic {id = 25, split = 0} + ... compute_p(%i, %qk) ... + %p = pto.talloc_to_aic {id = 30, split = 0} + pto.tstore ... outs(%p ...) + pto.tfree_from_aic(%qk) {id = 25, split = 0} + pto.tpush_to_aic(%p) {id = 30, split = 0} + + // GU consumer: Cube -> Vector, pipe id=27 + %pv = pto.tpop_from_aic {id = 27, split = 1} + ... compute_gu(%i, %pv) ... + pto.tfree_from_aic(%pv) {id = 27, split = 1} + } +} +``` + +这种写法的问题不是功能错误,而是 C/V 资源会在等待对端 stage 时产生空泡: + +```text +tile i: + Cube QK(i) ---- wait P(i) ---- PV(i) ---- + Vector ---- wait QK(i) ---- P(i) ---- wait PV(i) ---- GU(i) +``` + +TPipe 的 `tpush/tpop/tfree/talloc` 已经能保证跨核同步正确,但如果 compiler 不改变 scope 的 logical iteration,两个核仍然倾向于围绕同一个 tile 做握手,难以让 QK、P、PV、GU 同时处于稳态流水中。 + +### 1.2 优化后:scope 级 preload 展开 + +preload 的核心是把四个 stage 标成不同的 `preload_num`,让 producer 在更靠前的 logical tile 上运行,让 consumer drain 更旧的 logical tile: + +| scope | kernel | role | TPipe | `preload_num` | +| --- | --- | --- | --- | --- | +| `S3_QK` | Cube | producer | `push id=25` | 3 | +| `S2_P` | Vector | relay | `pop id=25 -> push id=30` | 2 | +| `S1_PV` | Cube | relay | `pop id=30 -> push id=27` | 1 | +| `S0_GU` | Vector | consumer | `pop id=27` | 0 | + +`max_preload_num = 4` 时,`pto-cv-create-preload` 将原 loop 扩成 physical loop,并为每个 scope 使用不同的 stage IV: + +```text +stage_iv(p) = physical_iv - (max_preload_num - 1 - p) * step +``` + +展开后的概念写法如下。实际 IR 中 Cube 和 Vector 仍是两个函数,且每个函数内部保持原 loop body 的词法顺序。 + +```mlir +// cube_kernel +scf.for %t = %lb to (%ub + 4 * %step) step %step { + %i_pv = %t - 2 * %step + scf.if (%lb <= %i_pv && %i_pv < %ub) { + // S1_PV: consume P(i_pv), produce PV(i_pv) + %p = pto.tpop_from_aiv {id = 30, split = 0} + ... compute_pv(%i_pv, %p) ... + pto.tfree_from_aiv(%p) {id = 30, split = 0} + %pv = pto.talloc_to_aiv {id = 27, split = 0} + pto.tpush_to_aiv(%pv) {id = 27, split = 0} + } + + %i_qk = %t + scf.if (%lb <= %i_qk && %i_qk < %ub) { + // S3_QK: produce QK(i_qk) + %qk = pto.talloc_to_aiv {id = 25, split = 0} + ... compute_qk(%i_qk) ... + pto.tpush_to_aiv(%qk) {id = 25, split = 0} + } +} + +// vector_kernel +scf.for %t = %lb to (%ub + 4 * %step) step %step { + %i_gu = %t - 3 * %step + scf.if (%lb <= %i_gu && %i_gu < %ub) { + // S0_GU: consume PV(i_gu) + %pv = pto.tpop_from_aic {id = 27, split = 1} + ... compute_gu(%i_gu, %pv) ... + pto.tfree_from_aic(%pv) {id = 27, split = 1} + } + + %i_p = %t - %step + scf.if (%lb <= %i_p && %i_p < %ub) { + // S2_P: consume QK(i_p), produce P(i_p) + %qk = pto.tpop_from_aic {id = 25, split = 0} + ... compute_p(%i_p, %qk) ... + %p = pto.talloc_to_aic {id = 30, split = 0} + pto.tstore ... outs(%p ...) + pto.tfree_from_aic(%qk) {id = 25, split = 0} + pto.tpush_to_aic(%p) {id = 30, split = 0} + } +} +``` + +稳态下,同一个 physical iteration 会同时推进多个 logical tile: + +```text +physical t: + Cube PV(t - 2) QK(t) + Vector GU(t - 3) P(t - 1) +``` + +这样 Cube 可以提前生产未来 tile 的 QK,Vector 可以消费上一批 QK 生成 P,Cube 又消费更旧的 P 生成 PV,Vector drain 最旧的 PV。TPipe 内部同步仍然负责“数据是否真的可读/slot 是否可复用”,compiler 只负责把不同 scope 映射到不同 logical iteration,并为 local/workspace buffer 准备足够的 stage 存储。 + +### 1.3 优化效果和必要性 + +pto-isa manual FA README 用无软件流水、`qkPreloadNum=2/4` 以及仿真 pipeline 的对比说明了同一个方向:preload 的直接目标是解除跨 stage 数据依赖带来的串行化,让 Vector 计算资源尽量保持忙碌;当流水形态形成后,瓶颈会进一步暴露到 Cube 侧 TSTORE、FIFO 深度、UB/L1 buffer 等资源上。 + +preload 的效果可以从四个层面理解: + +- 减少跨核等待空泡:`tpush/tpop` 的等待不消失,但等待更容易被另一个 stage 的计算覆盖。 +- 保持 C/V 资源忙碌:QK、P、PV、GU 不再围绕同一个 tile 串行握手,而是在不同 tile 上形成稳态流水。 +- 暴露内存规划需求:同一 physical iteration 同时存在多个 logical tile 的中间结果,`plan-memory` 必须为 preload local buffer 分配多份地址,并由 create-preload 按 stage 轮转。 +- 受资源约束:更大的 preload 需要更深的 FIFO、更多 UB/L1 local buffer 和更多 workspace slot。manual FA 中 `qkPreloadNum` 会影响 QK/P/PV FIFO 深度,PTOAS 的 `max_preload_num` 必须受相同资源约束。 + +因此,PTOAS 的 CV preload pass 不只是“复制 loop body”。它需要把 TPipe transaction 识别成 scope,生成跨 C/V 一致的 `preload_num`,让 plan-memory 看到 multibuffer 需求,最后再做 loop 扩界和 scope 展开。 + +## 2. 目标和边界 + +本文描述 PTOAS 中 CV 分离代码的第一版 preload 优化设计,重点覆盖四件事: + +1. 如何从 TPipe 通信识别 CV scope。 +2. 如何根据 C/V 核之间的 producer/consumer 关系生成 `preload_num`。 +3. `plan-memory` 如何消费 preload 标注并影响 local buffer 地址规划。 +4. 如何参考 NPU IR `CreatePreload.cpp` 展开已标记的 preload loop。 + +设计参考: + +- `C:\Users\rdp\Documents\AscendNPU-IR\bishengir\lib\Dialect\HIVM\Transforms\CreatePreload.cpp` +- `C:\Users\rdp\Documents\AscendNPU-IR\bishengir\lib\Dialect\HIVM\Transforms\PlanMemory.cpp` + +本文中的 `preload` 表示 CV scope/stage 级提前执行,不限定为单个 buffer load;相关 buffer 标注和 memory plan 只是支撑这种 stage 提前执行的实现手段。 + +PTOAS 输入中 C 核和 V 核函数已经分离,跨核通信由用户显式写成 `talloc`、`tpush`、`tpop`、`tfree`。这些 TPipe API 内部天然包含跨核同步动作。因此,PTOAS 的 CV 优化不需要重新发明跨核同步原语,但必须把这些同步动作作为 scope 边界和 preload 调度约束的一部分。 + +V1 不做 subtiling,不拆分 pipe entry,不改变用户定义的 TPipe 粒度。当前实现已经完成自动 scope 识别、`preload_num` 生成、no-result `pto.cv.scope` 标注,以及 opt-in 的 `pto-cv-create-preload` 展开。create-preload 默认关闭,只有打开 `--enable-cv-create-preload` 时才让 scope 保留到 plan-memory 之后并执行展开。 + +## 3. IR 标注模型 + +第一版使用一个编译器内部的 no-result `pto.cv.scope` op 表达 CV preload scope。它是单 block region 容器,不带 operand/result,也不带 `yield` terminator。这个选择和 PTOAS 的输入形态有关:FA 这类 CV 分离 kernel 的跨 scope 数据主要通过 TPipe FIFO、外提 tilebuf/memref 或 workspace 传递,而不是通过 tensor SSA result 在 scope 之间直接传值。 + +```mlir +pto.cv.scope { + pto.tpop(...) + ... + pto.tpush(...) + pto.tfree(...) +} { + pto.cv.scope_id = 7, + pto.cv.group_id = 0, + pto.cv.preload_num = 2, + pto.cv.max_preload_num = 4, + pto.cv.core = "vector", + pto.cv.role = "relay" +} +``` + +因此,第一版 `pto.cv.scope` 有两个约束: + +- scope 内定义的 SSA value 不允许被 scope 外直接使用。若一个值需要跨 scope 使用,要么把定义外提到 scope 之外,要么通过 TPipe/workspace/local buffer 这样的显式 side effect 传递。 +- scope 间的 producer/consumer 关系不依赖 SSA def-use,而是由 `pto.cv.input_pipe`、`pto.cv.output_pipe`、`pto.cv.group_id`、`pto.cv.preload_num` 等 metadata 保留。 + +这与 NPU IR 的 tensor-level `scope.scope -> result` 方案不同。NPU IR 在 bufferization 前用 scope result 和 `scope.return` 保留跨 region tensor SSA 数据流;PTOAS 第一版选择 no-result scope,是因为 tile 多为 tilebuf/memref 或 TPipe borrowed entry,且可要求普通 tile alloc 外提。未来如果出现 scope 内生成的 SSA tile 必须被外部直接使用,再扩展 `pto.cv.scope` 的 result/yield 形式。 + +关键标注: + +| 标注 | 语义 | +| --- | --- | +| `pto.cv.scope_id` | 单个 scope 的稳定编号,用于诊断和调试。 | +| `pto.cv.group_id` | 同一个 preload 展开组。一个 group 对应一个可展开的 `scf.for` pipeline。 | +| `pto.cv.preload_num` | scope 在展开 loop 中的 stage 编号。数值越大,逻辑迭代越靠前。 | +| `pto.cv.max_preload_num` | 当前 preload group 的 stage 数。 | +| `pto.cv.core` | `cube` 或 `vector`。 | +| `pto.cv.role` | `producer`、`consumer` 或 `relay`。 | +| `pto.cv.input_pipe` | scope 消费的上游逻辑 pipe,例如 `c2v:25`;producer scope 为空。 | +| `pto.cv.output_pipe` | scope 生产的下游逻辑 pipe,例如 `v2c:30`;consumer scope 为空。 | + +buffer 侧需要额外标注: + +| 标注 | 语义 | +| --- | --- | +| `pto.multi_buffer` | 该 root local buffer 需要多份物理存储。CV preload 自动标注时数值取当前跨 C/V stage 链的 `max_preload_num`。 | +| `pto.cv.preload_workspace` | workspace/subview 需要按 stage 改写 slot 维度。 | + +当前实现中,`pto-cv-auto-mark-multi-buffer` 在 `pto-cv-mark-preload-scopes` 之前运行,复用同一套 TPipe stage 链识别逻辑,把 scope 内使用到的根 `pto.alloc_tile` / `memref.alloc` 标上 `pto.multi_buffer`。`PTOViewToMemref` 会把 `pto.alloc_tile` 上的标注透传到降低后的 `memref.alloc`、`pto.pointer_cast`、`pto.bind_tile`,其中非 level3 的 plan-memory 路径继续从 `memref.alloc` 读取该属性。 + +当前 `pto-cv-mark-preload-scopes` 生成 no-result `pto.cv.scope` 以及 scope 级 `pto.cv.*` 标注;`pto-cv-create-preload` 继续消费这些稳定语义,在 plan-memory 生成多地址 local buffer 后做 loop 扩界、guard 插入和 stage 地址旋转。 + +## 4. Scope 识别 + +### 4.1 TPipe 动作语义 + +TPipe op 不只是普通 IO,它们也是同步动作: + +| op | producer/consumer | 同步含义 | scope 边界含义 | +| --- | --- | --- | --- | +| `talloc` | producer | 等待并获取空闲 FIFO slot | producer transaction 开始。 | +| `tpush` | producer | 提交数据,通知 consumer 可读 | producer scope 的提交边界。 | +| `tpop` | consumer | 等待 producer 数据可读并获取 slot/tile | consumer scope 的获取边界。 | +| `tfree` | consumer | 释放已消费 slot,允许 producer 复用 | consumer scope 的释放边界。 | + +因此,scope 不应该只按普通计算 op 切分,而要按 TPipe transaction 的生命周期切分: + +- producer scope:从产生待发送数据的计算开始,到对应 `tpush` 结束。若是 global entry,scope 内还包含 `talloc -> tstore -> tpush`。 +- consumer scope:从 `tpop` 开始,到最后一次使用 pop 出来的 tile/entry 并执行对应 `tfree` 结束。 +- relay scope:同一核内既消费上游 pipe,又生产下游 pipe,并且中间值或 pop 出来的 entry 不能在两个 scope 间安全切开时,使用一个 relay scope 覆盖从 `tpop` 获取到 producer `tpush` 提交和 consumer `tfree` 释放都完成的整段 transaction。`tpush` 与 `tfree` 的相对顺序保持用户 IR 的词法顺序。 + +`tpush` 和 `tpop` 是最重要的边界:`tpush` 是 producer 的 commit 边界,`tpop` 是 consumer 的 acquire 边界。但同一核内同时存在 producer 和 consumer 时,不能简单用所有 `tpush/tpop` 机械切段,还必须看 borrowed value 的生命周期和数据依赖。 + +第一版实现把识别出的 stage 作为一个连续 op range 包进 `pto.cv.scope`,scope 必须覆盖“产生/消费该 pipe 数据的计算”,而不是只覆盖 `talloc/tstore/tpush` 或 `tpop/tfree` 这几条 TPipe op: + +- producer-only:从上一条 CV 边界之后的本 stage 计算开始,到匹配 `tpush` 结束。若 `talloc` 写在计算之后,scope 仍然要包含前面的 load/matmul/arith 等 producer 计算。 +- consumer-only:从 `tpop` 开始,到该 stage 的尾部结束;如果 `tfree` 后还有使用已拷贝到 local buffer 的计算或 store,也属于同一个 consumer scope。 +- relay:从上游 `tpop` 开始,到下游 `tpush` 提交且上游 `tfree` 释放都完成为止,中间包含必要的计算、`tfree`、`talloc` 和 `tpush`。`tpush` 与 `tfree` 的相对顺序保持用户 IR 的词法顺序。 + +因为当前 `pto.cv.scope` 不带 result/yield,wrap 前必须检查该 range 内定义的 SSA value 没有逃逸到 scope 外。若存在逃逸,说明该 scope 需要带 result/yield 或需要把定义外提,第一版 pass 不应强行 wrap,否则会破坏 SSA dominance。 + +### 4.2 单核内 producer/consumer 共存 + +同一个 C/V kernel loop 内可能同时有 producer scope 和 consumer scope。识别规则如下: + +1. 从每个 TPipe op 建立 transaction: + - producer transaction:同一 logical pipe 上支配 `tpush` 的 `talloc`,以及写入该 entry/tile 的计算。 + - consumer transaction:`tpop` 产生的 tile/entry,到匹配 `tfree` 之间的所有使用。 +2. 如果 consumer 的输出已经物化到普通 local buffer,且该 local buffer 不依赖 pop 出来的 borrowed entry 生命周期,可以切成 consumer scope 和 producer scope。 +3. 如果 producer 直接或间接依赖 `tpop` 得到的 tile/entry,或把 consumer/producer 切开会破坏 borrowed entry 生命周期、FIFO 提交/释放顺序、数据依赖,则合并成 relay scope。 +4. 一个 relay scope 可以同时含有多个 pipe 动作,但它对 preload 展开来说是不可再拆的 stage 单元。 + +例子: + +```mlir +%qk = pto.tpop_from_aic {id = 25} +%p = pto.softmax(%qk) +%p_entry = pto.talloc_to_aic {id = 30} +pto.tstore ... outs(%p_entry ...) +pto.tfree_from_aic(%qk) {id = 25} +pto.tpush_to_aic(%p_entry) {id = 30} +``` + +这里 vector 先消费 C2V 的 QK,再生产 V2C 的 P。如果 `%p` 的计算依赖 `%qk`,且把 consumer 与 producer 分成两个 scope 会丢失这段 transaction 的生命周期和词法顺序,则这段应标成一个 relay scope,而不是在 `tpop` 和 `tpush` 之间强拆。 + +### 4.3 跨 C/V 核配对 + +跨核关系必须以 logical pipe 为单位建立,而不是以 SSA value 为单位。logical pipe key 建议包含: + +```text +(pipe_id, direction, split, entry_kind, slot_shape, dtype) +``` + +其中 direction 是 C2V 或 V2C: + +- C2V producer:cube kernel 中的 `talloc_to_aiv/tpush_to_aiv`。 +- C2V consumer:vector kernel 中的 `tpop_from_aic/tfree_from_aic`。 +- V2C producer:vector kernel 中的 `talloc_to_aic/tpush_to_aic`。 +- V2C consumer:cube kernel 中的 `tpop_from_aiv/tfree_from_aiv`。 + +识别 pass 为每个 loop 内的 TPipe 动作建立 occurrence: + +```text +CVPipeOccurrence { + func, + core, + parent_loop, + loop_iv, + lexical_index, + logical_pipe, + op_kind, // talloc/tpush/tpop/tfree + role, // producer/consumer + scope_id +} +``` + +同一 logical pipe 上 producer occurrence 和 consumer occurrence 按 FIFO 顺序配对。配对时要同时检查: + +- producer 的 `tpush` 和 consumer 的 `tpop` 类型、shape、dtype、split 一致。 +- consumer 的 `tfree` 和 producer 的 slot 复用关系合法。 +- 同一个 loop 内不同 pipe 的 occurrence 顺序被记录下来,后续生成 `preload_num` 时不能丢失。 +- 包成 no-result `pto.cv.scope` 后,scope 间原本可能存在的 SSA def-use 不再作为主要依赖来源;跨 C/V pipeline 依赖由 `input_pipe/output_pipe` 建边,loop 内 side-effect 顺序由 scope 的词法顺序保留。 + +## 5. Preload Num 生成 + +### 5.1 基本语义 + +PTOAS 的 `preload_num` 语义与 NPU IR `CreatePreload.cpp` 对齐: + +```text +stage_iv(preload_num p) = + physical_iv - (max_preload_num - 1 - p) * step +``` + +因此: + +- `preload_num = max_preload_num - 1` 表示最靠前的 stage,使用当前 physical iteration。 +- `preload_num = 0` 表示最靠后的 drain stage,使用最旧的 logical iteration。 +- 数值越大,scope 越“提前”执行。 + +以 `max_preload_num = 4` 为例,展开后同一个 physical iteration 中的逻辑迭代关系是: + +| `preload_num` | logical iv | +| --- | --- | +| 3 | `physical_iv` | +| 2 | `physical_iv - step` | +| 1 | `physical_iv - 2 * step` | +| 0 | `physical_iv - 3 * step` | + +这点非常重要:`preload_num` 不是简单的词法顺序编号,而是 scope 在 pipeline 中相对 drain stage 的提前距离。 + +### 5.2 依赖图 + +自动生成 `preload_num` 时,先建立 scope 级图: + +```text +CVScopeNode { + scope_id, + core, + parent_loop, + lexical_range, + role, + logical_pipes, + reads, + writes, + tpipe_actions +} +``` + +图上包含四类边: + +| 边 | 来源 | 含义 | +| --- | --- | --- | +| data edge | SSA/use-def、memref alias、tile/entry 使用 | producer 的结果必须先于 consumer 使用。 | +| tpipe ready edge | `tpush -> tpop` | producer 提交的数据被对端消费。 | +| tpipe release edge | `tfree -> talloc` | consumer 释放 slot 后 producer 才能安全复用。 | +| lexical/lifetime edge | 同一 loop 内 TPipe occurrence 顺序和 borrowed value 生命周期 | 保持用户显式写出的 transaction 顺序。 | + +其中 `tpipe ready edge` 是生成 preload stage 的主要依据。若希望 producer 为 consumer 预取下一轮数据,则 producer scope 的 `preload_num` 应比该 consumer scope 大 1。对一条链: + +```text +producer -> relay -> relay -> consumer +``` + +可以得到: + +```text +producer.preload_num = 3 +relay1.preload_num = 2 +relay2.preload_num = 1 +consumer.preload_num = 0 +max_preload_num = 4 +``` + +### 5.3 生成算法 + +对每个可展开的 loop group: + +1. 收集 C/V 两侧 parent loop 中的 CV scope。 +2. 按 logical pipe FIFO 顺序配对 producer/consumer occurrence。 +3. 构建 scope dependency graph。 +4. 找到 drain scope。drain scope 是该 pipeline 中最靠后的 consumer,通常是产生当前 loop 可见最终结果或释放最后一个输入 token 的 scope。 +5. 从 drain scope 反向沿 data/tpipe ready edge 计算 stage distance。 +6. 设置: + +```text +preload_num(scope) = stage_distance_from_drain(scope) +max_preload_num = max(preload_num) + 1 +``` + +7. 对同一个 scope 中多个 pipe 动作取最大 stage distance,避免一个 scope 被分到多个 stage。 +8. 检查同一 parent loop 内是否存在相同 `preload_num` 的多个不可交换 side-effect scope。若存在,需要合并 scope 或放弃自动标注。 +9. 检查 `preload_num` 不超过可用 FIFO 深度或用户指定的 preload 深度。若需要的 stage 数大于 pipe slot 数,不能自动开启 preload。 + +### 5.4 词法顺序的影响 + +preload 展开 pass 不会重排原 loop body,它只在原词法顺序上为不同 scope 映射不同 logical iv。因此,PTOAS 生成 `preload_num` 时必须保留 loop 内不同 producer/consumer TPipe 的出现顺序。 + +例如 FA 的 CV 分离代码中可以抽象出以下跨核链: + +```text +Cube: qk_push(id=25) +Vector: qk_pop(id=25) -> p_push(id=30) +Cube: p_pop(id=30) -> pv_push(id=27) +Vector: pv_pop(id=27) +``` + +对应推荐标注: + +| scope | core | role | `preload_num` | +| --- | --- | --- | --- | +| QK producer | Cube | producer | 3 | +| QK consumer + P producer | Vector | relay | 2 | +| P consumer + PV producer | Cube | relay | 1 | +| PV consumer | Vector | consumer | 0 | + +如果 Cube loop 中 `p_pop/pv_push` 出现在 `qk_push` 之前,标注仍然可以是 `1` 后接 `3`。展开 pass 会保持这个词法顺序: + +```text +physical iteration t: + Cube P/PV relay uses logical iteration t - 2 * step + Cube QK producer uses logical iteration t +``` + +这正是需要记录 loop 内 TPipe 顺序的原因。不能简单按 `preload_num` 对 scope 排序,否则会改变用户写出的 FIFO transaction 顺序,可能破坏 TPipe 内部同步协议。 + +### 5.5 第一版实现策略 + +当前第一版已经先落地三段前置标注逻辑: + +1. `pto-cv-auto-mark-multi-buffer` + - 在创建 `pto.cv.scope` 前扫描 C/V kernel 中的 TPipe transaction。 + - 复用 producer / relay / consumer stage 链识别,计算与 scope pass 一致的 `max_preload_num`。 + - 追溯 scope 内 tile/memref operand 的 root local alloc,给 `pto.alloc_tile` 或 `memref.alloc` 写入 `pto.multi_buffer = max_preload_num`;已有用户标注不覆盖。 + +2. `pto-cv-mark-preload-scopes` + - 自动识别 scope,生成 no-result `pto.cv.scope`。 + - 生成 `group_id`、`preload_num`、`max_preload_num`、`role`、`input_pipe`、`output_pipe` 等 scope metadata。 + +3. `pto-inline-cv-preload-scopes` + - 过渡期在 frontend pipe lowering 之后把 no-result scope 展回父 block,保证未接入 preload 展开时不影响现有 codegen。 + +后续还需要补独立的 `pto-cv-verify-preload-marks`,把 scope/preload 标注的一致性诊断从 create-preload 中拆出。`pto-cv-create-preload` 已经先实现第一版,用于将已标注 scope 展开成 stage-level preload loop。 + +## 6. Plan Memory 中 preload 标注的影响 + +NPU IR 的 `PlanMemory.cpp` 对 preload 的关键处理可以概括为三步: + +1. `annotation.mark` 上的 multibuffer 属性记录到 `buffer2MultiNum`。 +2. `preload_local_buffer` / preload local alloc 被加入 `preloadBuffers`,生命周期分析在 parent `for` 上显式生成 gen/kill。 +3. memory plan 为 `multiBufferNum > 1` 的 storage entry 扩展多个地址,后续 pointer cast 携带地址数组。 + +PTOAS 需要保留同样的契约。 + +### 6.1 Local buffer 必须变成 multibuffer + +被多个 preload stage 同时使用的 local buffer 不能继续只有一个物理地址。否则展开后不同 logical iteration 的中间数据会互相覆盖。 + +规则: + +- scope 内使用到、且参与 CV preload stage 链的 root local buffer 标注 `pto.multi_buffer = max_preload_num`。 +- 对 high-level tile IR,标注先挂在 `pto.alloc_tile` 上;`PTOViewToMemref` 负责把它透传到降低后的 `memref.alloc` 或显式地址路径的 `pto.pointer_cast` / `pto.bind_tile`。 +- plan-memory 看到 `memref.alloc {pto.multi_buffer = N}` 后,为它生成 N 个可轮转地址。 +- create-preload 在展开时按 `preload_num` 旋转这些地址。 + +### 6.2 生命周期分析的特殊处理 + +普通 local buffer 的 gen/kill 可以由 MLIR liveness 直接得到;preload buffer 不行,因为展开前 IR 里只有一个 loop iteration,看不到多个 stage 同时活跃。 + +因此,plan-memory 需要像 NPU IR 一样做特殊处理: + +- 识别所有带 `pto.multi_buffer` 的 root alloc 及其 alias。 +- 普通 op 的 gen 逻辑遇到 preload buffer 时先跳过,避免把它当作单 iteration buffer。 +- 在包含 preload scope 的 parent `for` 上补充 gen/kill,使该 buffer 的生命周期覆盖整个 preload loop。 +- alias buffer 必须一起进入 gen/kill,避免 subview/cast 绕过保护。 + +这样 memory plan 才不会把不同 stage 的地址错误复用给其他 live buffer。 + +### 6.3 StorageEntry 扩展 + +plan-memory 为每个 gen buffer 建立 storage entry。若该 buffer 有 `pto.multi_buffer = N`: + +- 原 entry 保留第 0 个地址。 +- 额外创建 `N - 1` 个等价 storage entry。 +- 地址规划完成后,`buffer2Offsets[buffer]` 中包含 N 个 offset。 +- lowering 到 pointer cast 或等价 PTOAS 地址表达时,需要保留这 N 个地址。 + +create-preload 依赖这个地址数组做 stage 旋转。如果 plan-memory 没有提前扩展,preload 展开只能复制计算,不能保证 local storage 正确。 + +### 6.4 Workspace 和 TPipe FIFO + +TPipe FIFO 的 slot 管理由 TPipe API 负责,不应该被 plan-memory 当成普通 local temp 复制。特别是 consumer 侧的 reserved FIFO buffer 不是 preload local scratch,除非明确标注,否则不能按 stage duplicate。 + +但有两类 buffer 仍需要处理: + +- 编译器生成的 local scratch:如果跨 stage 活跃,按 `pto.multi_buffer = max_preload_num` 处理。 +- workspace/subview:如果需要由不同 stage 访问不同 workspace slot,标注 `pto.cv.preload_workspace`,create-preload 展开时改写 slot 维度。 + +workspace slot 的推荐计算方式: + +```text +slot = ((stage_iv - loop_lb) / loop_step) % max_preload_num +``` + +NPU IR 参考实现中使用的是 `stage_iv / step % max_preload_num`,隐含 loop lower bound 已规范化到 0。PTOAS 若不能保证 loop lb 为 0,应使用带 `loop_lb` 的形式,或在 create-preload 前先规范化 loop。 + +### 6.5 建议 pass 顺序 + +建议第一版 pipeline: + +```text +frontend TPipe IR + -> pto-cv-auto-mark-multi-buffer + -> pto-cv-mark-preload-scopes / pto-cv-verify-preload-marks + -> lower frontend pipe ops, preserve or consume cv scope before codegen + -> view/tile bufferization + -> pto-plan-memory + -> pto-resolve-reserved-buffer + -> pto-cv-create-preload + -> canonicalize/cse + -> inline/remove cv scope + -> existing sync/codegen passes +``` + +关键点是:`pto-cv-auto-mark-multi-buffer` 必须在 `pto-cv-mark-preload-scopes` 之前运行,因为它要在 high-level `pto.alloc_tile` 还没有被 scope 包裹和 lowering 前标注 root local buffer;`pto-plan-memory` 必须在 `pto-cv-create-preload` 之前看到 `pto.multi_buffer` 并生成多地址;`pto-cv-create-preload` 必须在最终 codegen 前完成 scope 展开和地址旋转。 + +当前 pipeline 中,`pto-cv-auto-mark-multi-buffer` 会生成 `pto.multi_buffer`,`pto-cv-mark-preload-scopes` 会生成显式 `pto.cv.scope`。默认路径仍在 frontend pipe lowering 之后运行 `pto-inline-cv-preload-scopes`,保证未开启 preload 展开时不影响现有 codegen;开启 `--enable-cv-create-preload` 时,scope 会保留到 `pto-plan-memory` / `pto-resolve-reserved-buffers` 之后,再由 `pto-cv-create-preload` 展开,随后才 inline/remove scope 容器。 + +## 7. Preload 展开逻辑 + +### 7.1 输入约束 + +`pto-cv-create-preload` 的输入是已标注 IR: + +- 每个要展开的 scope 都在某个 `scf.for` 的直接 body 中,或第一版先限制为直接 body。 +- 同一 parent `for` 下的 preload scope 具有相同 `group_id` 和 `max_preload_num`。 +- 单个 parent `for` 内 `preload_num` 唯一且在 `[0, max_preload_num)`;允许缺失某些 `preload_num`,缺失 stage 表示该 kernel loop 在该 stage 没有 scope。 +- 跨 C/V group 的 `preload_num` 全集应覆盖 `[0, max_preload_num)`,否则 pipeline 不完整。 +- scope 内的 TPipe transaction 完整,不跨 scope 泄漏 borrowed entry,relay scope 除外。 +- 非 scope 的 side-effect op 第一版不允许夹在 preload group 中,除非 pass 明确知道如何按 stage clone。 + +### 7.2 收集 PreloadInfo + +对每个 parent `scf.for` 建立: + +```text +PreloadInfo { + max_preload_num, + lb, + ub, + step, + original_iv, + mappings[max_preload_num], + scopes[preload_num] +} +``` + +其中 `mappings[p]` 表示原 IR value 在 `preload_num = p` stage 中对应的新 value。 + +### 7.3 改写 loop 边界 + +参考 NPU IR `CreatePreload.cpp`,新 loop 的 exclusive upper bound 是: + +```text +new_ub = old_ub + max_preload_num * step +``` + +注意这里是 `max_preload_num * step`,不是 `(max_preload_num - 1) * step`。因为 `scf.for` upper bound 是 exclusive,最后一个实际 physical iv 是: + +```text +old_ub + (max_preload_num - 1) * step +``` + +新 loop 的每个 stage IV: + +```text +stage_iv[p] = physical_iv - (max_preload_num - 1 - p) * step +``` + +stage 的有效条件: + +```text +old_lb <= stage_iv[p] && stage_iv[p] < old_ub +``` + +如果前置 canonicalize 已把 `old_lb` 规范成 0,可以退化成 NPU IR 里的 `0 <= stage_iv && stage_iv < old_ub`。 + +### 7.4 改写 loop-carried value + +preload local buffer 的 loop-carried 参数要特殊处理: + +- 如果 init arg 来自带 `pto.multi_buffer` 的 pointer cast/alloc,不再作为新 loop 的普通 init arg。 +- 展开时为每个 `preload_num` clone 一个旋转后的 pointer cast,并写入对应 mapping。 +- no-result `pto.cv.scope` 不返回 preload local buffer;若 scope 内使用该 buffer,展开时通过 stage mapping 找到旋转后的 pointer cast。 +- 普通 loop-carried value 仍按原 `scf.for` 规则 yield,推荐使用最高 stage mapping,即 `mappings[max_preload_num - 1]`。 + +### 7.5 改写 local buffer 地址 + +plan-memory 产生的 pointer cast 携带 N 个地址。create-preload 对不同 stage 做旋转: + +```text +new_addrs[(max_preload_num - preload_num - 1 + i) % max_preload_num] + = old_addrs[i] +``` + +直观理解: + +- 高 `preload_num` stage 使用更靠前的 ring slot。 +- 低 `preload_num` stage 使用更旧的 ring slot。 +- 物理 loop 每走一步,logical iteration 在 ring buffer 上自然前进。 + +这个旋转必须只应用于带 `pto.multi_buffer` 且属于 preload local buffer 的地址,不能应用于 TPipe 内部 FIFO reserved buffer。 + +### 7.6 改写 workspace/subview + +若 op 带 `pto.cv.preload_workspace`: + +1. clone 原 `memref.subview` 或等价 view op。 +2. 将第 0 维 offset 改为 stage slot: + +```text +slot = ((stage_iv - lb) / step) % max_preload_num +``` + +3. 若新 view type 与原 use type 不完全一致,可临时插入 adaptor cast。 +4. 后续 canonicalize/adaptor propagation 消除临时 cast。 + +### 7.7 改写 scope + +对每个 no-result preload scope: + +```mlir +scf.if (%cond_for_stage) { + pto.cv.scope { + // clone old scope body with mappings[preload_num] + } +} +``` + +clone scope body 时: + +- 普通 op 用 `mappings[preload_num]` clone。 +- nested `scf.for` 递归 clone body。 +- 带 `pto.multi_buffer` 的 preload local pointer cast 使用旋转后的地址。 +- `preload_workspace` subview 使用 stage slot。 +- `talloc/tpush/tpop/tfree` 保持在原 scope 内,不重新排序。 +- relay scope 保持整体 clone,不在展开阶段拆开。 + +第一版不支持 scope result/yield,所以 create-preload 不从 `pto.cv.scope` 产出 `scf.if` result。需要跨 scope 使用的值必须满足以下之一: + +- TPipe 数据由 `tpush/tpop` 的 FIFO side effect 传递,依赖关系由 `input_pipe/output_pipe` 和 `group_id` 表示。 +- 普通 tilebuf/memref alloc 在 scope 外定义,scope 内只读写其内容。 +- workspace/subview 由 `preload_workspace` 逻辑按 stage slot 重新生成。 +- loop-carried index 或其它纯 SSA value 仍然按外层 `scf.for` 的 `iter_args/scf.yield` 规则处理,不通过 `pto.cv.scope` 返回。 + +如果未来需要让 scope 内定义的 SSA tile 直接给外部使用,应扩展 `pto.cv.scope` 的 result/yield 形式,并像 NPU IR `scope.scope` 一样在展开时把 result 映射回各 stage mapping。 + +### 7.8 非 scope op + +NPU IR 参考实现会对非 scope op 按每个 stage clone。PTOAS 第一版建议更保守: + +- 允许 constants、index arithmetic、pure shape/view op 按 stage clone。 +- 允许带 `pto.multi_buffer` 或 `preload_workspace` 标注的地址/view op 由专门逻辑 clone。 +- 不允许未知 side-effect op 出现在 preload group 中。若发现,pass 报错并要求前置 scope 识别把它纳入某个 scope。 + +这样可以避免复制 TPipe 或 memory side effect 导致语义变化。 + +### 7.9 展开后的清理 + +展开后运行: + +- adaptor propagation。 +- CSE。 +- canonicalize。 +- inline/remove `pto.cv.scope`。 +- verifier 检查 TPipe transaction 和 buffer alias。 + +## 8. 诊断和限制 + +pass 应给出明确诊断: + +- scope 缺少 `preload_num` 或 `max_preload_num`。 +- 跨 C/V group 的 `preload_num` 全集不连续,或单个 parent loop 内 `preload_num` 重复。 +- scope 不在支持的 parent `scf.for` 形态下。 +- C/V logical pipe 无法配对,或 shape/dtype/split 不一致。 +- preload local buffer 缺少 `pto.multi_buffer`,或数量小于 `max_preload_num`。 +- TPipe borrowed entry 生命周期跨出 scope。 +- 未知 side-effect op 出现在 preload group 中但不属于任何 scope。 +- 推导出的 `max_preload_num` 超过 FIFO slot 深度或用户指定上限。 + +第一版限制: + +- 不做 subtiling。 +- 不跨函数重排。 +- 不自动改变 TPipe `id/split/slot_size`。 +- 不自动合并不同 parent loop 的 preload group。 +- 不支持无法证明 loop step/lb/ub 兼容的 C/V loop group。 + +## 9. 实现和测试计划 + +当前实现状态: + +1. 已实现 `pto-cv-auto-mark-multi-buffer`,在 CV scope 标注前自动生成 `pto.multi_buffer`。 +2. 已实现 `pto-cv-mark-preload-scopes`,自动识别 scope 并生成 `preload_num` / `max_preload_num`。 +3. 已实现过渡期 `pto-inline-cv-preload-scopes`,保持当前 codegen pipeline 不被 no-result scope 影响。 +4. 已实现 plan-memory 对 CV preload multi-buffer 的 loop 级生命周期处理:带 `pto.multi_buffer` 且被 preload scope 使用的 local buffer 在 parent `scf.for` 上生成/释放,scope 内普通 gen/kill 不再切断它。 +5. 已实现 opt-in `pto-cv-create-preload`,完成无 `iter_args` loop 的扩界、stage condition、scope clone 和 local `pto.pointer_cast` 多地址旋转。 +6. 待实现 `pto-cv-verify-preload-marks`。 +7. 待实现 workspace/subview 的 `pto.cv.preload_workspace` slot 改写。 + +测试建议: + +- 已有一个最小两 stage C2V preload IR,验证 `new_ub = old_ub + N * step`、stage condition、plan-memory 产生 N 个地址,以及 create-preload 按 `preload_num` 旋转。 +- 一个 workspace subview 标注测试,验证第 0 维 slot 改写。 +- 已有一个 FA 风格四 stage 样例,覆盖 C/V 分离、relay scope、同一核内 producer/consumer 共存、自动 `pto.multi_buffer` 标注,以及 loop 内 TPipe 词法顺序和 `preload_num` 顺序不一致的场景。 +- 已有 FA level3 显式地址负例:若 local buffer 只有 1 个 planned address,即使带 `pto.multi_buffer` 标注,create-preload 也会报错而不是静默展开。 +- 后续负例测试:重复 `preload_num`、缺少 `max_preload_num`、scope 外 borrowed entry、未知 side-effect op。 diff --git a/include/PTO/IR/PTOOps.td b/include/PTO/IR/PTOOps.td index 8c81f0824..c405d501d 100644 --- a/include/PTO/IR/PTOOps.td +++ b/include/PTO/IR/PTOOps.td @@ -1328,6 +1328,25 @@ class PTO_SectionOp def SectionCubeOp : PTO_SectionOp<"section.cube">; def SectionVectorOp : PTO_SectionOp<"section.vector">; +//===----------------------------------------------------------------------===// +// CV Preload Scope Ops +//===----------------------------------------------------------------------===// + +def CVScopeOp : PTO_Op<"cv.scope", [SingleBlock, NoTerminator]> { + let summary = "No-result CV preload transaction scope"; + let description = [{ + Groups one producer, consumer, or relay TPipe transaction for CV preload + analysis. The first PTOAS version intentionally carries no operands, + results, or yield terminator: cross-scope dependencies are represented by + explicit TPipe metadata such as pto.cv.input_pipe and pto.cv.output_pipe, + while tilebuf/memref allocations that are shared across scopes must be + defined outside the scope. + }]; + + let regions = (region SizedRegion<1>:$body); + let assemblyFormat = "$body attr-dict"; +} + //===----------------------------------------------------------------------===// // Frontend TPUSH/TPOP Pipe Communication Ops //===----------------------------------------------------------------------===// diff --git a/include/PTO/Transforms/Passes.h b/include/PTO/Transforms/Passes.h index 147ae756a..da47d5c6b 100644 --- a/include/PTO/Transforms/Passes.h +++ b/include/PTO/Transforms/Passes.h @@ -37,6 +37,10 @@ std::unique_ptr createLoweringSyncToPipePass(); std::unique_ptr createPTOAssignDefaultFrontendPipeIdPass(); std::unique_ptr createPTOLowerFrontendPipeOpsPass(); std::unique_ptr createPTOInferValidatePipeInitPass(); +std::unique_ptr createPTOCVAutoMarkMultiBufferPass(); +std::unique_ptr createPTOCVMarkPreloadScopesPass(); +std::unique_ptr createPTOCVCreatePreloadPass(); +std::unique_ptr createPTOInlineCVPreloadScopesPass(); std::unique_ptr createPTOResolveReservedBuffersPass(); std::unique_ptr createPTOWrapFunctionsInSectionsPass(); std::unique_ptr createPTOVerifyTFreePass(); diff --git a/include/PTO/Transforms/Passes.td b/include/PTO/Transforms/Passes.td index 569210556..34c4baa9b 100644 --- a/include/PTO/Transforms/Passes.td +++ b/include/PTO/Transforms/Passes.td @@ -223,6 +223,85 @@ def PTOInferValidatePipeInit : Pass<"pto-infer-validate-pipe-init", "ModuleOp"> ]; } +def PTOCVAutoMarkMultiBuffer : Pass<"pto-cv-auto-mark-multi-buffer", "ModuleOp"> { + let summary = "Auto-mark CV preload local buffers as multi-buffered"; + let description = [{ + Scans Cube/Vector frontend TPipe transactions before CV scope materialization, + pairs producer/relay/consumer stages across logical C2V/V2C pipes, derives the + same max preload depth used by `pto-cv-mark-preload-scopes`, and annotates the + root local buffers used by those stages with `pto.multi_buffer`. + + The pass intentionally runs before scope creation so the marker is attached to + existing `pto.alloc_tile` / `memref.alloc` anchors and can be forwarded to plan + memory after tile-buffer lowering. + }]; + + let constructor = "mlir::pto::createPTOCVAutoMarkMultiBufferPass()"; + + let dependentDialects = [ + "mlir::pto::PTODialect", + "mlir::func::FuncDialect", + "mlir::memref::MemRefDialect", + "mlir::scf::SCFDialect" + ]; +} + +def PTOCVMarkPreloadScopes : Pass<"pto-cv-mark-preload-scopes", "ModuleOp"> { + let summary = "Create CV preload scopes from frontend TPipe transactions"; + let description = [{ + Scans Cube/Vector kernel_kind functions before frontend pipe lowering, + recognizes producer/consumer/relay TPipe transaction scopes, wraps each + transaction in a no-result `pto.cv.scope`, pairs them across logical + C2V/V2C pipes, and annotates the scope op with pto.cv.* preload metadata. + }]; + + let constructor = "mlir::pto::createPTOCVMarkPreloadScopesPass()"; + + let dependentDialects = [ + "mlir::pto::PTODialect", + "mlir::func::FuncDialect", + "mlir::scf::SCFDialect" + ]; +} + +def PTOCVCreatePreload : Pass<"pto-cv-create-preload", "ModuleOp"> { + let summary = "Expand marked CV preload scopes into stage-guarded loops"; + let description = [{ + Consumes no-result `pto.cv.scope` operations annotated with + `pto.cv.preload_num` / `pto.cv.max_preload_num`, expands the parent + `scf.for` upper bound by `max_preload_num * step`, clones each scope under + a stage-valid `scf.if`, remaps the loop induction variable to the stage IV, + and rotates multi-address `pto.pointer_cast` values prepared by plan memory. + + This pass expects local preload buffers to have already been annotated with + `pto.multi_buffer` and lowered/planned into variadic `pto.pointer_cast` + values. It deliberately does not synthesize storage addresses itself. + }]; + + let constructor = "mlir::pto::createPTOCVCreatePreloadPass()"; + + let dependentDialects = [ + "mlir::pto::PTODialect", + "mlir::arith::ArithDialect", + "mlir::scf::SCFDialect" + ]; +} + +def PTOInlineCVPreloadScopes : Pass<"pto-inline-cv-preload-scopes", "ModuleOp"> { + let summary = "Inline no-result CV preload scope containers"; + let description = [{ + Inlines `pto.cv.scope` bodies back into their parent blocks. This keeps the + current codegen pipeline unchanged while leaving an explicit scope IR point + for debug printing and future preload expansion passes. + }]; + + let constructor = "mlir::pto::createPTOInlineCVPreloadScopesPass()"; + + let dependentDialects = [ + "mlir::pto::PTODialect" + ]; +} + def PTOResolveReservedBuffers : Pass<"pto-resolve-reserved-buffers", "ModuleOp"> { let summary = "Resolve reserved local buffer addresses and peer pipe flag bases"; let description = [{ diff --git a/lib/PTO/Transforms/CMakeLists.txt b/lib/PTO/Transforms/CMakeLists.txt index 7caa511c8..23f58d1aa 100644 --- a/lib/PTO/Transforms/CMakeLists.txt +++ b/lib/PTO/Transforms/CMakeLists.txt @@ -29,6 +29,8 @@ add_mlir_dialect_library(PTOTransforms PTOAssignDefaultFrontendPipeIdPass.cpp PTOLowerFrontendPipeOpsPass.cpp PTOInferValidatePipeInitPass.cpp + CVMarkPreloadScopesPass.cpp + CVCreatePreloadPass.cpp PTOResolveReservedBuffersPass.cpp PTOWrapFunctionsInSectionsPass.cpp InsertSync/PTOIRTranslator.cpp diff --git a/lib/PTO/Transforms/CVCreatePreloadPass.cpp b/lib/PTO/Transforms/CVCreatePreloadPass.cpp new file mode 100644 index 000000000..fb9dea4f3 --- /dev/null +++ b/lib/PTO/Transforms/CVCreatePreloadPass.cpp @@ -0,0 +1,374 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "PTO/IR/PTO.h" +#include "PTO/Transforms/Passes.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/Pass/Pass.h" + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" + +#include + +namespace mlir { +namespace pto { +#define GEN_PASS_DEF_PTOCVCREATEPRELOAD +#include "PTO/Transforms/Passes.h.inc" +} // namespace pto +} // namespace mlir + +using namespace mlir; +using namespace mlir::pto; + +namespace { + +constexpr llvm::StringLiteral kPreloadNumAttr = "pto.cv.preload_num"; +constexpr llvm::StringLiteral kMaxPreloadNumAttr = "pto.cv.max_preload_num"; +constexpr llvm::StringLiteral kGroupIdAttr = "pto.cv.group_id"; + +struct LocalBufferBinding { + Value value; + PointerCastOp pointerCast; + BindTileOp bindTile; +}; + +static std::optional getI64Attr(Operation *op, + llvm::StringRef name) { + auto attr = op->getAttrOfType(name); + if (!attr) + return std::nullopt; + return attr.getInt(); +} + +static bool isLocalMemRefType(Type type) { + auto memrefTy = dyn_cast(type); + if (!memrefTy) + return false; + auto as = dyn_cast_or_null(memrefTy.getMemorySpace()); + if (!as) + return false; + AddressSpace space = as.getAddressSpace(); + return space != AddressSpace::GM && space != AddressSpace::Zero; +} + +static Value createIntegerLikeConstant(OpBuilder &builder, Location loc, + Type type, int64_t value) { + if (type.isIndex()) + return builder.create(loc, value); + return builder.create( + loc, type, builder.getIntegerAttr(type, value)); +} + +static Value createStageIV(OpBuilder &builder, Location loc, Value physicalIV, + Value step, int64_t maxPreloadNum, + int64_t preloadNum) { + int64_t distance = maxPreloadNum - 1 - preloadNum; + if (distance == 0) + return physicalIV; + + Value distanceValue = + createIntegerLikeConstant(builder, loc, step.getType(), distance); + Value delta = builder.create(loc, step, distanceValue); + return builder.create(loc, physicalIV, delta); +} + +static Value createShiftedUpperBound(OpBuilder &builder, Location loc, Value ub, + Value step, int64_t maxPreloadNum) { + Value stages = + createIntegerLikeConstant(builder, loc, step.getType(), maxPreloadNum); + Value delta = builder.create(loc, step, stages); + return builder.create(loc, ub, delta); +} + +static Value createStageCondition(OpBuilder &builder, Location loc, Value lb, + Value ub, Value stageIV) { + Value lowerOk = builder.create( + loc, arith::CmpIPredicate::sle, lb, stageIV); + Value upperOk = builder.create( + loc, arith::CmpIPredicate::slt, stageIV, ub); + return builder.create(loc, lowerOk, upperOk); +} + +static SmallVector rotateAddrs(PointerCastOp pointerCast, + int64_t maxPreloadNum, + int64_t preloadNum) { + ValueRange oldAddrs = pointerCast.getAddrs(); + SmallVector rotated(maxPreloadNum); + for (auto [idx, addr] : llvm::enumerate(oldAddrs.take_front(maxPreloadNum))) { + int64_t shifted = + (maxPreloadNum - preloadNum - 1 + static_cast(idx)) % + maxPreloadNum; + rotated[shifted] = addr; + } + return rotated; +} + +static LogicalResult collectLocalBufferBindings( + CVScopeOp scope, int64_t maxPreloadNum, scf::ForOp forOp, + DenseMap &bindings) { + auto walkResult = scope.walk([&](Operation *op) -> WalkResult { + for (Value operand : op->getOperands()) { + if (!operand || !forOp.isDefinedOutsideOfLoop(operand)) + continue; + + if (auto bind = operand.getDefiningOp()) { + auto pointerCast = bind.getSource().getDefiningOp(); + if (!pointerCast || !isLocalMemRefType(pointerCast.getType())) + continue; + if (static_cast(pointerCast.getAddrs().size()) < + maxPreloadNum) { + scope.emitOpError() + << "uses preload local buffer with only " + << pointerCast.getAddrs().size() << " planned address(es); " + << "expected at least " << maxPreloadNum; + return WalkResult::interrupt(); + } + bindings.try_emplace(bind.getResult(), + LocalBufferBinding{bind.getResult(), pointerCast, + bind}); + continue; + } + + if (auto pointerCast = operand.getDefiningOp()) { + if (!isLocalMemRefType(pointerCast.getType())) + continue; + if (static_cast(pointerCast.getAddrs().size()) < + maxPreloadNum) { + scope.emitOpError() + << "uses preload local buffer with only " + << pointerCast.getAddrs().size() << " planned address(es); " + << "expected at least " << maxPreloadNum; + return WalkResult::interrupt(); + } + bindings.try_emplace( + pointerCast.getResult(), + LocalBufferBinding{pointerCast.getResult(), pointerCast, + BindTileOp()}); + } + } + return WalkResult::advance(); + }); + return walkResult.wasInterrupted() ? failure() : success(); +} + +static LogicalResult validateNoUnsupportedLoopLocalUses(CVScopeOp scope, + scf::ForOp forOp) { + Operation *scopeOp = scope.getOperation(); + LogicalResult result = success(); + scope.walk([&](Operation *op) { + for (Value operand : op->getOperands()) { + if (operand == forOp.getInductionVar()) + continue; + Operation *defOp = operand.getDefiningOp(); + if (!defOp || scopeOp->isAncestor(defOp)) + continue; + if (defOp->getParentOfType() == forOp && + !forOp.isDefinedOutsideOfLoop(operand)) { + scope.emitOpError() + << "contains a use of value defined in the original loop but " + "outside the CV scope; hoist it into the scope or outside the " + "loop before create-preload"; + result = failure(); + } + } + }); + return result; +} + +static LogicalResult collectDirectScopes(scf::ForOp forOp, + SmallVectorImpl &scopes, + int64_t &maxPreloadNum) { + maxPreloadNum = -1; + int64_t groupId = -1; + DenseSet seenPreloadNums; + for (Operation &op : forOp.getBody()->without_terminator()) { + auto scope = dyn_cast(&op); + if (!scope) + continue; + + std::optional preload = getI64Attr(scope, kPreloadNumAttr); + std::optional maxPreload = getI64Attr(scope, kMaxPreloadNumAttr); + std::optional group = getI64Attr(scope, kGroupIdAttr); + if (!preload || !maxPreload || !group) + return scope.emitOpError("is missing required CV preload attributes"); + if (*preload < 0 || *maxPreload <= 1 || *preload >= *maxPreload) + return scope.emitOpError("has invalid CV preload attributes"); + if (!seenPreloadNums.insert(*preload).second) + return scope.emitOpError("duplicates preload_num in the same loop"); + + if (maxPreloadNum < 0) + maxPreloadNum = *maxPreload; + else if (maxPreloadNum != *maxPreload) + return scope.emitOpError( + "has a different max_preload_num than sibling CV scopes"); + if (groupId < 0) + groupId = *group; + else if (groupId != *group) + return scope.emitOpError( + "has a different group_id than sibling CV scopes"); + + if (failed(validateNoUnsupportedLoopLocalUses(scope, forOp))) + return failure(); + scopes.push_back(scope); + } + return success(); +} + +static void clonePointerCastAndBind(OpBuilder &builder, + LocalBufferBinding &binding, + int64_t maxPreloadNum, int64_t preloadNum, + IRMapping &mapping) { + SmallVector addrs = + rotateAddrs(binding.pointerCast, maxPreloadNum, preloadNum); + std::optional config = binding.pointerCast.getConfig(); + auto newPointerCast = builder.create( + binding.pointerCast.getLoc(), binding.pointerCast.getType(), addrs, + binding.pointerCast.getValidRow() ? binding.pointerCast.getValidRow() + : Value(), + binding.pointerCast.getValidCol() ? binding.pointerCast.getValidCol() + : Value(), + config ? static_cast(*config) : Attribute()); + mapping.map(binding.pointerCast.getResult(), newPointerCast.getResult()); + + if (binding.bindTile) { + auto newBind = builder.create( + binding.bindTile.getLoc(), binding.bindTile.getResult().getType(), + newPointerCast.getResult(), + binding.bindTile.getValidRow() ? binding.bindTile.getValidRow() + : Value(), + binding.bindTile.getValidCol() ? binding.bindTile.getValidCol() + : Value(), + binding.bindTile.getConfig()); + mapping.map(binding.bindTile.getResult(), newBind.getResult()); + } +} + +static CVScopeOp cloneCVScope(OpBuilder &builder, CVScopeOp scope, + IRMapping &mapping) { + auto cloned = builder.create(scope.getLoc()); + cloned->setAttrs(scope->getAttrDictionary()); + cloned.getBody().push_back(new Block()); + + OpBuilder bodyBuilder = + OpBuilder::atBlockEnd(&cloned.getBody().front()); + for (Operation &op : scope.getBody().front().getOperations()) + bodyBuilder.clone(op, mapping); + return cloned; +} + +static LogicalResult rewritePreloadLoop(IRRewriter &rewriter, scf::ForOp forOp, + ArrayRef scopes, + int64_t maxPreloadNum) { + if (!forOp.getInitArgs().empty()) { + return forOp.emitOpError( + "pto-cv-create-preload currently supports scf.for without iter_args"); + } + + DenseMap localBindings; + for (CVScopeOp scope : scopes) { + if (failed( + collectLocalBufferBindings(scope, maxPreloadNum, forOp, + localBindings))) + return failure(); + } + + SmallVector orderedBindings; + orderedBindings.reserve(localBindings.size()); + for (const auto &it : localBindings) + orderedBindings.push_back(it.second); + + Location loc = forOp.getLoc(); + rewriter.setInsertionPoint(forOp); + Value newUpperBound = createShiftedUpperBound( + rewriter, loc, forOp.getUpperBound(), forOp.getStep(), maxPreloadNum); + auto newForOp = + rewriter.create(loc, forOp.getLowerBound(), newUpperBound, + forOp.getStep()); + + OpBuilder bodyBuilder = OpBuilder::atBlockBegin(newForOp.getBody()); + SmallVector stageMappings(maxPreloadNum); + SmallVector stageIVs(maxPreloadNum); + for (int64_t preloadNum = 0; preloadNum < maxPreloadNum; ++preloadNum) { + Value stageIV = createStageIV(bodyBuilder, loc, newForOp.getInductionVar(), + newForOp.getStep(), maxPreloadNum, + preloadNum); + stageIVs[preloadNum] = stageIV; + stageMappings[preloadNum].map(forOp.getInductionVar(), stageIV); + for (LocalBufferBinding &binding : orderedBindings) { + clonePointerCastAndBind(bodyBuilder, binding, maxPreloadNum, preloadNum, + stageMappings[preloadNum]); + } + } + + for (Operation &op : forOp.getBody()->without_terminator()) { + auto scope = dyn_cast(&op); + if (!scope) { + return op.emitOpError( + "non-CV-scope op cannot be preserved by pto-cv-create-preload yet"); + } + + int64_t preloadNum = *getI64Attr(scope, kPreloadNumAttr); + Value cond = createStageCondition(bodyBuilder, scope.getLoc(), + forOp.getLowerBound(), + forOp.getUpperBound(), + stageIVs[preloadNum]); + auto ifOp = bodyBuilder.create(scope.getLoc(), cond, + /*withElseRegion=*/false); + OpBuilder thenBuilder(ifOp.thenBlock(), ifOp.thenBlock()->begin()); + cloneCVScope(thenBuilder, scope, stageMappings[preloadNum]); + } + + rewriter.eraseOp(forOp); + for (LocalBufferBinding &binding : orderedBindings) { + if (binding.bindTile && binding.bindTile.getResult().use_empty()) + rewriter.eraseOp(binding.bindTile); + if (binding.pointerCast && binding.pointerCast.getResult().use_empty()) + rewriter.eraseOp(binding.pointerCast); + } + return success(); +} + +struct PTOCVCreatePreloadPass + : public mlir::pto::impl::PTOCVCreatePreloadBase< + PTOCVCreatePreloadPass> { + void runOnOperation() override { + ModuleOp moduleOp = getOperation(); + SmallVector loops; + moduleOp.walk([&](scf::ForOp forOp) { loops.push_back(forOp); }); + + IRRewriter rewriter(moduleOp.getContext()); + for (scf::ForOp forOp : llvm::reverse(loops)) { + SmallVector scopes; + int64_t maxPreloadNum = -1; + if (failed(collectDirectScopes(forOp, scopes, maxPreloadNum))) { + signalPassFailure(); + return; + } + if (scopes.empty()) + continue; + if (failed(rewritePreloadLoop(rewriter, forOp, scopes, maxPreloadNum))) { + signalPassFailure(); + return; + } + } + } +}; + +} // namespace + +std::unique_ptr mlir::pto::createPTOCVCreatePreloadPass() { + return std::make_unique(); +} diff --git a/lib/PTO/Transforms/CVMarkPreloadScopesPass.cpp b/lib/PTO/Transforms/CVMarkPreloadScopesPass.cpp new file mode 100644 index 000000000..c8af8f9e2 --- /dev/null +++ b/lib/PTO/Transforms/CVMarkPreloadScopesPass.cpp @@ -0,0 +1,657 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +#include "PTO/IR/PTO.h" +#include "PTO/Transforms/MultiBuffer.h" +#include "PTO/Transforms/Passes.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Interfaces/ViewLikeInterface.h" +#include "mlir/Pass/Pass.h" + +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/Twine.h" + +#include +#include +#include +#include +#include +#include + +namespace mlir { +namespace pto { +#define GEN_PASS_DEF_PTOCVAUTOMARKMULTIBUFFER +#define GEN_PASS_DEF_PTOCVMARKPRELOADSCOPES +#define GEN_PASS_DEF_PTOINLINECVPRELOADSCOPES +#include "PTO/Transforms/Passes.h.inc" +} // namespace pto +} // namespace mlir + +using namespace mlir; +using namespace mlir::pto; + +namespace { + +constexpr llvm::StringLiteral kScopeIdAttr = "pto.cv.scope_id"; +constexpr llvm::StringLiteral kGroupIdAttr = "pto.cv.group_id"; +constexpr llvm::StringLiteral kRoleAttr = "pto.cv.role"; +constexpr llvm::StringLiteral kCoreAttr = "pto.cv.core"; +constexpr llvm::StringLiteral kPreloadNumAttr = "pto.cv.preload_num"; +constexpr llvm::StringLiteral kMaxPreloadNumAttr = "pto.cv.max_preload_num"; +constexpr llvm::StringLiteral kInputPipeAttr = "pto.cv.input_pipe"; +constexpr llvm::StringLiteral kOutputPipeAttr = "pto.cv.output_pipe"; + +enum class CoreKind { Cube, Vector }; +enum class PipeDir { C2V, V2C }; + +struct PipeKey { + int64_t id = 0; + PipeDir dir = PipeDir::C2V; + + bool operator<(const PipeKey &other) const { + return std::tie(dir, id) < std::tie(other.dir, other.id); + } + + bool operator==(const PipeKey &other) const { + return id == other.id && dir == other.dir; + } +}; + +static std::string stringify(CoreKind core) { + return core == CoreKind::Cube ? "cube" : "vector"; +} + +static std::string stringify(PipeDir dir) { + return dir == PipeDir::C2V ? "c2v" : "v2c"; +} + +static std::string stringify(const PipeKey &key) { + return (llvm::Twine(stringify(key.dir)) + ":" + llvm::Twine(key.id)).str(); +} + +static std::string stringify(std::optional key) { + if (!key) + return ""; + return stringify(*key); +} + +struct PipeAction { + Operation *op = nullptr; + bool isTAlloc = false; + bool isTPush = false; + bool isTPop = false; + bool isTFree = false; + std::optional input; + std::optional output; + + explicit operator bool() const { return op != nullptr; } +}; + +static PipeAction getPipeAction(Operation *op) { + PipeAction action; + action.op = op; + + if (auto alloc = dyn_cast(op)) { + action.isTAlloc = true; + action.output = PipeKey{alloc.getId(), PipeDir::C2V}; + return action; + } + if (auto alloc = dyn_cast(op)) { + action.isTAlloc = true; + action.output = PipeKey{alloc.getId(), PipeDir::V2C}; + return action; + } + if (auto push = dyn_cast(op)) { + action.isTPush = true; + action.output = PipeKey{push.getId(), PipeDir::C2V}; + return action; + } + if (auto push = dyn_cast(op)) { + action.isTPush = true; + action.output = PipeKey{push.getId(), PipeDir::V2C}; + return action; + } + if (auto pop = dyn_cast(op)) { + action.isTPop = true; + action.input = PipeKey{pop.getId(), PipeDir::C2V}; + return action; + } + if (auto pop = dyn_cast(op)) { + action.isTPop = true; + action.input = PipeKey{pop.getId(), PipeDir::V2C}; + return action; + } + if (auto free = dyn_cast(op)) { + action.isTFree = true; + action.input = PipeKey{free.getId(), PipeDir::C2V}; + return action; + } + if (auto free = dyn_cast(op)) { + action.isTFree = true; + action.input = PipeKey{free.getId(), PipeDir::V2C}; + return action; + } + + action.op = nullptr; + return action; +} + +struct PendingScope { + bool active = false; + std::optional input; + std::optional output; + bool inputReleased = false; + bool outputCommitted = false; + SmallVector ops; + Operation *firstOp = nullptr; + Operation *lastOp = nullptr; + + void reset() { + active = false; + input.reset(); + output.reset(); + inputReleased = false; + outputCommitted = false; + ops.clear(); + firstOp = nullptr; + lastOp = nullptr; + } + + void start() { + if (!active) + active = true; + } + + void include(Operation *op) { + start(); + if (!firstOp) + firstOp = op; + lastOp = op; + ops.push_back(op); + } +}; + +struct ScopeInfo { + int64_t id = -1; + CoreKind core = CoreKind::Cube; + std::optional input; + std::optional output; + SmallVector ops; + Operation *firstOp = nullptr; + Operation *lastOp = nullptr; + int64_t classId = -1; +}; + +struct ScopeClass { + CoreKind core = CoreKind::Cube; + std::optional input; + std::optional output; + SmallVector scopeIds; + int64_t groupId = -1; + int64_t preloadNum = -1; + int64_t maxPreloadNum = -1; +}; + +static std::string getRole(const ScopeInfo &scope) { + if (scope.input && scope.output) + return "relay"; + if (scope.output) + return "producer"; + return "consumer"; +} + +static std::string getClassKey(CoreKind core, std::optional input, + std::optional output) { + return (llvm::Twine(stringify(core)) + "|" + llvm::Twine(stringify(input)) + + "|" + llvm::Twine(stringify(output))) + .str(); +} + +static std::optional getKernelCore(func::FuncOp funcOp) { + auto kernelKindAttr = funcOp->getAttrOfType( + FunctionKernelKindAttr::name); + if (!kernelKindAttr) + return std::nullopt; + + switch (kernelKindAttr.getKernelKind()) { + case FunctionKernelKind::Cube: + return CoreKind::Cube; + case FunctionKernelKind::Vector: + return CoreKind::Vector; + } + llvm_unreachable("unexpected kernel kind"); +} + +static void collectScopeIfValid(PendingScope &pending, CoreKind core, + SmallVectorImpl &scopes) { + std::optional committedOutput = + pending.outputCommitted ? pending.output : std::nullopt; + if (!pending.active || (!pending.input && !committedOutput)) { + pending.reset(); + return; + } + + ScopeInfo scope; + scope.id = static_cast(scopes.size()); + scope.core = core; + scope.input = pending.input; + scope.output = committedOutput; + scope.ops = pending.ops; + scope.firstOp = pending.firstOp; + scope.lastOp = pending.lastOp; + scopes.push_back(std::move(scope)); + pending.reset(); +} + +static void includeRange(PendingScope &pending, Operation *first, + Operation *last) { + if (!first || !last) + return; + + for (Operation *op = first; op; op = op->getNextNode()) { + if (pending.lastOp == op) + return; + pending.include(op); + if (op == last) + return; + } +} + +static void includeMissingThrough(PendingScope &pending, Operation *last) { + if (!pending.active || !last || pending.lastOp == last) + return; + includeRange(pending, pending.lastOp ? pending.lastOp->getNextNode() : last, + last); +} + +static void collectScopesInFor(scf::ForOp forOp, CoreKind core, + SmallVectorImpl &scopes) { + PendingScope pending; + Operation *segmentStart = &forOp.getBody()->front(); + Operation *terminator = forOp.getBody()->getTerminator(); + + for (Operation &op : forOp.getBody()->without_terminator()) { + PipeAction action = getPipeAction(&op); + if (!action) + continue; + + if (action.isTPop) { + if (pending.active) + includeMissingThrough(pending, op.getPrevNode()); + collectScopeIfValid(pending, core, scopes); + pending.input = action.input; + includeRange(pending, &op, &op); + segmentStart = &op; + continue; + } + + if (action.isTAlloc) { + if (pending.active && pending.output) + collectScopeIfValid(pending, core, scopes); + Operation *start = + pending.active + ? (pending.lastOp ? pending.lastOp->getNextNode() : &op) + : (segmentStart && segmentStart != terminator ? segmentStart + : &op); + includeRange(pending, start, &op); + pending.output = action.output; + continue; + } + + if (action.isTPush) { + pending.output = action.output; + pending.outputCommitted = true; + if (!pending.active) { + Operation *start = + segmentStart && segmentStart != terminator ? segmentStart : &op; + includeRange(pending, start, &op); + } else { + includeMissingThrough(pending, &op); + } + if (!pending.input || pending.inputReleased) { + collectScopeIfValid(pending, core, scopes); + segmentStart = op.getNextNode(); + } + continue; + } + + if (action.isTFree && pending.active) { + includeMissingThrough(pending, &op); + pending.inputReleased = true; + if (pending.outputCommitted) { + collectScopeIfValid(pending, core, scopes); + segmentStart = op.getNextNode(); + } + continue; + } + } + + if (pending.active) + includeMissingThrough(pending, terminator ? terminator->getPrevNode() + : nullptr); + collectScopeIfValid(pending, core, scopes); +} + +static void buildScopeClasses(SmallVectorImpl &scopes, + SmallVectorImpl &classes) { + llvm::StringMap classByKey; + for (ScopeInfo &scope : scopes) { + std::string key = getClassKey(scope.core, scope.input, scope.output); + auto [it, inserted] = classByKey.try_emplace(key, classes.size()); + if (inserted) { + ScopeClass klass; + klass.core = scope.core; + klass.input = scope.input; + klass.output = scope.output; + classes.push_back(std::move(klass)); + } + scope.classId = it->second; + classes[scope.classId].scopeIds.push_back(scope.id); + } +} + +static void assignPreloadNumbers(SmallVectorImpl &classes) { + std::map inputClass; + for (auto [idx, klass] : llvm::enumerate(classes)) { + if (klass.input) + inputClass.try_emplace(*klass.input, static_cast(idx)); + } + + SmallVector next(classes.size(), -1); + for (auto [idx, klass] : llvm::enumerate(classes)) { + if (!klass.output) + continue; + auto it = inputClass.find(*klass.output); + if (it == inputClass.end()) + continue; + next[idx] = it->second; + } + + int64_t nextGroupId = 0; + for (auto [startIdx, klass] : llvm::enumerate(classes)) { + if (klass.input || !klass.output) + continue; + + SmallVector chain; + std::set seen; + int64_t cur = static_cast(startIdx); + while (cur >= 0 && seen.insert(cur).second) { + chain.push_back(cur); + cur = next[cur]; + } + + if (chain.size() < 2) + continue; + if (classes[chain.back()].output) + continue; + + int64_t maxPreloadNum = static_cast(chain.size()); + int64_t groupId = nextGroupId++; + for (auto [stageIdx, classIdx] : llvm::enumerate(chain)) { + ScopeClass &stage = classes[classIdx]; + stage.groupId = groupId; + stage.maxPreloadNum = maxPreloadNum; + stage.preloadNum = maxPreloadNum - 1 - static_cast(stageIdx); + } + } +} + +static void collectCVScopes(ModuleOp moduleOp, + SmallVectorImpl &scopes) { + moduleOp.walk([&](func::FuncOp funcOp) { + std::optional core = getKernelCore(funcOp); + if (!core) + return; + funcOp.walk( + [&](scf::ForOp forOp) { collectScopesInFor(forOp, *core, scopes); }); + }); +} + +static int64_t clampMultiBufferNum(Operation *anchor, int64_t num) { + if (num <= 1) + return 0; + if (num <= static_cast(kPtoMultiBufferMaxNum)) + return num; + + anchor->emitWarning() + << "auto CV multi-buffer depth " << num << " exceeds maximum " + << kPtoMultiBufferMaxNum << "; clamping"; + return static_cast(kPtoMultiBufferMaxNum); +} + +static bool markMultiBuffer(Operation *op, int64_t num) { + if (!op) + return false; + num = clampMultiBufferNum(op, num); + if (num <= 1) + return false; + if (op->getAttr(kPtoMultiBufferAttrName)) + return false; + + OpBuilder builder(op->getContext()); + op->setAttr(kPtoMultiBufferAttrName, builder.getI32IntegerAttr(num)); + return true; +} + +static void collectRootAllocLikeOps(Value value, + llvm::SmallPtrSetImpl &roots) { + SmallVector worklist{value}; + llvm::DenseSet visited; + + while (!worklist.empty()) { + Value current = worklist.pop_back_val(); + if (!current || !visited.insert(current).second) + continue; + + Operation *def = current.getDefiningOp(); + if (!def) + continue; + + if (isa(def)) { + roots.insert(def); + continue; + } + + if (auto viewLike = dyn_cast(def)) { + worklist.push_back(viewLike.getViewSource()); + continue; + } + + if (auto bind = dyn_cast(def)) { + worklist.push_back(bind.getSource()); + continue; + } + } +} + +static void markScopeLocalBuffers(const ScopeInfo &scope, + const ScopeClass &klass) { + if (klass.maxPreloadNum <= 1) + return; + + llvm::SmallPtrSet roots; + for (Operation *op : scope.ops) { + for (Value operand : op->getOperands()) + collectRootAllocLikeOps(operand, roots); + } + + for (Operation *root : roots) + markMultiBuffer(root, klass.maxPreloadNum); +} + +static void markCVPreloadMultiBuffers(MutableArrayRef scopes, + ArrayRef classes) { + for (const ScopeInfo &scope : scopes) { + if (scope.classId < 0) + continue; + const ScopeClass &klass = classes[scope.classId]; + markScopeLocalBuffers(scope, klass); + } +} + +static bool isInsideMovedRange(Operation *op, + const llvm::DenseSet &moved) { + for (Operation *cur = op; cur; cur = cur->getParentOp()) { + if (moved.contains(cur)) + return true; + } + return false; +} + +static bool canWrapNoResultScope(const ScopeInfo &scope) { + if (!scope.firstOp || !scope.lastOp) + return false; + if (scope.firstOp->getBlock() != scope.lastOp->getBlock()) + return false; + + llvm::DenseSet moved; + for (Operation *op = scope.firstOp;; op = op->getNextNode()) { + moved.insert(op); + if (op == scope.lastOp) + break; + } + + for (Operation *op : moved) { + for (Value result : op->getResults()) { + for (OpOperand &use : result.getUses()) { + if (!isInsideMovedRange(use.getOwner(), moved)) + return false; + } + } + } + return true; +} + +static CVScopeOp wrapScope(ScopeInfo &scope, const ScopeClass *klass, + MLIRContext *ctx) { + Builder attrBuilder(ctx); + OpBuilder builder(scope.firstOp); + auto cvScope = builder.create(scope.firstOp->getLoc()); + cvScope.getBody().push_back(new Block()); + + cvScope->setAttr(kScopeIdAttr, attrBuilder.getI64IntegerAttr(scope.id)); + cvScope->setAttr(kRoleAttr, attrBuilder.getStringAttr(getRole(scope))); + cvScope->setAttr(kCoreAttr, attrBuilder.getStringAttr(stringify(scope.core))); + cvScope->setAttr(kInputPipeAttr, + attrBuilder.getStringAttr(stringify(scope.input))); + cvScope->setAttr(kOutputPipeAttr, + attrBuilder.getStringAttr(stringify(scope.output))); + + if (klass && klass->groupId >= 0) { + cvScope->setAttr(kGroupIdAttr, + attrBuilder.getI64IntegerAttr(klass->groupId)); + cvScope->setAttr(kPreloadNumAttr, + attrBuilder.getI64IntegerAttr(klass->preloadNum)); + cvScope->setAttr(kMaxPreloadNumAttr, + attrBuilder.getI64IntegerAttr(klass->maxPreloadNum)); + } + + Block &scopeBlock = cvScope.getBody().front(); + Block *parentBlock = cvScope->getBlock(); + auto begin = Block::iterator(scope.firstOp); + auto end = std::next(Block::iterator(scope.lastOp)); + scopeBlock.getOperations().splice(scopeBlock.end(), + parentBlock->getOperations(), begin, end); + return cvScope; +} + +static void createScopeOps(MutableArrayRef scopes, + ArrayRef classes, MLIRContext *ctx) { + for (ScopeInfo &scope : scopes) { + const ScopeClass *klass = + scope.classId >= 0 ? &classes[scope.classId] : nullptr; + if (!canWrapNoResultScope(scope)) { + if (scope.firstOp) + scope.firstOp->emitWarning( + "cannot wrap CV preload scope without results because a value " + "defined in the scope is used outside; leave it unwrapped"); + continue; + } + wrapScope(scope, klass, ctx); + } +} + +static void inlineCVScope(CVScopeOp scopeOp) { + Block &scopeBlock = scopeOp.getBody().front(); + Block *parentBlock = scopeOp->getBlock(); + parentBlock->getOperations().splice(Block::iterator(scopeOp), + scopeBlock.getOperations(), + scopeBlock.begin(), scopeBlock.end()); + scopeOp.erase(); +} + +struct PTOCVMarkPreloadScopesPass + : public mlir::pto::impl::PTOCVMarkPreloadScopesBase< + PTOCVMarkPreloadScopesPass> { + void runOnOperation() override { + ModuleOp moduleOp = getOperation(); + SmallVector scopes; + collectCVScopes(moduleOp, scopes); + + if (scopes.empty()) + return; + + SmallVector classes; + buildScopeClasses(scopes, classes); + assignPreloadNumbers(classes); + createScopeOps(scopes, classes, moduleOp.getContext()); + } +}; + +struct PTOCVAutoMarkMultiBufferPass + : public mlir::pto::impl::PTOCVAutoMarkMultiBufferBase< + PTOCVAutoMarkMultiBufferPass> { + void runOnOperation() override { + ModuleOp moduleOp = getOperation(); + SmallVector scopes; + collectCVScopes(moduleOp, scopes); + + if (scopes.empty()) + return; + + SmallVector classes; + buildScopeClasses(scopes, classes); + assignPreloadNumbers(classes); + markCVPreloadMultiBuffers(scopes, classes); + } +}; + +struct PTOInlineCVPreloadScopesPass + : public mlir::pto::impl::PTOInlineCVPreloadScopesBase< + PTOInlineCVPreloadScopesPass> { + void runOnOperation() override { + SmallVector scopes; + getOperation()->walk([&](CVScopeOp scopeOp) { + scopes.push_back(scopeOp); + }); + + for (CVScopeOp scopeOp : llvm::reverse(scopes)) + inlineCVScope(scopeOp); + } +}; + +} // namespace + +std::unique_ptr mlir::pto::createPTOCVAutoMarkMultiBufferPass() { + return std::make_unique(); +} + +std::unique_ptr mlir::pto::createPTOCVMarkPreloadScopesPass() { + return std::make_unique(); +} + +std::unique_ptr mlir::pto::createPTOInlineCVPreloadScopesPass() { + return std::make_unique(); +} diff --git a/lib/PTO/Transforms/PTOPlanMemory.cpp b/lib/PTO/Transforms/PTOPlanMemory.cpp index c7a5406ca..7cff8f68f 100644 --- a/lib/PTO/Transforms/PTOPlanMemory.cpp +++ b/lib/PTO/Transforms/PTOPlanMemory.cpp @@ -526,12 +526,13 @@ collectAutoReserveBufferBitsByAddressSpace(const ReserveBufferPlans &plans) { void MemLivenessAnalysis::build() { Region &funcRegion = func_.getBody(); stableValueOrder = buildStableValueOrder(func_); + collectMultiBufferAnnotations(); + collectPreloadBufferParentLoops(); Liveness live(func_); // Recursively obtaining IR information. RecursionIR(&funcRegion, live); // the lifetime of the buffer. GenerateBufferLife(); - collectMultiBufferAnnotations(); } void MemLivenessAnalysis::collectMultiBufferAnnotations() { @@ -546,6 +547,50 @@ void MemLivenessAnalysis::collectMultiBufferAnnotations() { }); } +void MemLivenessAnalysis::collectPreloadBufferParentLoops() { + if (!isLocalMemPlan()) + return; + + auto recordValue = [&](Value value, scf::ForOp parentLoop) { + if (!value || !isa(value.getType())) + return; + Value root = tracebackMemRef(value); + auto alloc = root.getDefiningOp(); + if (!alloc) + return; + auto multiBufferAttr = + alloc->getAttrOfType(kPtoMultiBufferAttrName); + if (!multiBufferAttr || multiBufferAttr.getInt() <= 1) + return; + auto memorySpaceAttr = GetBufferSpaceAttr(root); + if (!isLocalBuffer(memorySpaceAttr)) + return; + + SmallVector &parentLoops = + preloadBufferParentLoops[root]; + Operation *loopOp = parentLoop.getOperation(); + if (!llvm::is_contained(parentLoops, loopOp)) + parentLoops.push_back(loopOp); + }; + + func_.walk([&](CVScopeOp scope) { + if (!scope->getAttr("pto.cv.preload_num") || + !scope->getAttr("pto.cv.max_preload_num")) + return; + + auto parentLoop = scope->getParentOfType(); + if (!parentLoop) + return; + + scope.walk([&](Operation *op) { + for (Value operand : op->getOperands()) + recordValue(operand, parentLoop); + for (Value result : op->getResults()) + recordValue(result, parentLoop); + }); + }); +} + bool MemLivenessAnalysis::isLocalMemPlan() const { return planMode == MemPlanMode::LOCAL_MEM_PLAN; } @@ -715,10 +760,12 @@ void MemLivenessAnalysis::RecursiveForOp(scf::ForOp forOp, Liveness live) { // need to handle kill buffer. auto forBeginSeq = UpdateLinearOperation(forOp.getOperation()); UpdateOpGenInfo(forBeginSeq, GetLiveBuffersInLoop(forOp, live)); + UpdatePreloadLoopGenInfo(forBeginSeq, forOp); UpdateForOpInitArgsAlias(forOp); RecursionIR(&forOp.getRegion(), live); UpdateForOpBufferAlias(forOp); auto forEndSeq = UpdateLinearOperation(forOp.getOperation()); + UpdatePreloadLoopKillInfo(forEndSeq, forOp, live); OpKillHandle(forEndSeq, live, forOp->getBlock()); } @@ -1071,7 +1118,20 @@ void MemLivenessAnalysis::UpdateOpGenInfo(OpInfo *opInfo, } } +void MemLivenessAnalysis::UpdatePreloadLoopGenInfo(OpInfo *opInfo, + scf::ForOp forOp) { + Operation *loopOp = forOp.getOperation(); + for (const auto &it : preloadBufferParentLoops) { + if (!llvm::is_contained(it.second, loopOp)) + continue; + UpdateOperandGenInfo(opInfo, it.first); + } +} + void MemLivenessAnalysis::UpdateOperandGenInfo(OpInfo *opInfo, Value operand) { + if (IsInsidePreloadParentLoop(operand, opInfo->operation)) + return; + auto iter_buffer = buffer2status.find(operand); if (iter_buffer == buffer2status.end()) return; @@ -1124,6 +1184,11 @@ void MemLivenessAnalysis::UpdateOpKillInfo(OpInfo *opInfo, Value operand, Liveness live) { auto aliasBuffers = GetAliasBuffers(operand); aliasBuffers.insert(operand); + if (llvm::any_of(aliasBuffers, [&](Value aliasBuffer) { + return IsInsidePreloadParentLoop(aliasBuffer, opInfo->operation); + })) + return; + for (Value aliasBuffer : aliasBuffers) { auto iterBuffer = buffer2status.find(aliasBuffer); if (iterBuffer == buffer2status.end()) @@ -1150,6 +1215,32 @@ void MemLivenessAnalysis::UpdateOpKillInfo(OpInfo *opInfo, Value operand, } } +void MemLivenessAnalysis::UpdatePreloadLoopKillInfo(OpInfo *opInfo, + scf::ForOp forOp, + Liveness live) { + Operation *loopOp = forOp.getOperation(); + for (const auto &it : preloadBufferParentLoops) { + if (!llvm::is_contained(it.second, loopOp)) + continue; + UpdateOpKillInfo(opInfo, it.first, live); + } +} + +bool MemLivenessAnalysis::IsInsidePreloadParentLoop(Value buffer, + Operation *op) const { + auto it = preloadBufferParentLoops.find(buffer); + if (it == preloadBufferParentLoops.end()) + return false; + + for (Operation *loopOp : it->second) { + if (loopOp == op) + continue; + if (loopOp->isAncestor(op)) + return true; + } + return false; +} + Operation *MemLivenessAnalysis::GetBufferGenOp(Value buffer) const { auto it = buffer2GenOp.find(buffer); if (it != buffer2GenOp.end()) diff --git a/lib/PTO/Transforms/PTOPlanMemory.h b/lib/PTO/Transforms/PTOPlanMemory.h index b8162aa4b..641469443 100644 --- a/lib/PTO/Transforms/PTOPlanMemory.h +++ b/lib/PTO/Transforms/PTOPlanMemory.h @@ -330,6 +330,10 @@ class MemLivenessAnalysis { /// Read `pto.multi_buffer` on memref.alloc and fill `buffer2MultiNum`. void collectMultiBufferAnnotations(); + /// Find multi-buffer local allocs used by annotated CV preload scopes and + /// remember the parent loop whose physical iterations will overlap them. + void collectPreloadBufferParentLoops(); + void RecursionIR(Region *region, Liveness live); /// Get the buffer used within the loop and defined outside the loop. @@ -400,6 +404,9 @@ class MemLivenessAnalysis { /// Process gen buffer based on the result value of op. void UpdateOpGenInfo(OpInfo *opInfo, const ValueRange &results); + /// Force preload buffers to be generated at their parent loop boundary. + void UpdatePreloadLoopGenInfo(OpInfo *opInfo, scf::ForOp forOp); + /// Update normal operand gen information on buffer. void UpdateOperandGenInfo(OpInfo *opInfo, Value operand); @@ -435,6 +442,14 @@ class MemLivenessAnalysis { /// Process kill buffer based on the result live of op. void UpdateOpKillInfo(OpInfo *opInfo, Value operand, Liveness live); + /// Force preload buffers to be killed at their parent loop boundary. + void UpdatePreloadLoopKillInfo(OpInfo *opInfo, scf::ForOp forOp, + Liveness live); + + /// Return true when a buffer is a preload buffer being touched inside, but + /// not at, one of its preload parent loops. + bool IsInsidePreloadParentLoop(Value buffer, Operation *op) const; + /// Have all alias buffer been killed. bool AllDeadAfter(Operation *op, SetVector aliasVec, Liveness live) const; @@ -469,6 +484,10 @@ class MemLivenessAnalysis { /// the loop body. DenseMap delayedLoopEntryGenBuffers; + /// Multi-buffer local allocs used by annotated CV preload scopes. Each value + /// maps to the loop operations that need loop-wide lifetimes for it. + DenseMap> preloadBufferParentLoops; + /// map on buffer alias DenseMap> buffer2AliasVec; diff --git a/lib/PTO/Transforms/PTOViewToMemref.cpp b/lib/PTO/Transforms/PTOViewToMemref.cpp index b5b99cc3a..5dd14710e 100644 --- a/lib/PTO/Transforms/PTOViewToMemref.cpp +++ b/lib/PTO/Transforms/PTOViewToMemref.cpp @@ -14,6 +14,7 @@ #include "PTO/IR/PTO.h" #include "PTO/IR/PTOTypeUtils.h" +#include "PTO/Transforms/MultiBuffer.h" #include "PTO/Transforms/Passes.h" #include "mlir/Dialect/Arith/IR/Arith.h" @@ -56,6 +57,11 @@ static constexpr llvm::StringLiteral kLoweredSetValidShapeAttrName = static constexpr llvm::StringLiteral kForceDynamicValidShapeAttrName = "__pto.force_dynamic_valid_shape"; +static void forwardMultiBufferAttr(Operation *from, Operation *to) { + if (Attribute attr = from->getAttr(kPtoMultiBufferAttrName)) + to->setAttr(kPtoMultiBufferAttrName, attr); +} + namespace { static void markForceDynamicValidShape(Operation *op, bool force, @@ -668,10 +674,12 @@ static LogicalResult lowerAllocTileOps(func::FuncOp func, MLIRContext *ctx) { auto pc = rewriter.create( loc, targetType, ValueRange{addr}, vRow ? vRow : Value(), vCol ? vCol : Value(), configAttr); + forwardMultiBufferAttr(op, pc); markForceDynamicValidShape(pc, tbTy.hasDynamicValid(), ctx); auto bindOp = rewriter.create( loc, targetType, pc.getResult(), vRow ? vRow : Value(), vCol ? vCol : Value(), configAttr); + forwardMultiBufferAttr(op, bindOp); markForceDynamicValidShape(bindOp, tbTy.hasDynamicValid(), ctx); rewriter.replaceOp(op, bindOp.getResult()); continue; @@ -680,10 +688,12 @@ static LogicalResult lowerAllocTileOps(func::FuncOp func, MLIRContext *ctx) { auto allocLayout = StridedLayoutAttr::get(ctx, 0, strides); auto allocType = MemRefType::get(shape, elemTy, allocLayout, tbTy.getMemorySpace()); - Value alloc = rewriter.create(loc, allocType); + auto allocOp = rewriter.create(loc, allocType); + forwardMultiBufferAttr(op, allocOp); auto bindOp = rewriter.create( - loc, targetType, alloc, vRow ? vRow : Value(), vCol ? vCol : Value(), - configAttr); + loc, targetType, allocOp.getResult(), vRow ? vRow : Value(), + vCol ? vCol : Value(), configAttr); + forwardMultiBufferAttr(op, bindOp); markForceDynamicValidShape(bindOp, tbTy.hasDynamicValid(), ctx); rewriter.replaceOp(op, bindOp.getResult()); } @@ -1240,10 +1250,12 @@ struct PTOViewToMemrefPass auto pc = rewriter.create( loc, targetType, ValueRange{addr}, vRow ? vRow : Value(), vCol ? vCol : Value(), configAttr); + forwardMultiBufferAttr(op, pc); markForceDynamicValidShape(pc, tbTy.hasDynamicValid(), ctx); auto bindOp = rewriter.create( loc, targetType, pc.getResult(), vRow ? vRow : Value(), vCol ? vCol : Value(), configAttr); + forwardMultiBufferAttr(op, bindOp); markForceDynamicValidShape(bindOp, tbTy.hasDynamicValid(), ctx); rewriter.replaceOp(op, bindOp.getResult()); continue; @@ -1253,12 +1265,14 @@ struct PTOViewToMemrefPass // memref.alloc 要求明确的 layout,不能是动态 offset。 auto allocLayout = StridedLayoutAttr::get(ctx, 0, strides); // offset = 0 auto allocType = MemRefType::get(shape, elemTy, allocLayout, tbTy.getMemorySpace()); - Value alloc = rewriter.create(loc, allocType); + auto allocOp = rewriter.create(loc, allocType); + forwardMultiBufferAttr(op, allocOp); // BindTileOp 的 Builder 会自动处理空的 Value,将其视为静态维度 auto bindOp = rewriter.create( - loc, targetType, alloc, vRow ? vRow : Value(), vCol ? vCol : Value(), - configAttr); + loc, targetType, allocOp.getResult(), vRow ? vRow : Value(), + vCol ? vCol : Value(), configAttr); + forwardMultiBufferAttr(op, bindOp); markForceDynamicValidShape(bindOp, tbTy.hasDynamicValid(), ctx); rewriter.replaceOp(op, bindOp.getResult()); diff --git a/test/lit/pto/cv_create_preload_scope_pipeline.pto b/test/lit/pto/cv_create_preload_scope_pipeline.pto new file mode 100644 index 000000000..2056b367c --- /dev/null +++ b/test/lit/pto/cv_create_preload_scope_pipeline.pto @@ -0,0 +1,75 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ptoas --enable-cv-create-preload --mlir-print-ir-after=pto-plan-memory %s 2>&1 1>/dev/null | FileCheck %s --check-prefix=PLAN +// RUN: ptoas --enable-cv-create-preload --mlir-print-ir-after=pto-cv-create-preload %s 2>&1 1>/dev/null | FileCheck %s --check-prefix=CREATE +// RUN: ptoas --enable-cv-create-preload --mlir-print-ir-after=pto-enable-multi-buffer %s 2>&1 1>/dev/null | FileCheck %s --check-prefix=MB + +module { + func.func @cv_create_preload(%arg0: memref<16x16xf16, #pto.address_space>, + %arg1: memref<16x16xf16, #pto.address_space>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %buf = memref.alloc() {pto.multi_buffer = 2 : i32} + : memref<16x16xf16, #pto.address_space> + + scf.for %i = %c0 to %c4 step %c1 { + pto.cv.scope { + pto.tload ins(%arg0 : memref<16x16xf16, #pto.address_space>) + outs(%buf : memref<16x16xf16, #pto.address_space>) + } {pto.cv.core = "vector", pto.cv.group_id = 0 : i64, + pto.cv.input_pipe = "", pto.cv.max_preload_num = 2 : i64, + pto.cv.output_pipe = "c2v:1", pto.cv.preload_num = 1 : i64, + pto.cv.role = "producer", pto.cv.scope_id = 0 : i64} + + pto.cv.scope { + pto.tstore ins(%buf : memref<16x16xf16, #pto.address_space>) + outs(%arg1 : memref<16x16xf16, #pto.address_space>) + } {pto.cv.core = "vector", pto.cv.group_id = 0 : i64, + pto.cv.input_pipe = "c2v:1", pto.cv.max_preload_num = 2 : i64, + pto.cv.output_pipe = "", pto.cv.preload_num = 0 : i64, + pto.cv.role = "consumer", pto.cv.scope_id = 1 : i64} + } + return + } +} + +// PLAN-LABEL: IR Dump After PlanMemory +// PLAN-LABEL: func.func @cv_create_preload +// PLAN-NOT: memref.alloc +// PLAN: pto.pointer_cast(%{{.*}}, %{{.*}}) : +// PLAN: pto.cv.scope +// PLAN: } {{.*}}pto.cv.preload_num = 1 +// PLAN: pto.cv.scope +// PLAN: } {{.*}}pto.cv.preload_num = 0 + +// CREATE-LABEL: IR Dump After PTOCVCreatePreload +// CREATE-LABEL: func.func @cv_create_preload +// CREATE: arith.addi +// CREATE: scf.for +// CREATE: pto.pointer_cast(%{{.*}}, %{{.*}}) : +// CREATE: pto.pointer_cast(%{{.*}}, %{{.*}}) : +// CREATE: scf.if +// CREATE: pto.cv.scope +// CREATE: pto.tload +// CREATE: } {{.*}}pto.cv.preload_num = 1 +// CREATE: scf.if +// CREATE: pto.cv.scope +// CREATE: pto.tstore +// CREATE: } {{.*}}pto.cv.preload_num = 0 + +// MB-LABEL: IR Dump After PTOEnableMultiBuffer +// MB-LABEL: func.func @cv_create_preload +// MB: pto.pointer_cast(%{{[^,)]+}}) : +// MB: pto.pointer_cast(%{{[^,)]+}}) : +// MB: scf.for %[[IV:.*]] = +// MB: arith.remui %[[IV]], +// MB: arith.select +// MB: pto.tload +// MB: pto.tstore diff --git a/test/lit/pto/cv_preload_mark_fa_tpipe.pto b/test/lit/pto/cv_preload_mark_fa_tpipe.pto new file mode 100644 index 000000000..992f29421 --- /dev/null +++ b/test/lit/pto/cv_preload_mark_fa_tpipe.pto @@ -0,0 +1,200 @@ +// RUN: ptoas --pto-level=level3 --emit-pto-ir --mlir-print-ir-after=pto-cv-auto-mark-multi-buffer %s 2>&1 | FileCheck %s --check-prefix=MB +// RUN: ptoas --pto-level=level3 --emit-pto-ir --mlir-print-ir-after=pto-cv-mark-preload-scopes %s 2>&1 | FileCheck %s +// RUN: ptoas --pto-level=level3 --emit-pto-ir --mlir-print-ir-after=pto-view-to-memref %s 2>&1 | FileCheck %s --check-prefix=LOWER +// RUN: ptoas --pto-level=level3 --emit-pto-ir %s 2>&1 | FileCheck %s --check-prefix=INLINE --implicit-check-not=pto.cv.scope +// RUN: not ptoas --pto-level=level3 --enable-cv-create-preload --emit-pto-ir %s 2>&1 | FileCheck %s --check-prefix=CREATE-LEVEL3-ERR + +module { + func.func @cube_kernel(%qk_fifo: !pto.ptr, %q: !pto.ptr, %k: !pto.ptr, %v: !pto.ptr, %p_fifo: !pto.ptr, %pv_fifo: !pto.ptr) + attributes {pto.kernel_kind = #pto.kernel_kind} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + %c0_i64 = arith.constant 0 : i64 + %c32768_i64 = arith.constant 32768 : i64 + %c65536_i64 = arith.constant 65536 : i64 + %c98304_i64 = arith.constant 98304 : i64 + %c131072_i64 = arith.constant 131072 : i64 + %c163840_i64 = arith.constant 163840 : i64 + %c196608_i64 = arith.constant 196608 : i64 + %c229376_i64 = arith.constant 229376 : i64 + + %qk_desc = pto.make_tensor_view %qk_fifo, shape = [%c128, %c256], strides = [%c256, %c1] : !pto.tensor_view<128x256xf32> + pto.aic_initialize_pipe{id = 25, dir_mask = 1, slot_size = 131072} (gm_slot_tensor = %qk_desc : !pto.tensor_view<128x256xf32>) + %p_desc = pto.make_tensor_view %p_fifo, shape = [%c128, %c256], strides = [%c256, %c1] : !pto.tensor_view<128x256xf16> + pto.aic_initialize_pipe{id = 30, dir_mask = 2, slot_size = 65536} (gm_slot_tensor = %p_desc : !pto.tensor_view<128x256xf16>) + %pv_desc = pto.make_tensor_view %pv_fifo, shape = [%c128, %c128], strides = [%c128, %c1] : !pto.tensor_view<128x128xf32> + pto.aic_initialize_pipe{id = 27, dir_mask = 1, slot_size = 65536} (gm_slot_tensor = %pv_desc : !pto.tensor_view<128x128xf32>) + + %q_tensor = pto.make_tensor_view %q, shape = [%c128, %c128], strides = [%c128, %c1] : !pto.tensor_view<128x128xf16> + %k_tensor = pto.make_tensor_view %k, shape = [%c128, %c256], strides = [%c256, %c1] : !pto.tensor_view<128x256xf16> + %v_tensor = pto.make_tensor_view %v, shape = [%c256, %c128], strides = [%c128, %c1] : !pto.tensor_view<256x128xf16> + + %q_mat = pto.alloc_tile addr = %c0_i64 : !pto.tile_buf + %q_left = pto.alloc_tile addr = %c32768_i64 : !pto.tile_buf + %k_mat = pto.alloc_tile addr = %c65536_i64 : !pto.tile_buf + %k_right = pto.alloc_tile addr = %c0_i64 : !pto.tile_buf + %qk_acc = pto.alloc_tile addr = %c98304_i64 : !pto.tile_buf + %p_mat = pto.alloc_tile addr = %c131072_i64 : !pto.tile_buf + %p_left = pto.alloc_tile addr = %c163840_i64 : !pto.tile_buf + %v_mat = pto.alloc_tile addr = %c196608_i64 : !pto.tile_buf + %v_right = pto.alloc_tile addr = %c0_i64 : !pto.tile_buf + %pv_acc = pto.alloc_tile addr = %c229376_i64 : !pto.tile_buf + + scf.for %i = %c0 to %c2 step %c1 { + // Cube QK producer. + %q_part = pto.partition_view %q_tensor, offsets = [%c0, %c0], sizes = [%c128, %c128] : !pto.tensor_view<128x128xf16> -> !pto.partition_tensor_view<128x128xf16> + pto.tload ins(%q_part : !pto.partition_tensor_view<128x128xf16>) outs(%q_mat : !pto.tile_buf) + pto.tmov ins(%q_mat : !pto.tile_buf) outs(%q_left : !pto.tile_buf) + %k_part_0 = pto.partition_view %k_tensor, offsets = [%c0, %c0], sizes = [%c128, %c128] : !pto.tensor_view<128x256xf16> -> !pto.partition_tensor_view<128x128xf16> + pto.tload ins(%k_part_0 : !pto.partition_tensor_view<128x128xf16>) outs(%k_mat : !pto.tile_buf) + pto.tmov ins(%k_mat : !pto.tile_buf) outs(%k_right : !pto.tile_buf) + %qk_acc_lo = pto.subview %qk_acc[%c0, %c0] sizes [128, 128] : !pto.tile_buf -> !pto.tile_buf + pto.tmatmul ins(%q_left, %k_right : !pto.tile_buf, !pto.tile_buf) outs(%qk_acc_lo : !pto.tile_buf) + %k_part_1 = pto.partition_view %k_tensor, offsets = [%c0, %c128], sizes = [%c128, %c128] : !pto.tensor_view<128x256xf16> -> !pto.partition_tensor_view<128x128xf16> + pto.tload ins(%k_part_1 : !pto.partition_tensor_view<128x128xf16>) outs(%k_mat : !pto.tile_buf) + pto.tmov ins(%k_mat : !pto.tile_buf) outs(%k_right : !pto.tile_buf) + %qk_acc_hi = pto.subview %qk_acc[%c0, %c128] sizes [128, 128] : !pto.tile_buf -> !pto.tile_buf + pto.tmatmul ins(%q_left, %k_right : !pto.tile_buf, !pto.tile_buf) outs(%qk_acc_hi : !pto.tile_buf) + %qk_push = pto.talloc_to_aiv {id = 25, split = 0} -> !pto.tensor_view<128x256xf32> + %qk_push_part = pto.partition_view %qk_push, offsets = [%c0, %c0], sizes = [%c128, %c256] : !pto.tensor_view<128x256xf32> -> !pto.partition_tensor_view<128x256xf32> + pto.tstore ins(%qk_acc : !pto.tile_buf) outs(%qk_push_part : !pto.partition_tensor_view<128x256xf32>) + pto.tpush_to_aiv(%qk_push : !pto.tensor_view<128x256xf32>) {id = 25, split = 0} + + // Cube PV relay. + %p_pop = pto.tpop_from_aiv {id = 30, split = 0} -> !pto.tensor_view<128x256xf16> + %p_pop_part_0 = pto.partition_view %p_pop, offsets = [%c0, %c0], sizes = [%c128, %c128] : !pto.tensor_view<128x256xf16> -> !pto.partition_tensor_view<128x128xf16> + pto.tload ins(%p_pop_part_0 : !pto.partition_tensor_view<128x128xf16>) outs(%p_mat : !pto.tile_buf) + pto.tmov ins(%p_mat : !pto.tile_buf) outs(%p_left : !pto.tile_buf) + %v_part_0 = pto.partition_view %v_tensor, offsets = [%c0, %c0], sizes = [%c128, %c128] : !pto.tensor_view<256x128xf16> -> !pto.partition_tensor_view<128x128xf16> + pto.tload ins(%v_part_0 : !pto.partition_tensor_view<128x128xf16>) outs(%v_mat : !pto.tile_buf) + pto.tmov ins(%v_mat : !pto.tile_buf) outs(%v_right : !pto.tile_buf) + pto.tmatmul ins(%p_left, %v_right : !pto.tile_buf, !pto.tile_buf) outs(%pv_acc : !pto.tile_buf) + %p_pop_part_1 = pto.partition_view %p_pop, offsets = [%c0, %c128], sizes = [%c128, %c128] : !pto.tensor_view<128x256xf16> -> !pto.partition_tensor_view<128x128xf16> + pto.tload ins(%p_pop_part_1 : !pto.partition_tensor_view<128x128xf16>) outs(%p_mat : !pto.tile_buf) + pto.tmov ins(%p_mat : !pto.tile_buf) outs(%p_left : !pto.tile_buf) + pto.tfree_from_aiv(%p_pop : !pto.tensor_view<128x256xf16>) {id = 30, split = 0} + %v_part_1 = pto.partition_view %v_tensor, offsets = [%c128, %c0], sizes = [%c128, %c128] : !pto.tensor_view<256x128xf16> -> !pto.partition_tensor_view<128x128xf16> + pto.tload ins(%v_part_1 : !pto.partition_tensor_view<128x128xf16>) outs(%v_mat : !pto.tile_buf) + pto.tmov ins(%v_mat : !pto.tile_buf) outs(%v_right : !pto.tile_buf) + pto.tmatmul.acc ins(%pv_acc, %p_left, %v_right : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%pv_acc : !pto.tile_buf) + %pv_push = pto.talloc_to_aiv {id = 27, split = 0} -> !pto.tensor_view<128x128xf32> + %pv_push_part = pto.partition_view %pv_push, offsets = [%c0, %c0], sizes = [%c128, %c128] : !pto.tensor_view<128x128xf32> -> !pto.partition_tensor_view<128x128xf32> + pto.tstore ins(%pv_acc : !pto.tile_buf) outs(%pv_push_part : !pto.partition_tensor_view<128x128xf32>) + pto.tpush_to_aiv(%pv_push : !pto.tensor_view<128x128xf32>) {id = 27, split = 0} + } + return + } + + func.func @vector_kernel(%qk_fifo: !pto.ptr, %out: !pto.ptr, %p_fifo: !pto.ptr, %pv_fifo: !pto.ptr) + attributes {pto.kernel_kind = #pto.kernel_kind} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + %c196608_i64 = arith.constant 196608 : i64 + %c262144_i64 = arith.constant 262144 : i64 + %c327680_i64 = arith.constant 327680 : i64 + %c360448_i64 = arith.constant 360448 : i64 + %c393216_i64 = arith.constant 393216 : i64 + %cst = arith.constant 0.0883883461 : f32 + + %qk_desc = pto.make_tensor_view %qk_fifo, shape = [%c128, %c256], strides = [%c256, %c1] : !pto.tensor_view<128x256xf32> + pto.aiv_initialize_pipe{id = 25, dir_mask = 1, slot_size = 131072} (gm_slot_tensor = %qk_desc : !pto.tensor_view<128x256xf32>) + %p_desc = pto.make_tensor_view %p_fifo, shape = [%c128, %c256], strides = [%c256, %c1] : !pto.tensor_view<128x256xf16> + pto.aiv_initialize_pipe{id = 30, dir_mask = 2, slot_size = 65536} (gm_slot_tensor = %p_desc : !pto.tensor_view<128x256xf16>) + %pv_desc = pto.make_tensor_view %pv_fifo, shape = [%c128, %c128], strides = [%c128, %c1] : !pto.tensor_view<128x128xf32> + pto.aiv_initialize_pipe{id = 27, dir_mask = 1, slot_size = 65536} (gm_slot_tensor = %pv_desc : !pto.tensor_view<128x128xf32>) + %out_tensor = pto.make_tensor_view %out, shape = [%c64, %c128], strides = [%c128, %c1] : !pto.tensor_view<64x128xf32> + + %qk_vec = pto.alloc_tile addr = %c196608_i64 : !pto.tile_buf + %p_work = pto.alloc_tile addr = %c262144_i64 : !pto.tile_buf + %p_vec = pto.alloc_tile addr = %c327680_i64 : !pto.tile_buf + %pv_vec = pto.alloc_tile addr = %c360448_i64 : !pto.tile_buf + %gu_vec = pto.alloc_tile addr = %c393216_i64 : !pto.tile_buf + + scf.for %i = %c0 to %c2 step %c1 { + // Vector P relay. + %qk_pop = pto.tpop_from_aic {id = 25, split = 0} -> !pto.tensor_view<128x256xf32> + %qk_pop_part_0 = pto.partition_view %qk_pop, offsets = [%c0, %c0], sizes = [%c32, %c256] : !pto.tensor_view<128x256xf32> -> !pto.partition_tensor_view<32x256xf32> + pto.tload ins(%qk_pop_part_0 : !pto.partition_tensor_view<32x256xf32>) outs(%qk_vec : !pto.tile_buf) + pto.tmuls ins(%qk_vec, %cst : !pto.tile_buf, f32) outs(%p_work : !pto.tile_buf) + pto.texp ins(%p_work : !pto.tile_buf) outs(%p_work : !pto.tile_buf) + pto.tcvt ins(%p_work {rmode = #pto} : !pto.tile_buf) outs(%p_vec : !pto.tile_buf) + %p_push = pto.talloc_to_aic {id = 30, split = 0} -> !pto.tensor_view<128x256xf16> + %p_push_part_0 = pto.partition_view %p_push, offsets = [%c0, %c0], sizes = [%c32, %c256] : !pto.tensor_view<128x256xf16> -> !pto.partition_tensor_view<32x256xf16> + pto.tstore ins(%p_vec : !pto.tile_buf) outs(%p_push_part_0 : !pto.partition_tensor_view<32x256xf16>) + %qk_pop_part_1 = pto.partition_view %qk_pop, offsets = [%c32, %c0], sizes = [%c32, %c256] : !pto.tensor_view<128x256xf32> -> !pto.partition_tensor_view<32x256xf32> + pto.tload ins(%qk_pop_part_1 : !pto.partition_tensor_view<32x256xf32>) outs(%qk_vec : !pto.tile_buf) + pto.tmuls ins(%qk_vec, %cst : !pto.tile_buf, f32) outs(%p_work : !pto.tile_buf) + pto.texp ins(%p_work : !pto.tile_buf) outs(%p_work : !pto.tile_buf) + pto.tcvt ins(%p_work {rmode = #pto} : !pto.tile_buf) outs(%p_vec : !pto.tile_buf) + %p_push_part_1 = pto.partition_view %p_push, offsets = [%c32, %c0], sizes = [%c32, %c256] : !pto.tensor_view<128x256xf16> -> !pto.partition_tensor_view<32x256xf16> + pto.tstore ins(%p_vec : !pto.tile_buf) outs(%p_push_part_1 : !pto.partition_tensor_view<32x256xf16>) + pto.tfree_from_aic(%qk_pop : !pto.tensor_view<128x256xf32>) {id = 25, split = 0} + pto.tpush_to_aic(%p_push : !pto.tensor_view<128x256xf16>) {id = 30, split = 0} + + // Vector GU consumer. + %pv_pop = pto.tpop_from_aic {id = 27, split = 0} -> !pto.tensor_view<128x128xf32> + %pv_pop_part = pto.partition_view %pv_pop, offsets = [%c0, %c0], sizes = [%c64, %c128] : !pto.tensor_view<128x128xf32> -> !pto.partition_tensor_view<64x128xf32> + pto.tload ins(%pv_pop_part : !pto.partition_tensor_view<64x128xf32>) outs(%pv_vec : !pto.tile_buf) + pto.tmov ins(%pv_vec : !pto.tile_buf) outs(%gu_vec : !pto.tile_buf) + pto.tfree_from_aic(%pv_pop : !pto.tensor_view<128x128xf32>) {id = 27, split = 0} + %out_part = pto.partition_view %out_tensor, offsets = [%c0, %c0], sizes = [%c64, %c128] : !pto.tensor_view<64x128xf32> -> !pto.partition_tensor_view<64x128xf32> + pto.tstore ins(%gu_vec : !pto.tile_buf) outs(%out_part : !pto.partition_tensor_view<64x128xf32>) + } + return + } +} + +// MB-LABEL: IR Dump After PTOCVAutoMarkMultiBuffer +// MB-LABEL: func.func @cube_kernel +// MB: pto.alloc_tile addr = {{.*}} {pto.multi_buffer = 4 : i32} : !pto.tile_buf +// MB: pto.alloc_tile addr = {{.*}} {pto.multi_buffer = 4 : i32} : !pto.tile_buf + +// LOWER-LABEL: IR Dump After PTOViewToMemref +// LOWER: pto.pointer_cast{{.*}}pto.multi_buffer = 4 : i32 +// LOWER: pto.bind_tile{{.*}}pto.multi_buffer = 4 : i32 + +// CHECK-LABEL: func.func @cube_kernel +// CHECK: pto.cv.scope { +// CHECK: pto.tload +// CHECK: pto.tmatmul +// CHECK: pto.talloc_to_aiv{{.*}}id = 25 +// CHECK: pto.tstore +// CHECK: pto.tpush_to_aiv{{.*}}id = 25 +// CHECK: } {pto.cv.core = "cube"{{.*}}pto.cv.input_pipe = ""{{.*}}pto.cv.max_preload_num = 4{{.*}}pto.cv.output_pipe = "c2v:25"{{.*}}pto.cv.preload_num = 3{{.*}}pto.cv.role = "producer" +// CHECK: pto.cv.scope { +// CHECK: pto.tpop_from_aiv {{.*}}id = 30 +// CHECK: pto.tmatmul +// CHECK: pto.tfree_from_aiv +// CHECK: pto.talloc_to_aiv{{.*}}id = 27 +// CHECK: pto.tpush_to_aiv{{.*}}id = 27 +// CHECK: } {pto.cv.core = "cube"{{.*}}pto.cv.input_pipe = "v2c:30"{{.*}}pto.cv.output_pipe = "c2v:27"{{.*}}pto.cv.preload_num = 1{{.*}}pto.cv.role = "relay" +// CHECK-LABEL: func.func @vector_kernel +// CHECK: pto.cv.scope { +// CHECK: pto.tpop_from_aic {{.*}}id = 25 +// CHECK: pto.texp +// CHECK: pto.tcvt +// CHECK: pto.talloc_to_aic{{.*}}id = 30 +// CHECK: pto.tfree_from_aic +// CHECK: pto.tpush_to_aic{{.*}}id = 30 +// CHECK: } {pto.cv.core = "vector"{{.*}}pto.cv.input_pipe = "c2v:25"{{.*}}pto.cv.output_pipe = "v2c:30"{{.*}}pto.cv.preload_num = 2{{.*}}pto.cv.role = "relay" +// CHECK: pto.cv.scope { +// CHECK: pto.tpop_from_aic {{.*}}id = 27 +// CHECK: pto.tmov +// CHECK: pto.tfree_from_aic +// CHECK: pto.tstore +// CHECK: } {pto.cv.core = "vector"{{.*}}pto.cv.input_pipe = "c2v:27"{{.*}}pto.cv.output_pipe = ""{{.*}}pto.cv.preload_num = 0{{.*}}pto.cv.role = "consumer" + +// INLINE: module + +// CREATE-LEVEL3-ERR: uses preload local buffer with only 1 planned address(es); expected at least 4 diff --git a/tools/ptoas/ptoas.cpp b/tools/ptoas/ptoas.cpp index 45e98cdfb..61bffd513 100644 --- a/tools/ptoas/ptoas.cpp +++ b/tools/ptoas/ptoas.cpp @@ -189,6 +189,12 @@ static llvm::cl::opt enableMultiBufferLowering( "single-address casts plus an iv mod N arith.select"), llvm::cl::init(false)); +static llvm::cl::opt enableCVCreatePreload( + "enable-cv-create-preload", + llvm::cl::desc("Expand annotated pto.cv.scope preload stages after memory " + "planning"), + llvm::cl::init(false)); + static llvm::cl::opt enableGraphSyncSolver( "enable-graph-sync-solver", llvm::cl::desc("Enable the graph-based intra-core sync solver " @@ -1131,8 +1137,12 @@ int main(int argc, char **argv) { pm.addNestedPass( pto::createPTOAssignDefaultFrontendPipeIdPass()); + pm.addPass(pto::createPTOCVAutoMarkMultiBufferPass()); + pm.addPass(pto::createPTOCVMarkPreloadScopesPass()); pm.addNestedPass( pto::createPTOLowerFrontendPipeOpsPass()); + if (!enableCVCreatePreload) + pm.addPass(pto::createPTOInlineCVPreloadScopesPass()); //pm.addNestedPass(pto::createPTOVerifyTFreePass()); pm.addPass(pto::createPTOInferValidatePipeInitPass()); pm.addNestedPass(pto::createLoweringSyncToPipePass()); @@ -1151,6 +1161,11 @@ int main(int argc, char **argv) { } pm.addPass(pto::createPTOResolveReservedBuffersPass()); + if (enableCVCreatePreload) { + pm.addPass(pto::createPTOCVCreatePreloadPass()); + pm.addPass(pto::createPTOInlineCVPreloadScopesPass()); + } + // Conditionally add Sync pass based on flag. The two solvers are mutually // exclusive (validated above); GraphSyncSolver is the experimental new // path that lives next to PTOInsertSync. @@ -1167,7 +1182,7 @@ int main(int argc, char **argv) { // variadic pto.pointer_cast emitted by PlanMemory and replaces it with // single-address casts plus an iv mod N selector. Decoupled from the sync // solver choice because the pointer_cast geometry is solver-agnostic. - if (enableMultiBufferLowering) + if (enableMultiBufferLowering || enableCVCreatePreload) pm.addNestedPass( pto::createPTOEnableMultiBufferPass()); From ef2b6c6a3f71eb051384f4bedb5f2afca37a9d6b Mon Sep 17 00:00:00 2001 From: zhangstevenunity <128771452+zhangstevenunity@users.noreply.github.com> Date: Wed, 6 May 2026 17:01:04 +0800 Subject: [PATCH 10/10] multi-buffer / cv-preload: address review P0 + P1 issues P0 fixes (correctness, untested edges): * CVCreatePreload: use llvm::MapVector for LocalBufferBinding so the cloned per-stage pointer_cast / bind_tile order is deterministic across builds (DenseMap iteration order had been leaking pointer-hash into the IR). * Multi-buffer slot index now computes ((iv - lb) / step) mod N instead of iv mod N. The previous form silently dropped slots when gcd(step, N) > 1 (e.g. step=2/N=4 only ever selected {0,2}). Fixed in three codegen paths (PTOEnableMultiBuffer, GSS SyncSolverCodeGen, InsertSync SyncCodegen); step=1, lb=0 still takes the single-remui fast path. * GraphSyncSolver multi-buffer codegen now preserves the original sync anchors. cp->op1 / cp->op2 are syncIR-ordered, not MLIR-lexical-ordered; resolve via Operation::isBeforeInBlock so wait_flag_dyn lands before the iteration's first user and set_flag_dyn after the iteration's last user. Anchors in different blocks (e.g. nested scf.if) fall back to the prior body-start / before-terminator placement. P1 fixes (semantic gaps): * SyncSolver::getMultiBufferEventIdInfo now matches HIVM and PTO InsertSync: any dep pair with n<2 collapses the whole pair to single-buffer rather than being silently skipped. * CVMarkPreloadScopes: ScopeClass key now includes parent scf.for so two unrelated loops sharing the same (core, input, output) pipe stay in separate classes; fan-out / fan-in along a logical pipe is detected and emits a diagnostic instead of silently dropping siblings. * CVMarkPreloadScopes: TPipe ops nested under scf.if / scf.for inside the scope-recognition for now emit a warning and skip auto-marking, instead of being missed by the direct-children walk. * CVCreatePreload: stage-rotated multi-address pointer_casts are hoisted above the rotation loop instead of being duplicated max*N times in the body. PTOEnableMultiBuffer learns to infer the rotation loop from a loop-invariant cast's users so the lowering still finds it. * CVMarkPreloadScopes::canWrapNoResultScope walks every op (including ops inside nested regions of the moved range) when checking for SSA escapes, not just the top-level moved ops. Tests: * New regression test/lit/pto/multi_buffer_step_not_one.pto covers the step!=1 slot-index path. * Existing cv_create_preload_scope_pipeline.pto updated for the hoisted stage pointer_cast layout. * Full lit suite (194 tests) passes. Co-Authored-By: Claude Opus 4.7 --- include/PTO/Transforms/MultiBuffer.h | 5 +- lib/PTO/Transforms/CVCreatePreloadPass.cpp | 39 ++++- .../Transforms/CVMarkPreloadScopesPass.cpp | 144 +++++++++++++++--- .../Transforms/GraphSyncSolver/SyncSolver.cpp | 29 +++- .../GraphSyncSolver/SyncSolverCodeGen.cpp | 132 ++++++++++++++-- lib/PTO/Transforms/InsertSync/SyncCodegen.cpp | 24 ++- lib/PTO/Transforms/PTOEnableMultiBuffer.cpp | 62 +++++++- lib/PTO/Transforms/PTOPlanMemory.cpp | 8 + .../pto/cv_create_preload_scope_pipeline.pto | 7 +- test/lit/pto/multi_buffer_step_not_one.pto | 47 ++++++ 10 files changed, 445 insertions(+), 52 deletions(-) create mode 100644 test/lit/pto/multi_buffer_step_not_one.pto diff --git a/include/PTO/Transforms/MultiBuffer.h b/include/PTO/Transforms/MultiBuffer.h index ea5c494fe..df7ac54cd 100644 --- a/include/PTO/Transforms/MultiBuffer.h +++ b/include/PTO/Transforms/MultiBuffer.h @@ -17,7 +17,10 @@ namespace pto { /// Attribute name for multi-buffer depth on `memref.alloc` (integer slot count N>=2). inline constexpr llvm::StringLiteral kPtoMultiBufferAttrName = "pto.multi_buffer"; -/// Upper bound for N; must stay consistent with `MAX_MULTI_BUFFER_NUM` in insert-sync. +/// Upper bound for N; must stay consistent with `MAX_MULTI_BUFFER_NUM` in +/// insert-sync's SyncCommon.h. The static_assert that pins these two values +/// together lives in PTOPlanMemory.cpp (which already includes both headers) +/// so this header stays cheap to include from CV/multi-buffer paths. inline constexpr unsigned kPtoMultiBufferMaxNum = 16; } // namespace pto diff --git a/lib/PTO/Transforms/CVCreatePreloadPass.cpp b/lib/PTO/Transforms/CVCreatePreloadPass.cpp index fb9dea4f3..ad56f20a6 100644 --- a/lib/PTO/Transforms/CVCreatePreloadPass.cpp +++ b/lib/PTO/Transforms/CVCreatePreloadPass.cpp @@ -18,6 +18,7 @@ #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/MapVector.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" @@ -116,9 +117,15 @@ static SmallVector rotateAddrs(PointerCastOp pointerCast, return rotated; } +// MapVector keeps insertion order so the IR generated by rewritePreloadLoop is +// stable across runs. A plain DenseMap iterates in pointer-hash order, which +// makes the cloned per-stage pointer_cast / bind_tile sequence non-deterministic +// and shifts SSA numbering on every build (P0-1 fix). +using LocalBufferBindingMap = llvm::MapVector; + static LogicalResult collectLocalBufferBindings( CVScopeOp scope, int64_t maxPreloadNum, scf::ForOp forOp, - DenseMap &bindings) { + LocalBufferBindingMap &bindings) { auto walkResult = scope.walk([&](Operation *op) -> WalkResult { for (Value operand : op->getOperands()) { if (!operand || !forOp.isDefinedOutsideOfLoop(operand)) @@ -277,7 +284,7 @@ static LogicalResult rewritePreloadLoop(IRRewriter &rewriter, scf::ForOp forOp, "pto-cv-create-preload currently supports scf.for without iter_args"); } - DenseMap localBindings; + LocalBufferBindingMap localBindings; for (CVScopeOp scope : scopes) { if (failed( collectLocalBufferBindings(scope, maxPreloadNum, forOp, @@ -285,21 +292,43 @@ static LogicalResult rewritePreloadLoop(IRRewriter &rewriter, scf::ForOp forOp, return failure(); } + // MapVector preserves insertion order; iterate via .takeVector() equivalent + // (range-for over the MapVector visits in insertion order, not pointer-hash + // order). The SSA values cloned per stage thus appear in a stable sequence. SmallVector orderedBindings; orderedBindings.reserve(localBindings.size()); for (const auto &it : localBindings) orderedBindings.push_back(it.second); Location loc = forOp.getLoc(); + + // P1-7: Hoist per-stage pointer_cast / bind_tile clones to *above* the new + // for op. The original addrs are loop-invariant (PlanMemory emits constant + // i64 offsets), so the rotated multi-address casts are loop-invariant too. + // Inserting them inside the new for body makes the subsequent + // pto-enable-multi-buffer pass see a multi-address pointer_cast inside a + // loop, which then re-hoists N single-address casts per stage; with M stages + // and N slots that is M*N copies instead of N. Hoisting once here keeps the + // rotation deterministic and gives the next pass a single-address-per-stage + // shape to work with. rewriter.setInsertionPoint(forOp); + SmallVector stageMappings(maxPreloadNum); + for (int64_t preloadNum = 0; preloadNum < maxPreloadNum; ++preloadNum) { + for (LocalBufferBinding &binding : orderedBindings) { + clonePointerCastAndBind(rewriter, binding, maxPreloadNum, preloadNum, + stageMappings[preloadNum]); + } + } + Value newUpperBound = createShiftedUpperBound( rewriter, loc, forOp.getUpperBound(), forOp.getStep(), maxPreloadNum); auto newForOp = rewriter.create(loc, forOp.getLowerBound(), newUpperBound, forOp.getStep()); + // Stage IVs depend on the new induction variable and so must live inside + // the new for body. Wire them into the per-stage mappings used by clones. OpBuilder bodyBuilder = OpBuilder::atBlockBegin(newForOp.getBody()); - SmallVector stageMappings(maxPreloadNum); SmallVector stageIVs(maxPreloadNum); for (int64_t preloadNum = 0; preloadNum < maxPreloadNum; ++preloadNum) { Value stageIV = createStageIV(bodyBuilder, loc, newForOp.getInductionVar(), @@ -307,10 +336,6 @@ static LogicalResult rewritePreloadLoop(IRRewriter &rewriter, scf::ForOp forOp, preloadNum); stageIVs[preloadNum] = stageIV; stageMappings[preloadNum].map(forOp.getInductionVar(), stageIV); - for (LocalBufferBinding &binding : orderedBindings) { - clonePointerCastAndBind(bodyBuilder, binding, maxPreloadNum, preloadNum, - stageMappings[preloadNum]); - } } for (Operation &op : forOp.getBody()->without_terminator()) { diff --git a/lib/PTO/Transforms/CVMarkPreloadScopesPass.cpp b/lib/PTO/Transforms/CVMarkPreloadScopesPass.cpp index c8af8f9e2..68abcdc0a 100644 --- a/lib/PTO/Transforms/CVMarkPreloadScopesPass.cpp +++ b/lib/PTO/Transforms/CVMarkPreloadScopesPass.cpp @@ -191,6 +191,10 @@ struct ScopeInfo { SmallVector ops; Operation *firstOp = nullptr; Operation *lastOp = nullptr; + // P1-5: parent loop must participate in scope-class identity. Two + // independent scf.for loops in the same kernel that happen to use the same + // pipe must NOT be merged into one preload group. + Operation *parentLoop = nullptr; int64_t classId = -1; }; @@ -198,6 +202,11 @@ struct ScopeClass { CoreKind core = CoreKind::Cube; std::optional input; std::optional output; + // Same parent loop op as the contributing scopes. Pipe-only matching is + // still valid across cores (chain edges follow pipe identity), but within + // a single core+pipe direction we partition by parent loop so siblings of + // distinct loops do not collide. + Operation *parentLoop = nullptr; SmallVector scopeIds; int64_t groupId = -1; int64_t preloadNum = -1; @@ -212,10 +221,14 @@ static std::string getRole(const ScopeInfo &scope) { return "consumer"; } -static std::string getClassKey(CoreKind core, std::optional input, +static std::string getClassKey(CoreKind core, Operation *parentLoop, + std::optional input, std::optional output) { - return (llvm::Twine(stringify(core)) + "|" + llvm::Twine(stringify(input)) + - "|" + llvm::Twine(stringify(output))) + // Include the parent loop pointer so two same-pipe scopes from different + // scf.for nests get their own classes (P1-5). + uintptr_t loopAddr = reinterpret_cast(parentLoop); + return (llvm::Twine(stringify(core)) + "|" + llvm::Twine(loopAddr) + "|" + + llvm::Twine(stringify(input)) + "|" + llvm::Twine(stringify(output))) .str(); } @@ -235,6 +248,7 @@ static std::optional getKernelCore(func::FuncOp funcOp) { } static void collectScopeIfValid(PendingScope &pending, CoreKind core, + Operation *parentLoop, SmallVectorImpl &scopes) { std::optional committedOutput = pending.outputCommitted ? pending.output : std::nullopt; @@ -251,6 +265,7 @@ static void collectScopeIfValid(PendingScope &pending, CoreKind core, scope.ops = pending.ops; scope.firstOp = pending.firstOp; scope.lastOp = pending.lastOp; + scope.parentLoop = parentLoop; scopes.push_back(std::move(scope)); pending.reset(); } @@ -276,11 +291,38 @@ static void includeMissingThrough(PendingScope &pending, Operation *last) { last); } +// P1-6: Diagnose TPipe ops nested under scf.if / scf.for inside this loop. +// The state-machine in collectScopesInFor only iterates direct children of +// the loop body, so nested-region TPipe ops would be silently dropped from +// scope identification (and from auto multi-buffer marking too). Emit one +// warning per loop so the user knows preload analysis is incomplete instead +// of silently producing a partial transaction graph. +static void warnIfNestedTPipeOps(scf::ForOp forOp) { + Block *body = forOp.getBody(); + bool emitted = false; + forOp.getOperation()->walk([&](Operation *op) { + if (emitted) + return WalkResult::interrupt(); + if (op->getBlock() == body) + return WalkResult::advance(); + if (!getPipeAction(op)) + return WalkResult::advance(); + op->emitWarning( + "pto-cv-mark-preload-scopes: TPipe op nested under control flow " + "inside scf.for; V1 only recognizes TPipe ops at the loop body's " + "top level, this transaction is excluded from preload analysis"); + emitted = true; + return WalkResult::interrupt(); + }); +} + static void collectScopesInFor(scf::ForOp forOp, CoreKind core, SmallVectorImpl &scopes) { + warnIfNestedTPipeOps(forOp); PendingScope pending; Operation *segmentStart = &forOp.getBody()->front(); Operation *terminator = forOp.getBody()->getTerminator(); + Operation *parentLoopOp = forOp.getOperation(); for (Operation &op : forOp.getBody()->without_terminator()) { PipeAction action = getPipeAction(&op); @@ -290,7 +332,7 @@ static void collectScopesInFor(scf::ForOp forOp, CoreKind core, if (action.isTPop) { if (pending.active) includeMissingThrough(pending, op.getPrevNode()); - collectScopeIfValid(pending, core, scopes); + collectScopeIfValid(pending, core, parentLoopOp, scopes); pending.input = action.input; includeRange(pending, &op, &op); segmentStart = &op; @@ -299,7 +341,7 @@ static void collectScopesInFor(scf::ForOp forOp, CoreKind core, if (action.isTAlloc) { if (pending.active && pending.output) - collectScopeIfValid(pending, core, scopes); + collectScopeIfValid(pending, core, parentLoopOp, scopes); Operation *start = pending.active ? (pending.lastOp ? pending.lastOp->getNextNode() : &op) @@ -321,7 +363,7 @@ static void collectScopesInFor(scf::ForOp forOp, CoreKind core, includeMissingThrough(pending, &op); } if (!pending.input || pending.inputReleased) { - collectScopeIfValid(pending, core, scopes); + collectScopeIfValid(pending, core, parentLoopOp, scopes); segmentStart = op.getNextNode(); } continue; @@ -331,7 +373,7 @@ static void collectScopesInFor(scf::ForOp forOp, CoreKind core, includeMissingThrough(pending, &op); pending.inputReleased = true; if (pending.outputCommitted) { - collectScopeIfValid(pending, core, scopes); + collectScopeIfValid(pending, core, parentLoopOp, scopes); segmentStart = op.getNextNode(); } continue; @@ -341,20 +383,22 @@ static void collectScopesInFor(scf::ForOp forOp, CoreKind core, if (pending.active) includeMissingThrough(pending, terminator ? terminator->getPrevNode() : nullptr); - collectScopeIfValid(pending, core, scopes); + collectScopeIfValid(pending, core, parentLoopOp, scopes); } static void buildScopeClasses(SmallVectorImpl &scopes, SmallVectorImpl &classes) { llvm::StringMap classByKey; for (ScopeInfo &scope : scopes) { - std::string key = getClassKey(scope.core, scope.input, scope.output); + std::string key = + getClassKey(scope.core, scope.parentLoop, scope.input, scope.output); auto [it, inserted] = classByKey.try_emplace(key, classes.size()); if (inserted) { ScopeClass klass; klass.core = scope.core; klass.input = scope.input; klass.output = scope.output; + klass.parentLoop = scope.parentLoop; classes.push_back(std::move(klass)); } scope.classId = it->second; @@ -363,19 +407,65 @@ static void buildScopeClasses(SmallVectorImpl &scopes, } static void assignPreloadNumbers(SmallVectorImpl &classes) { - std::map inputClass; + // P1-5: detect fan-out / fan-in. A pipe consumed by more than one class is + // a branched chain that the linear preload-num scheme cannot represent; + // skip every class touching that pipe and emit one warning per offending + // pipe so the user knows preload was disabled instead of being silently + // misapplied. + std::map> inputClasses; + std::map> outputClasses; for (auto [idx, klass] : llvm::enumerate(classes)) { if (klass.input) - inputClass.try_emplace(*klass.input, static_cast(idx)); + inputClasses[*klass.input].push_back(static_cast(idx)); + if (klass.output) + outputClasses[*klass.output].push_back(static_cast(idx)); + } + + llvm::DenseSet branchedClasses; + auto warnFanout = [&](const PipeKey &pipe, + const SmallVector &cls, + const char *role) { + for (int64_t idx : cls) + branchedClasses.insert(idx); + if (cls.size() < 2) + return; + Operation *anchor = nullptr; + if (auto *parent = classes[cls.front()].parentLoop) + anchor = parent; + if (anchor) + anchor->emitWarning() << "pto-cv-mark-preload-scopes: pipe " + << stringify(pipe) << " has " << cls.size() << " " + << role + << "s; branched CV chain not supported, " + "skipping preload assignment for these scopes"; + }; + for (auto &kv : inputClasses) + if (kv.second.size() > 1) + warnFanout(kv.first, kv.second, "consumer"); + for (auto &kv : outputClasses) + if (kv.second.size() > 1) + warnFanout(kv.first, kv.second, "producer"); + + // Map pipe -> the unique consumer class (only present when there is exactly + // one consumer for that pipe; otherwise the pipe is in branchedClasses and + // chain construction skips it). + std::map inputClass; + for (auto &kv : inputClasses) { + if (kv.second.size() == 1) + inputClass.try_emplace(kv.first, kv.second.front()); } SmallVector next(classes.size(), -1); for (auto [idx, klass] : llvm::enumerate(classes)) { if (!klass.output) continue; + if (branchedClasses.contains(static_cast(idx))) + continue; auto it = inputClass.find(*klass.output); if (it == inputClass.end()) continue; + if (branchedClasses.contains(it->second)) + continue; next[idx] = it->second; } @@ -383,6 +473,8 @@ static void assignPreloadNumbers(SmallVectorImpl &classes) { for (auto [startIdx, klass] : llvm::enumerate(classes)) { if (klass.input || !klass.output) continue; + if (branchedClasses.contains(static_cast(startIdx))) + continue; SmallVector chain; std::set seen; @@ -523,13 +615,31 @@ static bool canWrapNoResultScope(const ScopeInfo &scope) { break; } - for (Operation *op : moved) { - for (Value result : op->getResults()) { - for (OpOperand &use : result.getUses()) { - if (!isInsideMovedRange(use.getOwner(), moved)) - return false; + // P1-8: walk EVERY op (including nested-region ops) inside the moved range + // and check every result's uses. The earlier version only iterated the + // direct top-level ops, so a result produced inside e.g. a nested scf.if + // body and consumed outside the scope would slip past the check; the splice + // would then move the def while leaving the use behind, breaking dominance + // (and crashing the verifier with "operand defined in a region not visible + // here"). isInsideMovedRange already walks parent ops, so any nested use of + // a moved-range value is correctly accepted. + bool sawEscapedUse = false; + for (Operation *top : moved) { + top->walk([&](Operation *innerOp) { + if (sawEscapedUse) + return WalkResult::interrupt(); + for (Value result : innerOp->getResults()) { + for (OpOperand &use : result.getUses()) { + if (!isInsideMovedRange(use.getOwner(), moved)) { + sawEscapedUse = true; + return WalkResult::interrupt(); + } + } } - } + return WalkResult::advance(); + }); + if (sawEscapedUse) + return false; } return true; } diff --git a/lib/PTO/Transforms/GraphSyncSolver/SyncSolver.cpp b/lib/PTO/Transforms/GraphSyncSolver/SyncSolver.cpp index bf10b699d..5b2cbeef6 100644 --- a/lib/PTO/Transforms/GraphSyncSolver/SyncSolver.cpp +++ b/lib/PTO/Transforms/GraphSyncSolver/SyncSolver.cpp @@ -161,22 +161,39 @@ scf::ForOp Solver::getMultiBufferLoop(RWOperation *rwOp1, RWOperation *rwOp2) { EventIdInfo Solver::getMultiBufferEventIdInfo(RWOperation *rwOp1, RWOperation *rwOp2) { - // Mirrors `checkMultiBufferEventIdInfo` + `getMultiBufferEventIdInfo`: - // 1. All conflict pairs must agree on slot count N >= 2. - // 2. All involved buffers must hang off the same scf.for. - // 3. N is the common slot count (small enough to fit MAX_MULTI_BUFFER_NUM). + // P1-4: Match HIVM and PTO InsertSync semantics. Multi-buffer is only safe + // when EVERY dependency pair between rwOp1/rwOp2 is multi-buffer-eligible + // with the same slot count. If even one pair is a single-buffer real + // dependency, we MUST collapse to single-buffer; otherwise the dyn flag + // serializing the multi-buffer pairs would not protect the single-buffer + // pair, which only happens to be covered if its lifetime accidentally fits + // inside the rotation window. Earlier code used `continue` here, which let + // a single-buffer real dep slip past (Codex Review P2 in the original PR). + // + // Constraints (all must hold): + // 1. Every pair returns the SAME slot count N >= 2 (no n==0 / n==1 pairs). + // 2. N <= MAX_MULTI_BUFFER_NUM. + // 3. All involved buffers hang off the same scf.for. // Returns single-buffer EventIdInfo() on any failure. if (!rwOp1 || !rwOp2) return {}; + bool sawAnyPair = false; unsigned commonN = 0; auto checkPair = [&](const llvm::SmallVector &as, const llvm::SmallVector &bs) -> bool { for (auto *a : as) { for (auto *b : bs) { + // Only count pairs that actually carry a memory dependency. + // getMultiBufferSlotCount returns 0 either when there is no dep at + // all (different physical buffer) or when the dep is a real single + // buffer dep. We need to distinguish: probe MemAlias to decide. + if (!memAnalyzer_.MemAlias(a, b)) + continue; unsigned n = memAnalyzer_.getMultiBufferSlotCount(a, b); if (n < 2) - continue; + return false; // real single-buffer dep -> not safe to MB. + sawAnyPair = true; if (commonN == 0) commonN = n; else if (commonN != n) @@ -191,7 +208,7 @@ EventIdInfo Solver::getMultiBufferEventIdInfo(RWOperation *rwOp1, return {}; if (!checkPair(rwOp1->writeMemInfo, rwOp2->writeMemInfo)) return {}; - if (commonN < 2 || commonN > MAX_MULTI_BUFFER_NUM) + if (!sawAnyPair || commonN < 2 || commonN > MAX_MULTI_BUFFER_NUM) return {}; scf::ForOp loop = getMultiBufferLoop(rwOp1, rwOp2); diff --git a/lib/PTO/Transforms/GraphSyncSolver/SyncSolverCodeGen.cpp b/lib/PTO/Transforms/GraphSyncSolver/SyncSolverCodeGen.cpp index 35dc2df0d..cab3a8729 100644 --- a/lib/PTO/Transforms/GraphSyncSolver/SyncSolverCodeGen.cpp +++ b/lib/PTO/Transforms/GraphSyncSolver/SyncSolverCodeGen.cpp @@ -126,6 +126,23 @@ void CodeGenerator::insertBarrier(IRRewriter &rewriter, OperationBase *anchor, rewriter.create(loc, pipeAttr); } +// P0-2: Slot index = LOGICAL iteration % N = ((iv - lb) / step) % N, NOT +// `iv % N`. When step is not 1 the latter only ever yields a stride-shifted +// subset of slots (e.g. step=2,N=4 -> {0,2}), silently corrupting the rotation. +// Keep a `iv % N` fast path for the canonical step=1, lb=0 case so the +// generated IR stays minimal in the common scenario. +static bool gssIsConstantIndexEqualTo(Value v, int64_t target) { + if (!v) + return false; + if (auto cst = v.getDefiningOp()) + return cst.value() == target; + if (auto cst = v.getDefiningOp()) { + if (auto intAttr = dyn_cast(cst.getValue())) + return intAttr.getInt() == target; + } + return false; +} + Value CodeGenerator::getOrCreateLoopCounter(IRRewriter &rewriter, scf::ForOp forOp, int64_t n, Location loc) { @@ -136,17 +153,57 @@ Value CodeGenerator::getOrCreateLoopCounter(IRRewriter &rewriter, rewriter.setInsertionPointToStart(forOp.getBody()); Value iv = forOp.getInductionVar(); Value cN = rewriter.create(loc, n); - Value rem = rewriter.create(loc, iv, cN); + Value lb = forOp.getLowerBound(); + Value step = forOp.getStep(); + Value normalized = iv; + if (!gssIsConstantIndexEqualTo(lb, 0)) + normalized = rewriter.create(loc, normalized, lb); + if (!gssIsConstantIndexEqualTo(step, 1)) + normalized = rewriter.create(loc, normalized, step); + Value rem = rewriter.create(loc, normalized, cN); loop2BufferCounter_[key] = rem; return rem; } +// P0-3: Anchor preservation. +// +// Earlier this function unconditionally placed the dyn wait at the loop body +// start and the dyn set right before the yield. That works for the simple +// "loop body is one straight segment" case but corrupts every other shape: +// +// * If the consumer RWOp lives under an `scf.if`, the body-start dyn wait +// pops the queue even on iterations where the consumer is skipped, so we +// consume primes that were never produced. +// * If the producer RWOp lives under an `scf.if`, the before-terminator dyn +// set pushes onto the queue every iteration regardless of whether the +// producer ran, drifting the queue depth. +// +// The original solver already chose the actual conflicting RWOperations as +// `cp->op1`/`cp->op2` (these point to the producer/consumer MLIR ops directly, +// inside whatever control flow they live in). Use those positions when they +// are inside the rotation loop body. Fall back to body-start/before-terminator +// only when the anchor is missing or actually lives outside the loop (e.g., +// LCA-driven PlaceHolder anchors that resolved to the loop boundary). +static Operation *getMultiBufferAnchorOpInLoop(RWOperation *rwOp, + scf::ForOp loop) { + if (!rwOp || !rwOp->op || !loop) + return nullptr; + Operation *op = rwOp->op; + // The op must live strictly inside the rotation loop body for the dyn ops + // to dominate / be dominated by the in-body selector value. + if (op == loop.getOperation()) + return nullptr; + if (!loop->isProperAncestor(op)) + return nullptr; + return op; +} + void CodeGenerator::emitMultiBufferSetWait(IRRewriter &rewriter, ConflictPair *cp) { // Multi-buffer codegen mirrors the InsertSync output: // pre-loop: N pto.set_flag (queue prime, one per event id) - // in-loop: pto.wait_flag_dyn(idx) at body start, - // pto.set_flag_dyn(idx) at body end (before yield) + // in-loop: pto.wait_flag_dyn(idx) at the consumer anchor, + // pto.set_flag_dyn(idx) at the producer anchor // post-loop: N pto.wait_flag (queue drain) // The dyn set/wait MUST live inside the loop body so they share the // `iv mod N` selector's dominance. GSS's default backward-sync anchors @@ -168,6 +225,42 @@ void CodeGenerator::emitMultiBufferSetWait(IRRewriter &rewriter, auto srcAttr = makePipe(rewriter.getContext(), setPipe); auto dstAttr = makePipe(rewriter.getContext(), waitPipe); + // Resolve the lexical-first / lexical-last RWOp positions in the loop body. + // + // NOTE: cp->op1 / cp->op2 are NOT MLIR-lexical-ordered. processOrders sorts + // by syncIrIndex, and for backward-edge processing the "second iteration" + // copy of an op gets a LARGER syncIrIndex than the original. So for a + // backward dep between tload and tstore, cp->op1 is typically the prev-iter + // tstore (lexically LATER in the MLIR loop body) and cp->op2 is the + // next-iter tload copy (whose underlying MLIR op is lexically EARLIER). + // We need to compare actual MLIR positions and pick the right anchors. + // + // For backward (cross-iteration) slot rotation: + // - The wait_flag_dyn must precede the iteration's FIRST MLIR user of the + // slot, otherwise the next iteration could overwrite a slot still in + // flight from two iterations ago. + // - The set_flag_dyn must follow the iteration's LAST MLIR user of the + // slot, otherwise the queue would advance before the slot is free. + // This matches the conservative body-start / before-terminator placement + // exactly when the loop body has no other ops besides op1..op2 and is + // strictly tighter when those ops live under nested control flow. + Operation *anchor1 = getMultiBufferAnchorOpInLoop(cp->op1, loop); + Operation *anchor2 = getMultiBufferAnchorOpInLoop(cp->op2, loop); + Operation *firstUserAnchor = nullptr; + Operation *lastUserAnchor = nullptr; + if (anchor1 && anchor2 && anchor1->getBlock() == anchor2->getBlock()) { + if (anchor1->isBeforeInBlock(anchor2)) { + firstUserAnchor = anchor1; + lastUserAnchor = anchor2; + } else { + firstUserAnchor = anchor2; + lastUserAnchor = anchor1; + } + } + // If anchors are in different blocks (e.g. one under scf.if), or one is + // missing, leave them null and fall back to body-start/before-terminator + // below. That preserves correctness at the cost of guarding fewer cases. + // 1. Pre-loop: queue-prime with N concrete event ids. rewriter.setInsertionPoint(loop); for (int64_t i = 0; i < n; ++i) { @@ -175,8 +268,9 @@ void CodeGenerator::emitMultiBufferSetWait(IRRewriter &rewriter, rewriter.create(loc, srcAttr, dstAttr, eidAttr); } - // 2. In-loop: build (or reuse) the `iv mod N` counter at the start of the - // body, then a select chain over the assigned event ids. + // 2. In-loop: build (or reuse) the slot counter at body start, then a + // select chain over the assigned event ids. The chain dominates every + // in-body anchor because it lives at body start. Value rem = getOrCreateLoopCounter(rewriter, loop, n, loc); rewriter.setInsertionPointAfter(rem.getDefiningOp()); Value selected = @@ -189,18 +283,30 @@ void CodeGenerator::emitMultiBufferSetWait(IRRewriter &rewriter, selected = rewriter.create(loc, eq, idv, selected); } - // wait_flag_dyn goes at the start of the body (just after the selector), - // set_flag_dyn goes right before the terminator (yield) of the body. - rewriter.setInsertionPointAfter(selected.getDefiningOp()); + // 3a. wait_flag_dyn just before the iteration's first user of the slot. + // Fall back to body start (right after the selector) if that anchor lives + // outside the rotation loop body (e.g. solver chose loop-boundary + // PlaceHolders). + if (firstUserAnchor) { + rewriter.setInsertionPoint(firstUserAnchor); + } else { + rewriter.setInsertionPointAfter(selected.getDefiningOp()); + } rewriter.create(loc, srcAttr, dstAttr, selected); - Operation *terminator = loop.getBody()->getTerminator(); - if (!terminator) - return; - rewriter.setInsertionPoint(terminator); + // 3b. set_flag_dyn right after the iteration's last user of the slot. Fall + // back to right before the loop terminator if that anchor lives outside. + if (lastUserAnchor) { + rewriter.setInsertionPointAfter(lastUserAnchor); + } else { + Operation *terminator = loop.getBody()->getTerminator(); + if (!terminator) + return; + rewriter.setInsertionPoint(terminator); + } rewriter.create(loc, srcAttr, dstAttr, selected); - // 3. Post-loop: drain by waiting on each prime. + // 4. Post-loop: drain by waiting on each prime. rewriter.setInsertionPointAfter(loop); for (int64_t i = 0; i < n; ++i) { auto eidAttr = makeEvent(rewriter.getContext(), eids[i]); diff --git a/lib/PTO/Transforms/InsertSync/SyncCodegen.cpp b/lib/PTO/Transforms/InsertSync/SyncCodegen.cpp index 1b702bcb4..0aad93595 100644 --- a/lib/PTO/Transforms/InsertSync/SyncCodegen.cpp +++ b/lib/PTO/Transforms/InsertSync/SyncCodegen.cpp @@ -375,6 +375,21 @@ void SyncCodegen::CreateSetWaitOpForMultiBuffer(IRRewriter &rewriter, createSetOrWaitFlagDynOp(rewriter, op, sync, srcPipe, dstPipe, eventIdxDyn); } +// P0-2: Slot index must be the LOGICAL iteration count modulo N, not `iv % N`. +// Compute `((iv - lb) / step) % N`, with a fast path for the canonical +// step=1, lb=0 case so we don't bloat IR for the common scenario. +static bool isyncIsConstantIndexEqualTo(Value v, int64_t target) { + if (!v) + return false; + if (auto cst = v.getDefiningOp()) + return cst.value() == target; + if (auto cst = v.getDefiningOp()) { + if (auto intAttr = dyn_cast(cst.getValue())) + return intAttr.getInt() == target; + } + return false; +} + Value SyncCodegen::GetBufferSelected(IRRewriter &rewriter, Operation *op, SyncOperation *sync) { if (SyncIndex2SelectBuffer.count(sync->GetSyncIndex())) { @@ -397,7 +412,14 @@ Value SyncCodegen::GetBufferSelected(IRRewriter &rewriter, Operation *op, rewriter.setInsertionPointToStart(parentLoop.getBody()); Value iv = parentLoop.getInductionVar(); Value cN = rewriter.create(op->getLoc(), N); - counter = rewriter.create(op->getLoc(), iv, cN); + Value lb = parentLoop.getLowerBound(); + Value step = parentLoop.getStep(); + Value normalized = iv; + if (!isyncIsConstantIndexEqualTo(lb, 0)) + normalized = rewriter.create(op->getLoc(), normalized, lb); + if (!isyncIsConstantIndexEqualTo(step, 1)) + normalized = rewriter.create(op->getLoc(), normalized, step); + counter = rewriter.create(op->getLoc(), normalized, cN); loop2BufferCounter[parentLoop] = {counter, N}; } diff --git a/lib/PTO/Transforms/PTOEnableMultiBuffer.cpp b/lib/PTO/Transforms/PTOEnableMultiBuffer.cpp index 4237d8208..b0d84516f 100644 --- a/lib/PTO/Transforms/PTOEnableMultiBuffer.cpp +++ b/lib/PTO/Transforms/PTOEnableMultiBuffer.cpp @@ -26,13 +26,31 @@ using namespace mlir::pto; namespace { -// (loop, factor N) -> shared `iv mod N` counter inside that loop. -// B5: when several multi-buffer pointer_casts share the same enclosing scf.for -// and the same N, they should all read from the same counter rather than each -// inserting its own arith.remui + N constant ops. This mirrors +// (loop, factor N) -> shared `((iv - lb) / step) mod N` counter inside that +// loop. B5: when several multi-buffer pointer_casts share the same enclosing +// scf.for and the same N, they should all read from the same counter rather +// than each inserting its own arith.remui + N constant ops. This mirrors // SyncCodegen::loop2BufferCounter so the two passes emit consistent counters. +// +// P0-2 (slot index): The slot must be the LOGICAL iteration index modulo N, +// not the physical induction variable modulo N. When `step != 1` or `lb != 0`, +// `iv mod N` skips slots whenever gcd(step, N) > 1 (e.g. step=2,N=4 only ever +// produces 0,2). Compute `((iv - lb) / step) mod N` explicitly; degenerate to +// the cheaper `iv mod N` when `lb == 0` and `step == 1`. using LoopFactorKey = std::pair; +static bool isConstantIndexEqualTo(Value v, int64_t target) { + if (!v) + return false; + if (auto cst = v.getDefiningOp()) + return cst.value() == target; + if (auto cst = v.getDefiningOp()) { + if (auto intAttr = dyn_cast(cst.getValue())) + return intAttr.getInt() == target; + } + return false; +} + static Value getOrCreateLoopCounter( IRRewriter &rewriter, llvm::DenseMap &cache, @@ -44,7 +62,14 @@ static Value getOrCreateLoopCounter( rewriter.setInsertionPointToStart(forOp.getBody()); Value iv = forOp.getInductionVar(); Value cN = rewriter.create(loc, n); - Value rem = rewriter.create(loc, iv, cN); + Value lb = forOp.getLowerBound(); + Value step = forOp.getStep(); + Value normalized = iv; + if (!isConstantIndexEqualTo(lb, 0)) + normalized = rewriter.create(loc, normalized, lb); + if (!isConstantIndexEqualTo(step, 1)) + normalized = rewriter.create(loc, normalized, step); + Value rem = rewriter.create(loc, normalized, cN); cache[key] = rem; return rem; } @@ -112,6 +137,26 @@ static bool isLocalScopePointerCast(PointerCastOp op) { return as == AddressSpace::VEC || as == AddressSpace::MAT; } +// CV-create-preload hoists per-stage rotated pointer_casts ABOVE the rotation +// loop, so a multi-address pointer_cast may have NO enclosing scf.for of its +// own even though its users still rotate over an enclosing loop. When the cast +// itself is loop-invariant, infer the rotation loop from its users: every use +// must live inside the SAME scf.for, otherwise a single `iv mod N` selector +// can't be valid. Returns nullptr to mean "skip with a warning". +static scf::ForOp inferRotationLoopFromUses(PointerCastOp op) { + scf::ForOp common; + for (Operation *user : op->getUsers()) { + auto userFor = user->getParentOfType(); + if (!userFor) + return nullptr; + if (!common) + common = userFor; + else if (common != userFor) + return nullptr; + } + return common; +} + struct PTOEnableMultiBufferPass : public mlir::pto::impl::PTOEnableMultiBufferBase< PTOEnableMultiBufferPass> { @@ -138,6 +183,13 @@ struct PTOEnableMultiBufferPass } auto forOp = op->getParentOfType(); + if (!forOp) { + // Hoisted-by-CVCreatePreload case: the rotated multi-address cast lives + // at function level, but its users still rotate over an inner scf.for. + // Infer that loop and reuse the regular lowering path; if uses span + // multiple loops (or none), fall through to the warning + skip. + forOp = inferRotationLoopFromUses(op); + } if (!forOp) { op.emitWarning() << "pto-enable-multi-buffer: expected enclosing scf.for; skipping"; diff --git a/lib/PTO/Transforms/PTOPlanMemory.cpp b/lib/PTO/Transforms/PTOPlanMemory.cpp index 7cff8f68f..ff8ee8df6 100644 --- a/lib/PTO/Transforms/PTOPlanMemory.cpp +++ b/lib/PTO/Transforms/PTOPlanMemory.cpp @@ -11,8 +11,16 @@ #include "PTOPlanMemory.h" +#include "PTO/Transforms/InsertSync/SyncCommon.h" #include "PTO/Transforms/MultiBuffer.h" +// Compile-time pinning: PTO plan-memory hands out at most kPtoMultiBufferMaxNum +// physical slots, while InsertSync's event-id allocator is hard-bounded by +// MAX_MULTI_BUFFER_NUM. If they ever drift out of sync the planner will emit +// N slot pointer_casts that the sync side cannot allocate event ids for. +static_assert(::mlir::pto::kPtoMultiBufferMaxNum == MAX_MULTI_BUFFER_NUM, + "kPtoMultiBufferMaxNum must equal MAX_MULTI_BUFFER_NUM"); + #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" diff --git a/test/lit/pto/cv_create_preload_scope_pipeline.pto b/test/lit/pto/cv_create_preload_scope_pipeline.pto index 2056b367c..bee8719c3 100644 --- a/test/lit/pto/cv_create_preload_scope_pipeline.pto +++ b/test/lit/pto/cv_create_preload_scope_pipeline.pto @@ -51,10 +51,13 @@ module { // CREATE-LABEL: IR Dump After PTOCVCreatePreload // CREATE-LABEL: func.func @cv_create_preload -// CREATE: arith.addi -// CREATE: scf.for +// P1-7: per-stage rotated multi-address pointer_casts are hoisted ABOVE the +// rotation loop, not duplicated inside the body. arith.addi (new_ub) follows +// after all the hoisted pointer_casts. // CREATE: pto.pointer_cast(%{{.*}}, %{{.*}}) : // CREATE: pto.pointer_cast(%{{.*}}, %{{.*}}) : +// CREATE: arith.addi +// CREATE: scf.for // CREATE: scf.if // CREATE: pto.cv.scope // CREATE: pto.tload diff --git a/test/lit/pto/multi_buffer_step_not_one.pto b/test/lit/pto/multi_buffer_step_not_one.pto new file mode 100644 index 000000000..7e74b74a3 --- /dev/null +++ b/test/lit/pto/multi_buffer_step_not_one.pto @@ -0,0 +1,47 @@ +// Copyright (c) 2026 Huawei Technologies Co., Ltd. +// This program is free software, you can redistribute it and/or modify it under the terms and conditions of +// CANN Open Software License Agreement Version 2.0 (the "License"). +// Please refer to the License for details. You may not use this file except in compliance with the License. +// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +// See LICENSE in the root of the software repository for the full text of the License. + +// RUN: ptoas --enable-multi-buffer-lowering --mlir-print-ir-after=pto-enable-multi-buffer %s 2>&1 1>/dev/null | FileCheck %s + +// P0-2 regression: when the loop step is NOT 1, the slot index must be the +// LOGICAL iteration count modulo N, not `iv mod N`. With step=2 / N=2 the +// naive `iv mod 2` only ever yields 0 (slot 1 is never selected), silently +// breaking double-buffer rotation. The lowering should compute the slot as +// `((iv - lb) / step) mod N`, i.e., divide the normalized iv by the step +// before the modulo. +// +// Loop here uses lb=0, step=2, ub=8 -> 4 logical iterations: slot=0,1,0,1. + +module { + func.func @double_buffer_step_two(%arg0: memref<16x16x16xf16, #pto.address_space>, + %arg1: memref<16x16x16xf16, #pto.address_space>) { + %c0 = arith.constant 0 : index + %c2 = arith.constant 2 : index + %c8 = arith.constant 8 : index + scf.for %i = %c0 to %c8 step %c2 { + %a = memref.alloc() {pto.multi_buffer = 2 : i32} + : memref<16x16x16xf16, #pto.address_space> + pto.tload ins(%arg0 : memref<16x16x16xf16, #pto.address_space>) + outs(%a : memref<16x16x16xf16, #pto.address_space>) + pto.tstore ins(%a : memref<16x16x16xf16, #pto.address_space>) + outs(%arg1 : memref<16x16x16xf16, #pto.address_space>) + } + return + } +} + +// CHECK: IR Dump After PTOEnableMultiBuffer +// CHECK: func.func @double_buffer_step_two +// Slot index must include a divui by the loop step BEFORE the remui. The +// step %c2 is the same SSA value used as the loop's `step`, so we match +// "divui ..., %c2" without trying to capture the step name from the for op +// signature (FileCheck regex would otherwise greedily eat trailing tokens). +// CHECK: scf.for %[[IV:[a-zA-Z0-9_]+]] = %{{.*}} to %{{.*}} step %{{.*}} +// CHECK: arith.divui %[[IV]], %{{.*}} : index +// CHECK: arith.remui %{{.*}}, %{{.*}} : index +// CHECK: arith.select