Skip to content

Instantly share code, notes, and snippets.

@rickyz
Created February 16, 2022 00:07
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rickyz/dd9e7c5b73ccf176a80752adf5a0b3f6 to your computer and use it in GitHub Desktop.
Save rickyz/dd9e7c5b73ccf176a80752adf5a0b3f6 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
import enum
import itertools
from z3 import *
BITS = 32
class Predicate(enum.Enum):
UGT = enum.auto()
UGE = enum.auto()
ULT = enum.auto()
ULE = enum.auto()
SGT = enum.auto()
SGE = enum.auto()
SLT = enum.auto()
SLE = enum.auto()
UNSIGNED_PREDS = [Predicate.UGT, Predicate.UGE, Predicate.ULT, Predicate.ULE]
SIGNED_PREDS = [Predicate.SGT, Predicate.SGE, Predicate.SLT, Predicate.SLE]
def compare(pred, lhs, rhs):
ops = {
Predicate.UGT: lambda lhs, rhs: UGT(lhs, rhs),
Predicate.UGE: lambda lhs, rhs: UGE(lhs, rhs),
Predicate.ULT: lambda lhs, rhs: ULT(lhs, rhs),
Predicate.ULE: lambda lhs, rhs: ULE(lhs, rhs),
Predicate.SGT: lambda lhs, rhs: lhs > rhs,
Predicate.SGE: lambda lhs, rhs: lhs >= rhs,
Predicate.SLT: lambda lhs, rhs: lhs < rhs,
Predicate.SLE: lambda lhs, rhs: lhs <= rhs,
}
return ops[pred](lhs, rhs)
def signed_max(bits):
return BitVecVal((1 << (bits - 1)) - 1, bits)
def verify_canonicalize_clamp_like(pred0, pred1):
s = Solver()
x = BitVec('x', BITS)
c0 = BitVec('c0', BITS)
c1 = BitVec('c1', BITS)
c2 = BitVec('c2', BITS)
replacement_low = BitVec('replacement_low', BITS)
replacement_high = BitVec('replacement_high', BITS)
if pred0 in [Predicate.ULT, Predicate.UGE]:
s.add(c0 != 0)
else:
s.add(c0 != BitVecVal(-1, BITS))
pred0 = {Predicate.ULE: Predicate.ULT, Predicate.UGT: Predicate.UGE}[pred0]
c0 += 1
if pred1 == Predicate.SLE:
return None
if pred1 == Predicate.SGT:
# Note: Still passes with this line removed. This is probably only
# because we do not accurately model LLVM's poison/overflow
# semantics.
s.add(c2 != signed_max(BITS))
c2 += 1
if pred1 in [Predicate.SGT, Predicate.SGE]:
pred1 = Predicate.SLT
replacement_low, replacement_high = replacement_high, replacement_low
assert pred0 in (Predicate.ULT, Predicate.UGE)
assert pred1 == Predicate.SLT
threshold_low_incl = -c1
threshold_high_excl = c0 - c1
if pred0 == Predicate.UGE:
threshold_low_incl, threshold_high_excl = threshold_high_excl, threshold_low_incl
s.add(c2 >= threshold_low_incl)
s.add(c2 <= threshold_high_excl)
old = If(
compare(pred0, x + c1, c0), x,
If(compare(pred1, x, c2), replacement_low, replacement_high))
new = If(x >= threshold_high_excl, replacement_high,
If(x < threshold_low_incl, replacement_low, x))
s.add(old != new)
if s.check() != unsat:
return s.model()
return None
def main():
for pred0, pred1 in itertools.product(UNSIGNED_PREDS, SIGNED_PREDS):
print(f'Testing {pred0}, {pred1}... ', end='')
m = verify_canonicalize_clamp_like(pred0, pred1)
if m is None:
print('pass')
continue
print('fail')
for d in m.decls():
print(f' {d.name()} = {m[d]}')
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment