Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 116 additions & 3 deletions src/main/scala/analysis/Domains.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package analysis

import analysis.Interval.ConcreteInterval
import ir.*
import ir.transforms.AbstractDomain
import util.assertion.*
Expand All @@ -9,9 +10,12 @@ trait MustAnalysis

/** A domain that performs two analyses in parallel.
*/
class ProductDomain[L1, L2](d1: AbstractDomain[L1], d2: AbstractDomain[L2]) extends AbstractDomain[(L1, L2)] {
trait ProductDomain[L1, L2] extends AbstractDomain[(L1, L2)] {
def join(a: (L1, L2), b: (L1, L2), pos: Block): (L1, L2) = (d1.join(a._1, b._1, pos), d2.join(a._2, b._2, pos))

def d1: AbstractDomain[L1]
def d2: AbstractDomain[L2]

override def widen(a: (L1, L2), b: (L1, L2), pos: Block): (L1, L2) =
(d1.widen(a._1, b._1, pos), d2.widen(a._2, b._2, pos))
override def narrow(a: (L1, L2), b: (L1, L2)): (L1, L2) = (d1.narrow(a._1, b._1), d2.narrow(a._2, b._2))
Expand All @@ -25,11 +29,120 @@ class ProductDomain[L1, L2](d1: AbstractDomain[L1], d2: AbstractDomain[L2]) exte
def bot: (L1, L2) = (d1.bot, d2.bot)
}

/**
* A domain that is the reduced product of two other domains.
*/
trait ReducedProductDomain[L1, L2] extends ProductDomain[L1, L2] {
def reduce(a: (L1, L2), c: Command): (L1, L2)
override def transfer(a: (L1, L2), c: Command): (L1, L2) = {
val things = (d1.transfer(a._1, c), d2.transfer(a._2, c))
// PERFORMANCE: This reduces every variable every transfer. Needless to say that this is
// quite inefficient.
reduce(things, c)
}
}

private implicit val intervalTerm: Interval = Interval.Bottom

/**
* The reduced product between the interval and tnum domains.
*/
class TNumIntervalReducedProduct extends ReducedProductDomain[LatticeMap[Variable, Interval], Map[Variable, TNum]] {
override def d1: AbstractDomain[LatticeMap[Variable, Interval]] = UnsignedIntervalDomain()
override def d2: AbstractDomain[Map[Variable, TNum]] = TNumDomain()

override def reduce(
unreduced: (LatticeMap[Variable, Interval], Map[Variable, TNum]),
c: Command
): (LatticeMap[Variable, Interval], Map[Variable, TNum]) = {
def reduceSingle(int: Interval, x: TNum): (Interval, TNum) = {

int match {
case int: ConcreteInterval =>
val interval =
Interval.ConcreteInterval(refineLowerBound(int.lower, x), refineUpperBound(int.upper, x), int.width)
val tnum = refineTnum(int.lower, int.upper, x)
(interval, tnum)
case Interval.Top => (Interval.Top, x)
case Interval.Bottom => (Interval.Bottom, x)
}
}

c match {
case c: LocalAssign =>
val result = reduceSingle(unreduced._1(c.lhs), unreduced._2(c.lhs))
(unreduced._1 + (c.lhs -> result._1), unreduced._2 + (c.lhs -> result._2))
case c: MemoryAssign =>
val result = reduceSingle(unreduced._1(c.lhs), unreduced._2(c.lhs))
(unreduced._1 + (c.lhs -> result._1), unreduced._2 + (c.lhs -> result._2))
case c: SimulAssign =>
val (ints, tnums) = c.assignments
.map((v, e) => v)
.map(v => (v, reduceSingle(unreduced._1(v), unreduced._2(v))))
.map((v, t) => ((v, t._1), (v, t._2)))
.unzip
(unreduced._1 ++ ints.toMap, unreduced._2 ++ tnums)
case c: Return =>
val (ints, tnums) = c.outParams
.map((v, e) => v)
.map(v => (v, reduceSingle(unreduced._1(v), unreduced._2(v))))
.map((v, t) => ((v, t._1), (v, t._2)))
.unzip
(unreduced._1 ++ ints.toMap, unreduced._2 ++ tnums)
case _ => unreduced
}
}

/* These three methods (refine*()) are only exposed publicly so they can be tested, since they
* contain most of the logic for this domain. You probably shouldn't use them.
*/
def refineLowerBound(a: BigInt, x: TNum): BigInt = {
var newBound = x.maxUnsignedValue().value
for i <- x.width to 0 by -1
do
if 0 != (x.mask.value & (1 << i)) then
newBound = newBound & (~(1 << i)) // Unset the ith bit
if newBound < a then newBound = newBound | (1 << i) // Re-set the ith bit
newBound
}

def refineUpperBound(a: BigInt, x: TNum): BigInt = {
var newBound = x.minUnsignedValue().value
for i <- x.width to 0 by -1
do
if 0 != (x.mask.value & (1 << i)) then
newBound = newBound | (1 << i) // Set the ith bit
if newBound > a then newBound = newBound & (~(1 << i)) // Unset the ith bit
newBound
}

def refineTnum(a: BigInt, b: BigInt, x: TNum): TNum = {
val mask = ~(a ^ b)
var lb = a // An extra copy to obliterate instead of breaking from the loop
var stupidVariable = 1
var value = x.value.value
var tnumMask = x.mask.value
for i <- x.width - 1 to 0 by -1 do
if (mask & (1 << i)) == 0 then
lb = 0
stupidVariable = 0 // Because break doesn't exist because actually this for loop is a
// method call! or something. Just do nothing for the rest of the loop instead.
/* if !(((value & (1 << i)) > 0) || ((value & (1 << i)) == (a & (1 << i)))) then
* TODO: go to bottom here because interval and tnum don't overlap */
value = value | (lb & (1 << i))
tnumMask = tnumMask & (~(stupidVariable << i))

TNum(BitVecLiteral(value, x.width), BitVecLiteral(tnumMask, x.width))
}
}

/**
* Encodes the conjunction of two domain predicates.
*/
class PredProductDomain[L1, L2](d1: PredicateEncodingDomain[L1], d2: PredicateEncodingDomain[L2])
extends ProductDomain[L1, L2](d1, d2)
class PredProductDomain[L1, L2](
override val d1: PredicateEncodingDomain[L1],
override val d2: PredicateEncodingDomain[L2]
) extends ProductDomain[L1, L2]
with PredicateEncodingDomain[(L1, L2)] {

def toPred(x: (L1, L2)): Predicate = Predicate.and(d1.toPred(x._1), d2.toPred(x._2))
Expand Down
4 changes: 4 additions & 0 deletions src/main/scala/analysis/KnownBits.scala
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ case class TNum(value: BitVecLiteral, mask: BitVecLiteral) {
TNum(n, 0.bv(n.size))
}

def maxUnsignedValue(): BitVecLiteral = value | mask

def minUnsignedValue(): BitVecLiteral = value & (~mask)

override def toString() = {
val padwidth = width / 4 + (if width % 4 != 0 then 1 else 0)
def padded(number: BigInt) = {
Expand Down
160 changes: 160 additions & 0 deletions src/test/scala/TNumIntervalReducedProductTest.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
import analysis.Interval.ConcreteInterval
import analysis.{TNum, TNumIntervalReducedProduct}
import ir.*
import ir.dsl.*
import ir.transforms.{reversePostOrder, worklistSolver}
import org.scalatest.funsuite.AnyFunSuiteLike

@test_util.tags.UnitTest
class TNumIntervalReducedProductTest extends AnyFunSuiteLike {

val domain = TNumIntervalReducedProduct()
val interval = ConcreteInterval(6, 10, 4)
val tnum = TNum(BitVecLiteral(0, 4), BitVecLiteral(9, 4))

val kbitsProg = prog(
proc(
"knownBitsExample_4196164",
Seq("R0_in" -> BitVecType(64), "R1_in" -> BitVecType(64)),
Seq("R0_out" -> BitVecType(64), "R2_out" -> BitVecType(64), "R3_out" -> BitVecType(64)),
block(
"lknownBitsExample",
LocalAssign(
LocalVar("R2", BitVecType(64), 2),
BinaryExpr(
BVOR,
BinaryExpr(BVAND, LocalVar("R0_in", BitVecType(64), 0), BitVecLiteral(BigInt("18374966859414961920"), 64)),
BitVecLiteral(BigInt("18446744069414584320"), 64)
),
Some("%0000023e")
),
LocalAssign(
LocalVar("R0", BitVecType(64), 3),
BinaryExpr(
BVOR,
BinaryExpr(BVAND, LocalVar("R0_in", BitVecType(64), 0), BitVecLiteral(BigInt("18374966859414961920"), 64)),
BitVecLiteral(BigInt("71777218305454335"), 64)
),
Some("%00000257")
),
goto("lknownBitsExample_phi_lknownBitsExample_goto_l00000271", "lknownBitsExample_goto_l0000026d")
),
block(
"l00000274",
LocalAssign(
LocalVar("R0", BitVecType(64), 6),
ZeroExtend(
4,
BinaryExpr(
BVSHL,
ZeroExtend(8, Extract(60, 8, LocalVar("R2", BitVecType(64), 9))),
BitVecLiteral(BigInt("8"), 60)
)
),
Some("%0000027e")
),
LocalAssign(
LocalVar("R0", BitVecType(64), 7),
BinaryExpr(BVOR, LocalVar("R0", BitVecType(64), 6), BitVecLiteral(BigInt("15"), 64)),
Some("%00000284")
),
goto("l00000274_phi_l00000274_goto_l00000299", "l00000274_goto_l0000029d")
),
block(
"l000002a0",
Assert(TrueLiteral, Some("is returning to caller-set R30"), None),
goto("knownBitsExample_4196164_basil_return")
),
block(
"lknownBitsExample_goto_l0000026d",
Assume(BinaryExpr(EQ, LocalVar("R1_in", BitVecType(64), 0), BitVecLiteral(BigInt("0"), 64)), None, None, true),
LocalAssign(LocalVar("R2", BitVecType(64), 9), LocalVar("R0", BitVecType(64), 3), Some("phiback")),
goto("l00000274")
),
block(
"l00000274_goto_l0000029d",
Assume(
BinaryExpr(EQ, Extract(16, 0, LocalVar("R2", BitVecType(64), 9)), BitVecLiteral(BigInt("0"), 16)),
None,
None,
true
),
LocalAssign(LocalVar("R0", BitVecType(64), 14), LocalVar("R2", BitVecType(64), 9), Some("phiback")),
goto("l000002a0")
),
block(
"knownBitsExample_4196164_basil_return",
ret(
"R0_out" -> LocalVar("R0", BitVecType(64), 14),
"R2_out" -> LocalVar("R2", BitVecType(64), 9),
"R3_out" -> BitVecLiteral(BigInt("71777218305454335"), 64)
)
),
block(
"lknownBitsExample_phi_lknownBitsExample_goto_l00000271",
Assume(
UnaryExpr(BoolNOT, BinaryExpr(EQ, LocalVar("R1_in", BitVecType(64), 0), BitVecLiteral(BigInt("0"), 64))),
None,
None,
true
),
LocalAssign(LocalVar("R2", BitVecType(64), 9), LocalVar("R2", BitVecType(64), 2), Some("phiback")),
goto("l00000274")
),
block(
"l00000274_phi_l00000274_goto_l00000299",
Assume(
UnaryExpr(
BoolNOT,
BinaryExpr(EQ, Extract(16, 0, LocalVar("R2", BitVecType(64), 9)), BitVecLiteral(BigInt("0"), 16))
),
None,
None,
true
),
LocalAssign(LocalVar("R0", BitVecType(64), 14), LocalVar("R0", BitVecType(64), 7), Some("phiback")),
goto("l000002a0")
)
)
)
val littleProc = kbitsProg.nameToProcedure("knownBitsExample_4196164")

test("Refine lower bound - simple example") {
val expected = 8
assert(expected == domain.refineLowerBound(6, tnum))
}

test("Refine lower bound - 0") {
val expected = 0
assert(expected == domain.refineLowerBound(0, tnum))
}

test("Refine upper bound - simple example") {
val expected = 9
assert(expected == domain.refineUpperBound(10, tnum))
}

test("Refine upper bound - 0") {
val expected = 0
assert(expected == domain.refineUpperBound(0, tnum))
}

test("Refine tnum - simple example") {
val expected = TNum(BitVecLiteral(8, 4), BitVecLiteral(1, 4))
assert(expected == domain.refineTnum(8, 9, tnum))
}

test("Refine tnum - 0") {
val expected = TNum(BitVecLiteral(0, 4), BitVecLiteral(0, 4))
assert(expected == domain.refineTnum(0, 0, tnum))
}

test("Run on some code!") {
val domain = TNumIntervalReducedProduct()
val solver = worklistSolver(domain)
reversePostOrder(littleProc)
print(solver.solveProc(littleProc, backwards = false))
assert(true)
}

}
Loading