Created
February 16, 2022 00:07
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
import enum | |
import itertools | |
from z3 import * | |
BITS = 32 | |
class Predicate(enum.Enum): | |
UGT = enum.auto() | |
UGE = enum.auto() | |
ULT = enum.auto() | |
ULE = enum.auto() | |
SGT = enum.auto() | |
SGE = enum.auto() | |
SLT = enum.auto() | |
SLE = enum.auto() | |
UNSIGNED_PREDS = [Predicate.UGT, Predicate.UGE, Predicate.ULT, Predicate.ULE] | |
SIGNED_PREDS = [Predicate.SGT, Predicate.SGE, Predicate.SLT, Predicate.SLE] | |
def compare(pred, lhs, rhs): | |
ops = { | |
Predicate.UGT: lambda lhs, rhs: UGT(lhs, rhs), | |
Predicate.UGE: lambda lhs, rhs: UGE(lhs, rhs), | |
Predicate.ULT: lambda lhs, rhs: ULT(lhs, rhs), | |
Predicate.ULE: lambda lhs, rhs: ULE(lhs, rhs), | |
Predicate.SGT: lambda lhs, rhs: lhs > rhs, | |
Predicate.SGE: lambda lhs, rhs: lhs >= rhs, | |
Predicate.SLT: lambda lhs, rhs: lhs < rhs, | |
Predicate.SLE: lambda lhs, rhs: lhs <= rhs, | |
} | |
return ops[pred](lhs, rhs) | |
def signed_max(bits): | |
return BitVecVal((1 << (bits - 1)) - 1, bits) | |
def verify_canonicalize_clamp_like(pred0, pred1): | |
s = Solver() | |
x = BitVec('x', BITS) | |
c0 = BitVec('c0', BITS) | |
c1 = BitVec('c1', BITS) | |
c2 = BitVec('c2', BITS) | |
replacement_low = BitVec('replacement_low', BITS) | |
replacement_high = BitVec('replacement_high', BITS) | |
if pred0 in [Predicate.ULT, Predicate.UGE]: | |
s.add(c0 != 0) | |
else: | |
s.add(c0 != BitVecVal(-1, BITS)) | |
pred0 = {Predicate.ULE: Predicate.ULT, Predicate.UGT: Predicate.UGE}[pred0] | |
c0 += 1 | |
if pred1 == Predicate.SLE: | |
return None | |
if pred1 == Predicate.SGT: | |
# Note: Still passes with this line removed. This is probably only | |
# because we do not accurately model LLVM's poison/overflow | |
# semantics. | |
s.add(c2 != signed_max(BITS)) | |
c2 += 1 | |
if pred1 in [Predicate.SGT, Predicate.SGE]: | |
pred1 = Predicate.SLT | |
replacement_low, replacement_high = replacement_high, replacement_low | |
assert pred0 in (Predicate.ULT, Predicate.UGE) | |
assert pred1 == Predicate.SLT | |
threshold_low_incl = -c1 | |
threshold_high_excl = c0 - c1 | |
if pred0 == Predicate.UGE: | |
threshold_low_incl, threshold_high_excl = threshold_high_excl, threshold_low_incl | |
s.add(c2 >= threshold_low_incl) | |
s.add(c2 <= threshold_high_excl) | |
old = If( | |
compare(pred0, x + c1, c0), x, | |
If(compare(pred1, x, c2), replacement_low, replacement_high)) | |
new = If(x >= threshold_high_excl, replacement_high, | |
If(x < threshold_low_incl, replacement_low, x)) | |
s.add(old != new) | |
if s.check() != unsat: | |
return s.model() | |
return None | |
def main(): | |
for pred0, pred1 in itertools.product(UNSIGNED_PREDS, SIGNED_PREDS): | |
print(f'Testing {pred0}, {pred1}... ', end='') | |
m = verify_canonicalize_clamp_like(pred0, pred1) | |
if m is None: | |
print('pass') | |
continue | |
print('fail') | |
for d in m.decls(): | |
print(f' {d.name()} = {m[d]}') | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment