Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
299 changes: 276 additions & 23 deletions lib/PTO/Transforms/PTOPlanMemory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,36 @@ static void sortValuesByStableOrder(
});
}

static void appendUniqueValue(SmallVectorImpl<Value> &values, Value value) {
if (!llvm::is_contained(values, value))
values.push_back(value);
}

static std::optional<int64_t> getConstantIndexLike(Value value) {
if (auto constantIndexOp = value.getDefiningOp<arith::ConstantIndexOp>())
return constantIndexOp.value();
if (auto constantIntOp = value.getDefiningOp<arith::ConstantIntOp>())
return constantIntOp.value();
if (auto constantOp = value.getDefiningOp<arith::ConstantOp>()) {
if (auto intAttr = dyn_cast<IntegerAttr>(constantOp.getValue()))
return intAttr.getInt();
}
if (auto castOp = value.getDefiningOp<arith::IndexCastOp>())
return getConstantIndexLike(castOp.getIn());
return std::nullopt;
}

static bool isForLoopKnownNonEmpty(scf::ForOp forOp) {
std::optional<int64_t> lowerBound =
getConstantIndexLike(forOp.getLowerBound());
std::optional<int64_t> upperBound =
getConstantIndexLike(forOp.getUpperBound());
std::optional<int64_t> step = getConstantIndexLike(forOp.getStep());
if (!lowerBound || !upperBound || !step || *step <= 0)
return false;
return *lowerBound < *upperBound;
}

static SmallVector<Value> getScratchBuffersFromEffects(Operation *op,
ValueRange dpsInits,
const StableValueOrderMap &stableValueOrder) {
Expand Down Expand Up @@ -589,31 +619,198 @@ void MemLivenessAnalysis::RecursiveIfOp(scf::IfOp ifOp, Liveness live) {

SmallVector<Value> MemLivenessAnalysis::GetLiveBuffersInLoop(scf::ForOp forOp,
Liveness live) {
SmallVector<Value> allocBeforeLoopBuffers;
const auto *liveBlockInfo = live.getLiveness(forOp->getBlock());
auto currentLiveValues =
liveBlockInfo->currentlyLiveValues(forOp.getOperation());
if (currentLiveValues.empty()) {
return allocBeforeLoopBuffers;
return {};
}
// The gen buffer of the same operation must ensure the order of priority.
SetVector<Value> currentLiveValuesOrder;
for (auto buffer : currentLiveValues) {
currentLiveValuesOrder.insert(buffer);
}
SetVector<Value> allocBeforeLoopBufferSet;
for (const Value &operand : currentLiveValuesOrder) {
auto aliasBuffers = GetAliasBuffers(operand);
aliasBuffers.insert(operand);
for (auto Buffer : aliasBuffers) {
auto iter = buffer2status.find(Buffer);
if (iter != buffer2status.end())
allocBeforeLoopBuffers.push_back(Buffer);
if (iter == buffer2status.end())
continue;
if ((iter->second == BufferStatus::DEFFINED ||
iter->second == BufferStatus::KILLED) &&
CanDelayLoopEntryGenUntilFirstWrite(forOp, Buffer)) {
delayedLoopEntryGenBuffers[Buffer] = true;
continue;
Comment on lines +643 to +645
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Preserve loop-entry liveness for zero-trip loops

Delaying loop-entry generation here assumes the loop body executes at least once, but scf.for may legally have zero iterations (e.g., dynamic bounds where lb >= ub). In that case no in-loop overwrite occurs, yet this code marks the live-in buffer as delayable and skips generating it at loop entry, which can let PlanMemory reuse that storage too early and corrupt a value that is read after the loop. The optimization should only apply when non-empty execution is provable, or else keep the loop-entry lifetime.

Useful? React with 👍 / 👎.

}
allocBeforeLoopBufferSet.insert(Buffer);
}
}
SmallVector<Value> allocBeforeLoopBuffers(allocBeforeLoopBufferSet.begin(),
allocBeforeLoopBufferSet.end());
sortValuesByStableOrder(allocBeforeLoopBuffers, stableValueOrder);
return allocBeforeLoopBuffers;
}

bool MemLivenessAnalysis::CanDelayLoopEntryGenUntilFirstWrite(
scf::ForOp forOp, Value buffer) {
if (!isForLoopKnownNonEmpty(forOp))
return false;

SetVector<Value> aliasBuffers = GetAliasBuffers(buffer);
aliasBuffers.insert(buffer);
Block *body = forOp.getBody();
if (!body)
return false;

for (Operation &op : body->without_terminator()) {
if (!OperationOrNestedRegionTouchesAnyAlias(&op, aliasBuffers))
continue;
if (auto nestedForOp = dyn_cast<scf::ForOp>(&op)) {
return llvm::any_of(aliasBuffers, [&](Value alias) {
return CanDelayLoopEntryGenUntilFirstWrite(nestedForOp, alias);
});
}
return IsWriteOnlyDpsInitForAlias(&op, aliasBuffers);
}
return false;
}

bool MemLivenessAnalysis::OperationDirectlyTouchesAnyAlias(
Operation *op, const SetVector<Value> &aliasBuffers) const {
auto touchesValue = [&](Value value) {
return value && llvm::is_contained(aliasBuffers, value);
};
for (Value operand : op->getOperands()) {
if (touchesValue(operand))
return true;
}
for (Value result : op->getResults()) {
if (touchesValue(result))
return true;
}

auto memEffect = dyn_cast<MemoryEffectOpInterface>(op);
if (!memEffect)
return false;
SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>,
kMemoryEffectReserveSize>
effects;
memEffect.getEffects(effects);
for (const auto &effect : effects) {
if (touchesValue(effect.getValue()))
return true;
}
return false;
}

bool MemLivenessAnalysis::OperationOrNestedRegionTouchesAnyAlias(
Operation *op, const SetVector<Value> &aliasBuffers) const {
if (OperationDirectlyTouchesAnyAlias(op, aliasBuffers))
return true;
if (op->getNumRegions() == 0)
return false;

bool touches = false;
op->walk<WalkOrder::PreOrder>([&](Operation *nestedOp) {
if (nestedOp == op)
return WalkResult::advance();
if (OperationDirectlyTouchesAnyAlias(nestedOp, aliasBuffers)) {
touches = true;
return WalkResult::interrupt();
}
return WalkResult::advance();
});
return touches;
}

bool MemLivenessAnalysis::OperationReadsAnyAlias(
Operation *op, const SetVector<Value> &aliasBuffers) const {
auto touchesValue = [&](Value value) {
return value && llvm::is_contained(aliasBuffers, value);
};

auto memEffect = dyn_cast<MemoryEffectOpInterface>(op);
if (!memEffect) {
return llvm::any_of(op->getOperands(), touchesValue);
}

SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>,
kMemoryEffectReserveSize>
effects;
memEffect.getEffects(effects);
return llvm::any_of(effects, [&](const auto &effect) {
return isa<MemoryEffects::Read>(effect.getEffect()) &&
touchesValue(effect.getValue());
});
}

bool MemLivenessAnalysis::IsWriteOnlyDpsInitForAlias(
Operation *op, const SetVector<Value> &aliasBuffers) const {
auto ptoDpsOp = dyn_cast<pto::PTO_DpsInitOpInterface>(op);
if (!ptoDpsOp)
return false;

bool hasAliasDpsInit = false;
for (Value init : ptoDpsOp.getDpsInits()) {
if (llvm::is_contained(aliasBuffers, init)) {
hasAliasDpsInit = true;
break;
}
}
if (!hasAliasDpsInit)
return false;

auto memEffect = dyn_cast<MemoryEffectOpInterface>(op);
if (!memEffect)
return false;
SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>,
kMemoryEffectReserveSize>
effects;
memEffect.getEffects(effects);

bool hasWrite = false;
for (const auto &effect : effects) {
Value value = effect.getValue();
if (!value || !llvm::is_contained(aliasBuffers, value))
continue;
if (isa<MemoryEffects::Read>(effect.getEffect()))
return false;
if (isa<MemoryEffects::Write>(effect.getEffect())) {
hasWrite = true;
continue;
}
return false;
}
return hasWrite;
}

bool MemLivenessAnalysis::CanKillBeforeNextOverwrite(
Operation *op, const SetVector<Value> &aliasBuffers) {
for (Operation *nextOp = op->getNextNode(); nextOp;
nextOp = nextOp->getNextNode()) {
if (!OperationOrNestedRegionTouchesAnyAlias(nextOp, aliasBuffers))
continue;

if (auto forOp = dyn_cast<scf::ForOp>(nextOp)) {
return llvm::any_of(aliasBuffers, [&](Value alias) {
return CanDelayLoopEntryGenUntilFirstWrite(forOp, alias);
});
}

return IsWriteOnlyDpsInitForAlias(nextOp, aliasBuffers);
}
return false;
}

bool MemLivenessAnalysis::CanRegenerateBufferAtOp(Operation *op,
Value buffer) {
SetVector<Value> aliasBuffers = GetAliasBuffers(buffer);
aliasBuffers.insert(buffer);
return IsWriteOnlyDpsInitForAlias(op, aliasBuffers);
}

LogicalResult
MemLivenessAnalysis::CheckLocalBufferAllocOp(Operation *op) const {
auto allocOp = dyn_cast<memref::AllocOp>(op);
Expand Down Expand Up @@ -732,11 +929,25 @@ void MemLivenessAnalysis::UpdateOperandGenInfo(OpInfo *opInfo, Value operand) {
if (iter_buffer == buffer2status.end())
return;
if (iter_buffer->second == BufferStatus::DEFFINED) {
genKillMap[opInfo].gen.push_back(operand);
appendUniqueValue(genKillMap[opInfo].gen, operand);
buffer2status[iter_buffer->first] = BufferStatus::GENED;
buffer2GenOp[iter_buffer->first] = opInfo->operation;
} else if (iter_buffer->second == BufferStatus::KILLED) {
llvm_unreachable("The buffer memory has been released and cannot be used "
"again! ");
if (!CanRegenerateBufferAtOp(opInfo->operation, operand)) {
llvm_unreachable("The buffer memory has been released and cannot be "
"used again before it is redefined! ");
}
appendUniqueValue(genKillMap[opInfo].gen, operand);
buffer2status[iter_buffer->first] = BufferStatus::GENED;
buffer2GenOp[iter_buffer->first] = opInfo->operation;
} else if (iter_buffer->second == BufferStatus::GENED) {
SetVector<Value> aliasBuffers = GetAliasBuffers(operand);
aliasBuffers.insert(operand);
if (IsWriteOnlyDpsInitForAlias(opInfo->operation, aliasBuffers)) {
appendUniqueValue(genKillMap[opInfo].kill, operand);
appendUniqueValue(genKillMap[opInfo].gen, operand);
buffer2GenOp[iter_buffer->first] = opInfo->operation;
}
}
}

Expand Down Expand Up @@ -764,15 +975,35 @@ void MemLivenessAnalysis::UpdateOpKillInfo(OpInfo *opInfo, Value operand,
auto iterBuffer = buffer2status.find(aliasBuffer);
if (iterBuffer == buffer2status.end())
return;
if (iterBuffer->second == BufferStatus::GENED &&
IsInSameBlock(iterBuffer->first.getDefiningOp(), opInfo->operation) &&
AllDeadAfter(opInfo->operation, aliasBuffers, live)) {
genKillMap[opInfo].kill.push_back(aliasBuffer);
Operation *defOp = iterBuffer->first.getDefiningOp();
bool canKillInThisBlock =
defOp && IsInSameBlock(defOp, opInfo->operation);
auto delayedGen = delayedLoopEntryGenBuffers.find(iterBuffer->first);
if (!canKillInThisBlock && delayedGen != delayedLoopEntryGenBuffers.end() &&
delayedGen->second) {
Operation *genOp = GetBufferGenOp(iterBuffer->first);
canKillInThisBlock = genOp && IsInSameBlock(genOp, opInfo->operation);
}
bool canKillCurrentValue =
AllDeadAfter(opInfo->operation, aliasBuffers, live) ||
(OperationReadsAnyAlias(opInfo->operation, aliasBuffers) &&
CanKillBeforeNextOverwrite(opInfo->operation, aliasBuffers));
if (iterBuffer->second == BufferStatus::GENED && canKillInThisBlock &&
canKillCurrentValue) {
appendUniqueValue(genKillMap[opInfo].kill, aliasBuffer);
buffer2status[iterBuffer->first] = BufferStatus::KILLED;
buffer2GenOp.erase(iterBuffer->first);
}
}
}

Operation *MemLivenessAnalysis::GetBufferGenOp(Value buffer) const {
auto it = buffer2GenOp.find(buffer);
if (it != buffer2GenOp.end())
return it->second;
return nullptr;
}

bool MemLivenessAnalysis::IsInSameBlock(Operation *op1, Operation *op2) const {
return op1->getBlock() == op2->getBlock();
}
Expand Down Expand Up @@ -839,25 +1070,43 @@ BufferInfo MemLivenessAnalysis::GetBufferInfo(Operation *op, Value operand,

void MemLivenessAnalysis::GenerateBufferLife() {
int scopeTime = 0;
DenseMap<Value, std::shared_ptr<BufferLife>> openLives;
for (size_t i = 0; i < linearOperation.size(); ++i) {
auto it = genKillMap.find(linearOperation[i].get());
if (it == genKillMap.end()) {
scopeTime++;
continue;
}
// Time given to buffer start.

SmallVector<Value> postGenKills;
for (const Value &killBuffer : it->second.kill) {
auto iter = openLives.find(killBuffer);
if (iter != openLives.end()) {
iter->second->freeTime = scopeTime;
openLives.erase(iter);
continue;
}
if (!llvm::is_contained(it->second.gen, killBuffer))
llvm::report_fatal_error("buffer lifetime killed before generation");
appendUniqueValue(postGenKills, killBuffer);
}

for (const Value &genBuffer : it->second.gen) {
std::unique_ptr<BufferLife> bufferLife =
std::make_unique<BufferLife>(genBuffer);
if (openLives.find(genBuffer) != openLives.end())
llvm::report_fatal_error("buffer lifetime generated before release");
std::shared_ptr<BufferLife> bufferLife =
std::make_shared<BufferLife>(genBuffer);
bufferLife->allocTime = scopeTime;
buffer2Life[genBuffer] = std::move(bufferLife);
buffer2Life[genBuffer].push_back(bufferLife);
openLives[genBuffer] = std::move(bufferLife);
}
// Time given to buffer end.
for (const Value &killBuffer : it->second.kill) {
auto iter = buffer2Life.find(killBuffer);
if (iter == buffer2Life.end())
llvm::report_fatal_error("buffer lifetime killed before generation");

for (const Value &killBuffer : postGenKills) {
auto iter = openLives.find(killBuffer);
if (iter == openLives.end())
llvm::report_fatal_error("buffer lifetime generated after release");
iter->second->freeTime = scopeTime;
openLives.erase(iter);
}
scopeTime++;
}
Expand Down Expand Up @@ -1068,21 +1317,25 @@ LogicalResult MemPlan::plan() {

void MemPlan::GenerateStorageEntry() {
// create new storage entry.
SetVector<Value> seenBuffers;
for (auto &operation : linearOperation) {
auto it = genKillMap.find(operation.get());
if (it == genKillMap.end())
continue;
SmallVector<Value> genBuffers(it->second.gen.begin(), it->second.gen.end());
sortValuesByStableOrder(genBuffers, stableValueOrder);
for (const Value &genBuffer : genBuffers) {
if (llvm::is_contained(seenBuffers, genBuffer))
continue;
auto iter = bufferInfos.find(genBuffer);
if (iter == bufferInfos.end()) {
continue;
}
const std::shared_ptr<BufferLife> &bufLife = buffer2Life.at(genBuffer);
seenBuffers.insert(genBuffer);
const BufferLifeVec &bufLives = buffer2Life.at(genBuffer);
std::unique_ptr<StorageEntry> entry = std::make_unique<StorageEntry>();
entry->bufInfo = &iter->second;
entry->bufferLifeVec.emplace_back(bufLife);
entry->bufferLifeVec.append(bufLives.begin(), bufLives.end());
entry->inplaceBuffers.emplace_back(iter->first);
auto multiBuffer = buffer2MultiNum.find(genBuffer);
if (multiBuffer != buffer2MultiNum.end()) {
Expand Down Expand Up @@ -2095,7 +2348,7 @@ void MemPlan::ReportAllocatedEntryDebugInfo(StorageEntry *rootStorageEntry) {
}
size_t num = allocatedEntry.size() - 1;
if (rootStorageEntry->mergedChildren.size() <= num)
llvm::report_fatal_error("missing failed storage entry");
return;
const StorageEntry *failedSe = rootStorageEntry->mergedChildren[num];
printRecord(failedSe);
LDBG("alloc fail,because exceed bound of memory \n"
Expand Down
Loading
Loading