From f0ad0d25cbb0944c195760bfb5345a517101dd2d Mon Sep 17 00:00:00 2001 From: zhangstevenunity <128771452+zhangstevenunity@users.noreply.github.com> Date: Thu, 30 Apr 2026 13:29:49 +0800 Subject: [PATCH] Enhance PlanMemory reuse liveness --- lib/PTO/Transforms/PTOPlanMemory.cpp | 299 ++++++++++++++++-- lib/PTO/Transforms/PTOPlanMemory.h | 52 ++- .../pto/issue601_loop_first_write_reuse.pto | 32 ++ ...ory_loop_may_zero_no_first_write_reuse.pto | 37 +++ .../plan_memory_right_next_write_reuse.pto | 84 +++++ 5 files changed, 474 insertions(+), 30 deletions(-) create mode 100644 test/lit/pto/issue601_loop_first_write_reuse.pto create mode 100644 test/lit/pto/plan_memory_loop_may_zero_no_first_write_reuse.pto create mode 100644 test/lit/pto/plan_memory_right_next_write_reuse.pto diff --git a/lib/PTO/Transforms/PTOPlanMemory.cpp b/lib/PTO/Transforms/PTOPlanMemory.cpp index 38d0fd703..2cdb9c389 100644 --- a/lib/PTO/Transforms/PTOPlanMemory.cpp +++ b/lib/PTO/Transforms/PTOPlanMemory.cpp @@ -153,6 +153,36 @@ static void sortValuesByStableOrder( }); } +static void appendUniqueValue(SmallVectorImpl &values, Value value) { + if (!llvm::is_contained(values, value)) + values.push_back(value); +} + +static std::optional getConstantIndexLike(Value value) { + if (auto constantIndexOp = value.getDefiningOp()) + return constantIndexOp.value(); + if (auto constantIntOp = value.getDefiningOp()) + return constantIntOp.value(); + if (auto constantOp = value.getDefiningOp()) { + if (auto intAttr = dyn_cast(constantOp.getValue())) + return intAttr.getInt(); + } + if (auto castOp = value.getDefiningOp()) + return getConstantIndexLike(castOp.getIn()); + return std::nullopt; +} + +static bool isForLoopKnownNonEmpty(scf::ForOp forOp) { + std::optional lowerBound = + getConstantIndexLike(forOp.getLowerBound()); + std::optional upperBound = + getConstantIndexLike(forOp.getUpperBound()); + std::optional step = getConstantIndexLike(forOp.getStep()); + if (!lowerBound || !upperBound || !step || *step <= 0) + return false; + return *lowerBound < *upperBound; +} + static SmallVector getScratchBuffersFromEffects(Operation *op, ValueRange dpsInits, const StableValueOrderMap &stableValueOrder) { @@ -589,31 +619,198 @@ void MemLivenessAnalysis::RecursiveIfOp(scf::IfOp ifOp, Liveness live) { SmallVector MemLivenessAnalysis::GetLiveBuffersInLoop(scf::ForOp forOp, Liveness live) { - SmallVector allocBeforeLoopBuffers; const auto *liveBlockInfo = live.getLiveness(forOp->getBlock()); auto currentLiveValues = liveBlockInfo->currentlyLiveValues(forOp.getOperation()); if (currentLiveValues.empty()) { - return allocBeforeLoopBuffers; + return {}; } // The gen buffer of the same operation must ensure the order of priority. SetVector currentLiveValuesOrder; for (auto buffer : currentLiveValues) { currentLiveValuesOrder.insert(buffer); } + SetVector allocBeforeLoopBufferSet; for (const Value &operand : currentLiveValuesOrder) { auto aliasBuffers = GetAliasBuffers(operand); aliasBuffers.insert(operand); for (auto Buffer : aliasBuffers) { auto iter = buffer2status.find(Buffer); - if (iter != buffer2status.end()) - allocBeforeLoopBuffers.push_back(Buffer); + if (iter == buffer2status.end()) + continue; + if ((iter->second == BufferStatus::DEFFINED || + iter->second == BufferStatus::KILLED) && + CanDelayLoopEntryGenUntilFirstWrite(forOp, Buffer)) { + delayedLoopEntryGenBuffers[Buffer] = true; + continue; + } + allocBeforeLoopBufferSet.insert(Buffer); } } + SmallVector allocBeforeLoopBuffers(allocBeforeLoopBufferSet.begin(), + allocBeforeLoopBufferSet.end()); sortValuesByStableOrder(allocBeforeLoopBuffers, stableValueOrder); return allocBeforeLoopBuffers; } +bool MemLivenessAnalysis::CanDelayLoopEntryGenUntilFirstWrite( + scf::ForOp forOp, Value buffer) { + if (!isForLoopKnownNonEmpty(forOp)) + return false; + + SetVector aliasBuffers = GetAliasBuffers(buffer); + aliasBuffers.insert(buffer); + Block *body = forOp.getBody(); + if (!body) + return false; + + for (Operation &op : body->without_terminator()) { + if (!OperationOrNestedRegionTouchesAnyAlias(&op, aliasBuffers)) + continue; + if (auto nestedForOp = dyn_cast(&op)) { + return llvm::any_of(aliasBuffers, [&](Value alias) { + return CanDelayLoopEntryGenUntilFirstWrite(nestedForOp, alias); + }); + } + return IsWriteOnlyDpsInitForAlias(&op, aliasBuffers); + } + return false; +} + +bool MemLivenessAnalysis::OperationDirectlyTouchesAnyAlias( + Operation *op, const SetVector &aliasBuffers) const { + auto touchesValue = [&](Value value) { + return value && llvm::is_contained(aliasBuffers, value); + }; + for (Value operand : op->getOperands()) { + if (touchesValue(operand)) + return true; + } + for (Value result : op->getResults()) { + if (touchesValue(result)) + return true; + } + + auto memEffect = dyn_cast(op); + if (!memEffect) + return false; + SmallVector, + kMemoryEffectReserveSize> + effects; + memEffect.getEffects(effects); + for (const auto &effect : effects) { + if (touchesValue(effect.getValue())) + return true; + } + return false; +} + +bool MemLivenessAnalysis::OperationOrNestedRegionTouchesAnyAlias( + Operation *op, const SetVector &aliasBuffers) const { + if (OperationDirectlyTouchesAnyAlias(op, aliasBuffers)) + return true; + if (op->getNumRegions() == 0) + return false; + + bool touches = false; + op->walk([&](Operation *nestedOp) { + if (nestedOp == op) + return WalkResult::advance(); + if (OperationDirectlyTouchesAnyAlias(nestedOp, aliasBuffers)) { + touches = true; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + return touches; +} + +bool MemLivenessAnalysis::OperationReadsAnyAlias( + Operation *op, const SetVector &aliasBuffers) const { + auto touchesValue = [&](Value value) { + return value && llvm::is_contained(aliasBuffers, value); + }; + + auto memEffect = dyn_cast(op); + if (!memEffect) { + return llvm::any_of(op->getOperands(), touchesValue); + } + + SmallVector, + kMemoryEffectReserveSize> + effects; + memEffect.getEffects(effects); + return llvm::any_of(effects, [&](const auto &effect) { + return isa(effect.getEffect()) && + touchesValue(effect.getValue()); + }); +} + +bool MemLivenessAnalysis::IsWriteOnlyDpsInitForAlias( + Operation *op, const SetVector &aliasBuffers) const { + auto ptoDpsOp = dyn_cast(op); + if (!ptoDpsOp) + return false; + + bool hasAliasDpsInit = false; + for (Value init : ptoDpsOp.getDpsInits()) { + if (llvm::is_contained(aliasBuffers, init)) { + hasAliasDpsInit = true; + break; + } + } + if (!hasAliasDpsInit) + return false; + + auto memEffect = dyn_cast(op); + if (!memEffect) + return false; + SmallVector, + kMemoryEffectReserveSize> + effects; + memEffect.getEffects(effects); + + bool hasWrite = false; + for (const auto &effect : effects) { + Value value = effect.getValue(); + if (!value || !llvm::is_contained(aliasBuffers, value)) + continue; + if (isa(effect.getEffect())) + return false; + if (isa(effect.getEffect())) { + hasWrite = true; + continue; + } + return false; + } + return hasWrite; +} + +bool MemLivenessAnalysis::CanKillBeforeNextOverwrite( + Operation *op, const SetVector &aliasBuffers) { + for (Operation *nextOp = op->getNextNode(); nextOp; + nextOp = nextOp->getNextNode()) { + if (!OperationOrNestedRegionTouchesAnyAlias(nextOp, aliasBuffers)) + continue; + + if (auto forOp = dyn_cast(nextOp)) { + return llvm::any_of(aliasBuffers, [&](Value alias) { + return CanDelayLoopEntryGenUntilFirstWrite(forOp, alias); + }); + } + + return IsWriteOnlyDpsInitForAlias(nextOp, aliasBuffers); + } + return false; +} + +bool MemLivenessAnalysis::CanRegenerateBufferAtOp(Operation *op, + Value buffer) { + SetVector aliasBuffers = GetAliasBuffers(buffer); + aliasBuffers.insert(buffer); + return IsWriteOnlyDpsInitForAlias(op, aliasBuffers); +} + LogicalResult MemLivenessAnalysis::CheckLocalBufferAllocOp(Operation *op) const { auto allocOp = dyn_cast(op); @@ -732,11 +929,25 @@ void MemLivenessAnalysis::UpdateOperandGenInfo(OpInfo *opInfo, Value operand) { if (iter_buffer == buffer2status.end()) return; if (iter_buffer->second == BufferStatus::DEFFINED) { - genKillMap[opInfo].gen.push_back(operand); + appendUniqueValue(genKillMap[opInfo].gen, operand); buffer2status[iter_buffer->first] = BufferStatus::GENED; + buffer2GenOp[iter_buffer->first] = opInfo->operation; } else if (iter_buffer->second == BufferStatus::KILLED) { - llvm_unreachable("The buffer memory has been released and cannot be used " - "again! "); + if (!CanRegenerateBufferAtOp(opInfo->operation, operand)) { + llvm_unreachable("The buffer memory has been released and cannot be " + "used again before it is redefined! "); + } + appendUniqueValue(genKillMap[opInfo].gen, operand); + buffer2status[iter_buffer->first] = BufferStatus::GENED; + buffer2GenOp[iter_buffer->first] = opInfo->operation; + } else if (iter_buffer->second == BufferStatus::GENED) { + SetVector aliasBuffers = GetAliasBuffers(operand); + aliasBuffers.insert(operand); + if (IsWriteOnlyDpsInitForAlias(opInfo->operation, aliasBuffers)) { + appendUniqueValue(genKillMap[opInfo].kill, operand); + appendUniqueValue(genKillMap[opInfo].gen, operand); + buffer2GenOp[iter_buffer->first] = opInfo->operation; + } } } @@ -764,15 +975,35 @@ void MemLivenessAnalysis::UpdateOpKillInfo(OpInfo *opInfo, Value operand, auto iterBuffer = buffer2status.find(aliasBuffer); if (iterBuffer == buffer2status.end()) return; - if (iterBuffer->second == BufferStatus::GENED && - IsInSameBlock(iterBuffer->first.getDefiningOp(), opInfo->operation) && - AllDeadAfter(opInfo->operation, aliasBuffers, live)) { - genKillMap[opInfo].kill.push_back(aliasBuffer); + Operation *defOp = iterBuffer->first.getDefiningOp(); + bool canKillInThisBlock = + defOp && IsInSameBlock(defOp, opInfo->operation); + auto delayedGen = delayedLoopEntryGenBuffers.find(iterBuffer->first); + if (!canKillInThisBlock && delayedGen != delayedLoopEntryGenBuffers.end() && + delayedGen->second) { + Operation *genOp = GetBufferGenOp(iterBuffer->first); + canKillInThisBlock = genOp && IsInSameBlock(genOp, opInfo->operation); + } + bool canKillCurrentValue = + AllDeadAfter(opInfo->operation, aliasBuffers, live) || + (OperationReadsAnyAlias(opInfo->operation, aliasBuffers) && + CanKillBeforeNextOverwrite(opInfo->operation, aliasBuffers)); + if (iterBuffer->second == BufferStatus::GENED && canKillInThisBlock && + canKillCurrentValue) { + appendUniqueValue(genKillMap[opInfo].kill, aliasBuffer); buffer2status[iterBuffer->first] = BufferStatus::KILLED; + buffer2GenOp.erase(iterBuffer->first); } } } +Operation *MemLivenessAnalysis::GetBufferGenOp(Value buffer) const { + auto it = buffer2GenOp.find(buffer); + if (it != buffer2GenOp.end()) + return it->second; + return nullptr; +} + bool MemLivenessAnalysis::IsInSameBlock(Operation *op1, Operation *op2) const { return op1->getBlock() == op2->getBlock(); } @@ -839,25 +1070,43 @@ BufferInfo MemLivenessAnalysis::GetBufferInfo(Operation *op, Value operand, void MemLivenessAnalysis::GenerateBufferLife() { int scopeTime = 0; + DenseMap> openLives; for (size_t i = 0; i < linearOperation.size(); ++i) { auto it = genKillMap.find(linearOperation[i].get()); if (it == genKillMap.end()) { scopeTime++; continue; } - // Time given to buffer start. + + SmallVector postGenKills; + for (const Value &killBuffer : it->second.kill) { + auto iter = openLives.find(killBuffer); + if (iter != openLives.end()) { + iter->second->freeTime = scopeTime; + openLives.erase(iter); + continue; + } + if (!llvm::is_contained(it->second.gen, killBuffer)) + llvm::report_fatal_error("buffer lifetime killed before generation"); + appendUniqueValue(postGenKills, killBuffer); + } + for (const Value &genBuffer : it->second.gen) { - std::unique_ptr bufferLife = - std::make_unique(genBuffer); + if (openLives.find(genBuffer) != openLives.end()) + llvm::report_fatal_error("buffer lifetime generated before release"); + std::shared_ptr bufferLife = + std::make_shared(genBuffer); bufferLife->allocTime = scopeTime; - buffer2Life[genBuffer] = std::move(bufferLife); + buffer2Life[genBuffer].push_back(bufferLife); + openLives[genBuffer] = std::move(bufferLife); } - // Time given to buffer end. - for (const Value &killBuffer : it->second.kill) { - auto iter = buffer2Life.find(killBuffer); - if (iter == buffer2Life.end()) - llvm::report_fatal_error("buffer lifetime killed before generation"); + + for (const Value &killBuffer : postGenKills) { + auto iter = openLives.find(killBuffer); + if (iter == openLives.end()) + llvm::report_fatal_error("buffer lifetime generated after release"); iter->second->freeTime = scopeTime; + openLives.erase(iter); } scopeTime++; } @@ -1068,6 +1317,7 @@ LogicalResult MemPlan::plan() { void MemPlan::GenerateStorageEntry() { // create new storage entry. + SetVector seenBuffers; for (auto &operation : linearOperation) { auto it = genKillMap.find(operation.get()); if (it == genKillMap.end()) @@ -1075,14 +1325,17 @@ void MemPlan::GenerateStorageEntry() { SmallVector genBuffers(it->second.gen.begin(), it->second.gen.end()); sortValuesByStableOrder(genBuffers, stableValueOrder); for (const Value &genBuffer : genBuffers) { + if (llvm::is_contained(seenBuffers, genBuffer)) + continue; auto iter = bufferInfos.find(genBuffer); if (iter == bufferInfos.end()) { continue; } - const std::shared_ptr &bufLife = buffer2Life.at(genBuffer); + seenBuffers.insert(genBuffer); + const BufferLifeVec &bufLives = buffer2Life.at(genBuffer); std::unique_ptr entry = std::make_unique(); entry->bufInfo = &iter->second; - entry->bufferLifeVec.emplace_back(bufLife); + entry->bufferLifeVec.append(bufLives.begin(), bufLives.end()); entry->inplaceBuffers.emplace_back(iter->first); auto multiBuffer = buffer2MultiNum.find(genBuffer); if (multiBuffer != buffer2MultiNum.end()) { @@ -2095,7 +2348,7 @@ void MemPlan::ReportAllocatedEntryDebugInfo(StorageEntry *rootStorageEntry) { } size_t num = allocatedEntry.size() - 1; if (rootStorageEntry->mergedChildren.size() <= num) - llvm::report_fatal_error("missing failed storage entry"); + return; const StorageEntry *failedSe = rootStorageEntry->mergedChildren[num]; printRecord(failedSe); LDBG("alloc fail,because exceed bound of memory \n" diff --git a/lib/PTO/Transforms/PTOPlanMemory.h b/lib/PTO/Transforms/PTOPlanMemory.h index 6a1d40077..cdc9ae361 100644 --- a/lib/PTO/Transforms/PTOPlanMemory.h +++ b/lib/PTO/Transforms/PTOPlanMemory.h @@ -273,8 +273,8 @@ class MemLivenessAnalysis { /// stable IR order for Values used to keep memory planning deterministic. StableValueOrderMap stableValueOrder; - /// map from buffer to its lifetime. - DenseMap> buffer2Life; + /// map from buffer to its lifetime intervals. + DenseMap buffer2Life; /// map from operation to its gen and kill buffer. DenseMap genKillMap; @@ -303,6 +303,38 @@ class MemLivenessAnalysis { /// Get the buffer used within the loop and defined outside the loop. SmallVector GetLiveBuffersInLoop(scf::ForOp forOp, Liveness live); + /// Check whether a loop-live buffer can be generated by its first overwrite + /// inside a statically non-empty loop instead of being generated at the loop + /// entry. + bool CanDelayLoopEntryGenUntilFirstWrite(scf::ForOp forOp, Value buffer); + + /// Return true if op directly uses or effects any alias buffer. + bool OperationDirectlyTouchesAnyAlias(Operation *op, + const SetVector &aliasBuffers) const; + + /// Return true if op or any nested op touches any alias buffer. + bool OperationOrNestedRegionTouchesAnyAlias( + Operation *op, const SetVector &aliasBuffers) const; + + /// Return true if op directly reads any alias buffer. + bool OperationReadsAnyAlias(Operation *op, + const SetVector &aliasBuffers) const; + + /// Return true if op fully overwrites one alias as a DPS init without first + /// reading any alias in the same alias set. + bool IsWriteOnlyDpsInitForAlias(Operation *op, + const SetVector &aliasBuffers) const; + + /// Return true if the next touch after op redefines the buffer before read. + bool CanKillBeforeNextOverwrite(Operation *op, + const SetVector &aliasBuffers); + + /// Return true if a killed buffer can start a new lifetime at op. + bool CanRegenerateBufferAtOp(Operation *op, Value buffer); + + /// Return the operation that actually generated the buffer lifetime. + Operation *GetBufferGenOp(Value buffer) const; + /// Update for Op tensor init args and tensor result args alias info. void UpdateInitAndResAlias(DestinationStyleOpInterface dstStyleOp); @@ -399,6 +431,13 @@ class MemLivenessAnalysis { /// Gen-kill status corresponding to buffer. DenseMap buffer2status; + /// Operation where the current buffer lifetime was generated. + DenseMap buffer2GenOp; + + /// Buffers whose loop-entry generation was delayed to their first write in + /// the loop body. + DenseMap delayedLoopEntryGenBuffers; + /// map on buffer alias DenseMap> buffer2AliasVec; @@ -433,9 +472,8 @@ class MemPlan { bufferInfos = bufsInfo; } - inline void - SetBuffer2Life(DenseMap> buf2Life) { - buffer2Life = buf2Life; + inline void SetBuffer2Life(DenseMap buf2Life) { + buffer2Life = std::move(buf2Life); } inline void SetGenKillMap(DenseMap gkMap) { @@ -681,8 +719,8 @@ class MemPlan { /// map from buffer value to its buffer information. std::map bufferInfos; - /// map from buffer to its lifetime. - DenseMap> buffer2Life; + /// map from buffer to its lifetime intervals. + DenseMap buffer2Life; /// record the map from the buffer to its number of buffer if it does /// multibuffer optimization. diff --git a/test/lit/pto/issue601_loop_first_write_reuse.pto b/test/lit/pto/issue601_loop_first_write_reuse.pto new file mode 100644 index 000000000..2d0eb968f --- /dev/null +++ b/test/lit/pto/issue601_loop_first_write_reuse.pto @@ -0,0 +1,32 @@ +// RUN: ptoas --pto-arch=a3 %s >/dev/null +// +// The two vec buffers are allocated before the loop, but each iteration fully +// overwrites one buffer before reading it. A3 vec local memory cannot hold both +// 128x256xf32 buffers at once, so the two loop stages must reuse the same +// storage interval. + +module { + func.func @loop_first_write_reuse( + %src: memref<128x256xf32, #pto.address_space>, + %dst: memref<128x256xf32, #pto.address_space>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + + %stage0 = memref.alloc() : memref<128x256xf32, #pto.address_space> + %stage1 = memref.alloc() : memref<128x256xf32, #pto.address_space> + + scf.for %i = %c0 to %c2 step %c1 { + pto.tload ins(%src : memref<128x256xf32, #pto.address_space>) + outs(%stage0 : memref<128x256xf32, #pto.address_space>) + pto.tstore ins(%stage0 : memref<128x256xf32, #pto.address_space>) + outs(%dst : memref<128x256xf32, #pto.address_space>) + + pto.tload ins(%src : memref<128x256xf32, #pto.address_space>) + outs(%stage1 : memref<128x256xf32, #pto.address_space>) + pto.tstore ins(%stage1 : memref<128x256xf32, #pto.address_space>) + outs(%dst : memref<128x256xf32, #pto.address_space>) + } + return + } +} diff --git a/test/lit/pto/plan_memory_loop_may_zero_no_first_write_reuse.pto b/test/lit/pto/plan_memory_loop_may_zero_no_first_write_reuse.pto new file mode 100644 index 000000000..ae57623be --- /dev/null +++ b/test/lit/pto/plan_memory_loop_may_zero_no_first_write_reuse.pto @@ -0,0 +1,37 @@ +// RUN: not ptoas --pto-arch=a3 %s 2>&1 | FileCheck %s +// +// The loop upper bound is dynamic, so the loop may execute zero iterations. +// Even though each in-loop use starts with a full overwrite, PlanMemory must +// keep the loop-entry lifetime because the after-loop read may observe the +// pre-loop storage state when the loop is skipped. + +module { + func.func @loop_may_zero_preserve_live_in( + %src: memref<128x256xf32, #pto.address_space>, + %dst: memref<128x256xf32, #pto.address_space>, + %ub: index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + + %stage0 = memref.alloc() : memref<128x256xf32, #pto.address_space> + %stage1 = memref.alloc() : memref<128x256xf32, #pto.address_space> + + scf.for %i = %c0 to %ub step %c1 { + pto.tload ins(%src : memref<128x256xf32, #pto.address_space>) + outs(%stage0 : memref<128x256xf32, #pto.address_space>) + pto.tstore ins(%stage0 : memref<128x256xf32, #pto.address_space>) + outs(%dst : memref<128x256xf32, #pto.address_space>) + + pto.tload ins(%src : memref<128x256xf32, #pto.address_space>) + outs(%stage1 : memref<128x256xf32, #pto.address_space>) + pto.tstore ins(%stage1 : memref<128x256xf32, #pto.address_space>) + outs(%dst : memref<128x256xf32, #pto.address_space>) + } + + pto.tstore ins(%stage0 : memref<128x256xf32, #pto.address_space>) + outs(%dst : memref<128x256xf32, #pto.address_space>) + return + } +} + +// CHECK: vec overflow diff --git a/test/lit/pto/plan_memory_right_next_write_reuse.pto b/test/lit/pto/plan_memory_right_next_write_reuse.pto new file mode 100644 index 000000000..58705b7b3 --- /dev/null +++ b/test/lit/pto/plan_memory_right_next_write_reuse.pto @@ -0,0 +1,84 @@ +// RUN: ptoas --pto-arch=a3 --mlir-print-ir-after=pto-plan-memory %s 2>&1 1>/dev/null | FileCheck %s +// RUN: ptoas --pto-arch=a3 %s >/dev/null +// +// The small right tile is used again after the large-right stage, but that +// later use starts with a full write. The old small value can be killed before +// the large stage so both right buffers fit in the single 64KiB right space. + +module { + func.func @right_next_write_reuse(%a: !pto.ptr, %b: !pto.ptr) + attributes {pto.kernel_kind = #pto.kernel_kind} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + + %a_view = pto.make_tensor_view %a, shape = [%c128, %c256], + strides = [%c256, %c1] : !pto.tensor_view + %b_view = pto.make_tensor_view %b, shape = [%c256, %c128], + strides = [%c128, %c1] : !pto.tensor_view + %a_part = pto.partition_view %a_view, offsets = [%c0, %c0], + sizes = [%c128, %c256] : !pto.tensor_view + -> !pto.partition_tensor_view<128x256xf16> + %b_small_part = pto.partition_view %b_view, offsets = [%c0, %c0], + sizes = [%c128, %c128] : !pto.tensor_view + -> !pto.partition_tensor_view<128x128xf16> + %b_large_part = pto.partition_view %b_view, offsets = [%c0, %c0], + sizes = [%c256, %c128] : !pto.tensor_view + -> !pto.partition_tensor_view<256x128xf16> + + %left_mat = pto.alloc_tile + : !pto.tile_buf + %left = pto.alloc_tile + : !pto.tile_buf + %small_mat = pto.alloc_tile + : !pto.tile_buf + %large_mat = pto.alloc_tile + : !pto.tile_buf + %small_right = pto.alloc_tile + : !pto.tile_buf + %large_right = pto.alloc_tile + : !pto.tile_buf + %acc = pto.alloc_tile + : !pto.tile_buf + + pto.tload ins(%a_part : !pto.partition_tensor_view<128x256xf16>) + outs(%left_mat : !pto.tile_buf) + pto.tmov ins(%left_mat : !pto.tile_buf) + outs(%left : !pto.tile_buf) + %left_small = pto.subview %left[%c0, %c0] sizes [128, 128] + : !pto.tile_buf + -> !pto.tile_buf + + scf.for %i = %c0 to %c2 step %c1 { + pto.tload ins(%b_small_part : !pto.partition_tensor_view<128x128xf16>) + outs(%small_mat : !pto.tile_buf) + pto.tmov ins(%small_mat : !pto.tile_buf) + outs(%small_right : !pto.tile_buf) + pto.tmatmul ins(%left_small, %small_right : !pto.tile_buf, !pto.tile_buf) + outs(%acc : !pto.tile_buf) + + scf.for %j = %c0 to %c2 step %c1 { + pto.tload ins(%b_large_part : !pto.partition_tensor_view<256x128xf16>) + outs(%large_mat : !pto.tile_buf) + pto.tmov ins(%large_mat : !pto.tile_buf) + outs(%large_right : !pto.tile_buf) + pto.tmatmul ins(%left, %large_right : !pto.tile_buf, !pto.tile_buf) + outs(%acc : !pto.tile_buf) + + pto.tload ins(%b_small_part : !pto.partition_tensor_view<128x128xf16>) + outs(%small_mat : !pto.tile_buf) + pto.tmov ins(%small_mat : !pto.tile_buf) + outs(%small_right : !pto.tile_buf) + pto.tmatmul ins(%left_small, %small_right : !pto.tile_buf, !pto.tile_buf) + outs(%acc : !pto.tile_buf) + } + } + return + } +} + +// CHECK-LABEL: @right_next_write_reuse +// CHECK-DAG: pto.pointer_cast(%c0_i64) {{.*}}memref<128x128xf16{{.*}}#pto.address_space +// CHECK-DAG: pto.pointer_cast(%c0_i64) {{.*}}memref<256x128xf16{{.*}}#pto.address_space