Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 51 additions & 30 deletions src/main/scala/analysis/data_structure_analysis/IntervalDSA.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class IntervalGraph(
var phase: DSAPhase,
val irContext: IRContext,
var sva: SymValues[OSet],
val constraints: Set[Constraint],
var constraints: Set[Constraint],
val glIntervals: Seq[DSInterval],
val eqCells: Boolean,
val nodeBuilder: Option[() => Map[SymBase, IntervalNode]]
Expand All @@ -44,8 +44,9 @@ class IntervalGraph(
p.scc.isDefined || CallGraph.pred(p).exists(calledBySCC)
}

def exprToSymVal(expr: Expr): SymValSet[OSet] =
def exprToSymVal(expr: Expr): SymValSet[OSet] = {
SymValues.exprToSymValSet(sva, i => isGlobal(i, irContext), glIntervals)(expr)
}

protected def symValToNodes(
symVal: SymValSet[OSet],
Expand Down Expand Up @@ -245,19 +246,18 @@ class IntervalGraph(
debugAssert(IntervalDSA.equiv(valueCells), s"value cells should be unified instead got $valueCells")

val indexCells = constraintArgToCells(constraint.arg1, ignoreContents = true).map(get)
if indexCells.nonEmpty then
if indexCells.nonEmpty && valueCells.nonEmpty then
indexCells.foreach(indexCell =>
valueCells.foreach(valueCell =>
debugAssert(indexCell.node.isUptoDate, "outdated cell in local correctness check")
debugAssert(indexCell.getPointee.node.isUptoDate, "outdated cell in local correctness check")
debugAssert(valueCell.node.isUptoDate, "outdated cell in local correctness check")
debugAssert(
indexCell.hasPointee && indexCell.getPointee.equiv(valueCell),
s"$constraint, $indexCell doesn't point to ${valueCell} instead ${indexCell.getPointee}"
)
)
)
for {
indexCell <- indexCells
valueCell <- valueCells
} {
debugAssert(indexCell.node.isUptoDate, "outdated cell in local correctness check")
debugAssert(indexCell.getPointee.node.isUptoDate, "outdated cell in local correctness check")
debugAssert(valueCell.node.isUptoDate, "outdated cell in local correctness check")
debugAssert(
indexCell.hasPointee && indexCell.getPointee.equiv(valueCell),
s"$constraint, $indexCell doesn't point to $valueCell instead ${indexCell.getPointee}"
)
}
case _ =>
}
}
Expand Down Expand Up @@ -359,8 +359,12 @@ class IntervalGraph(
val cells = exprToCells(constraintArg.value, newEqv)
val exprCells = cells.map(find)

if constraintArg.contents && !ignoreContents then exprCells.map(_.getPointee)
else exprCells
if (constraintArg.contents && !ignoreContents) {
val pointees = exprCells.map(_.getPointee)
pointees
} else {
exprCells
}

}

Expand Down Expand Up @@ -914,13 +918,17 @@ class IntervalCell(val node: IntervalNode, val interval: DSInterval) {
* Can check if a cell has pointee without creating one for it with hasPointee
*/
def getPointee: IntervalCell = {
if node.get(this.interval) ne this then node.get(this.interval).getPointee
else if _pointee.isEmpty then
// throw Exception("expected a pointee")
debugAssert(this.node.isUptoDate)
val cell = node.get(this.interval)
if (cell ne this) {
cell.getPointee
} else if (_pointee.isEmpty) {
// throw Exception("expected a pointee")
debugAssert(node.isUptoDate)
_pointee = Some(IntervalNode(graph, Map.empty).add(0))
graph.find(_pointee.get)
else graph.find(_pointee.get)
} else {
graph.find(_pointee.get)
}
}

def hasPointee: Boolean = node.get(this.interval)._pointee.nonEmpty
Expand Down Expand Up @@ -1047,14 +1055,22 @@ object IntervalDSA {

def unifyGraphs(source: IntervalGraph, target: IntervalGraph)(using svDomain: SymValSetDomain[OSet]) = {
val oldToNew = mutable.Map[IntervalNode, IntervalNode]()
source.proc.formalInParam.foreach(p =>
source.proc.formalInParam.foreach { p =>
val base = Par(target.proc, p)
if !target.nodes.contains(base) then {
target.nodes += (base -> IntervalNode(target, Map(base -> Set(0))))
target.sva = SymValues(target.sva.state + (p -> svDomain.init(base)))
if (!target.nodes.contains(base)) {
val node = IntervalNode(target, Map(base -> Set(0)))
target.nodes += (base -> node)
if (target.sva.state.contains(p)) {
val cells = target.exprToCells(p).map(target.find)
val newNodeCell = node.get(0)
val allCells = cells + newNodeCell
target.mergeCells(allCells)
} else {
target.sva = SymValues(target.sva.state + (p -> svDomain.init(base)))
}
}
exprTransfer(p, p, source, target, oldToNew)
)
}

/*source.proc.formalOutParam.foreach(
p =>
Expand Down Expand Up @@ -1336,14 +1352,19 @@ object IntervalDSA {
require(scc.forall(graphs.keySet.contains))
require(scc.size > 1)
val sscc = scc.toSeq.sortBy(_.formalInParam.size).reverse
val head = graphs(scc.head)
sscc.tail.foreach(p => IntervalDSA.unifyGraphs(graphs(p), head))
val head = graphs(sscc.head)
sscc.tail.foreach { p =>
IntervalDSA.unifyGraphs(graphs(p), head)
}
val allConstraints = scc.flatMap(graphs(_).constraints)
head.constraints = allConstraints
scc.foreach(p => graphs.update(p, head))
}

/**
* resolves bottom up constraints
* @param init root node in call graph (main)
*
* @param init root node in call graph (main)
* @param locals mapping from procedures to their DSG after the local phase
* @return procedures
*/
Expand Down