Created
December 29, 2019 04:09
-
-
Save philzook58/40634d9be7a52e760476d041c9299c4e to your computer and use it in GitHub Desktop.
simple z3py proofs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def babylonian(x): | |
res = 1 | |
for i in range(7): | |
res = (x / res + res) / 2 | |
return res | |
x, y = Reals("x y") | |
prove(Implies(And(y**2 == x, y >= 0, 0 <= x, x <= 10), babylonian(x) - y <= 0.01)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
x = Int("x") | |
y = Int("y") | |
def Even(x): | |
q = FreshInt() | |
return Exists([q], x == 2*q) | |
def Odd(x): | |
return Not(Even(x)) | |
prove(Implies( And(Even(x), Odd(y)) , Odd(x + y))) | |
prove(Implies( And(Even(x), Even(y)) , Even(x + y))) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def inductionNat(f): # proves a predicate f forall nats by building s simple inductive version of f. | |
n = FreshInt() | |
return And(f(IntVal(0)), ForAll([n], Implies(And(n > 0, f(n)), f(n+1)))) | |
''' | |
# doesn't solve | |
sumn = Function('sumn', IntSort(), IntSort()) | |
n = FreshInt() | |
s = Solver() | |
s.add(ForAll([n], sumn(n) == If(n == 0, 0, n + sumn(n-1)))) | |
claim = ForAll([n], Implies( n >= 0, sumn(n) == n * (n+1) / 2)) | |
s.add(Not(claim)) | |
s.check() | |
''' | |
# solves immediately | |
sumn = Function('sumn', IntSort(), IntSort()) | |
n = FreshInt() | |
s = Solver() | |
s.add(ForAll([n], sumn(n) == If(n == 0, 0, n + sumn(n-1)))) | |
claim = inductionNat(lambda n : sumn(n) == n * (n+1) / 2) | |
s.add(Not(claim)) | |
s.check() #comes back unsat = proven |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Interval(): | |
def __init__(self,l,r): | |
self.l = l | |
self.r = r | |
def __add__(self,rhs): | |
if type(rhs) == Interval: | |
return Interval(self.l + rhs.l, self.r + rhs.r) | |
def __sub__(self, rhs): | |
return Interval(self.l) | |
def __mul__(self,rhs): | |
combos = [self.l * rhs.l, self.l * rhs.r, self.r * rhs.l, self.r*rhs.r] | |
return Interval( Min(*combos), Max(*combos)) | |
def fresh(): | |
l = FreshReal() | |
r = FreshReal() | |
return Interval(l,r) | |
def valid(self): # It is problematic that I have to rememeber to use this. A way around it? | |
return self.l <= self.r | |
def __le__(self,rhs): # Or( self.r < self.l ) (ie is bottom) | |
return And(rhs.l <= self.l, self.r <= rhs.r ) | |
def __lt__(self,rhs): | |
return And(rhs.l < self.l, self.r < rhs.r ) | |
def forall( eq ): | |
i = Interval.fresh() | |
return ForAll([i.l,i.r] , Implies(i.valid(), eq(i) )) | |
def elem(self,item): | |
return And(self.l <= item, item <= self.r) | |
def join(self,rhs): | |
return Interval(Min(self.l, rhs.l), Max(self.r, rhs.r)) | |
def meet(self,rhs): | |
return Interval(Max(self.l, rhs.l), Min(self.r, rhs.r)) | |
def width(self): | |
return self.r - self.l | |
def mid(self): | |
return (self.r + self.l)/2 | |
def bisect(self): | |
return Interval(self.l, self.mid()), Interval(self.mid(), self.r) | |
def point(x): | |
return Interval(x,x) | |
def recip(self): #assume 0 is not in | |
return Interval(1/self.r, 1/self.l) | |
def __truediv__(self,rhs): | |
return self * rhs.recip() | |
def __repr__(self): | |
return f"[{self.l} , {self.r}]" | |
def pos(self): | |
return And(self.l > 0, self.r > 0) | |
def neg(self): | |
return And(self.l < 0, self.r < 0) | |
def non_zero(self): | |
return Or(self.pos(), self.neg()) | |
x, y = Reals("x y") | |
i1 = Interval.fresh() | |
i2 = Interval.fresh() | |
i3 = Interval.fresh() | |
i4 = Interval.fresh() | |
prove(Implies(And(i1.elem(x), i2.elem(y)), (i1 + i2).elem(x + y))) | |
prove(Implies(And(i1.elem(x), i2.elem(y)), (i1 * i2).elem(x * y))) | |
prove(Implies( And(i1 <= i2, i2 <= i3), i1 <= i3 )) # transitivity of inclusion | |
prove( Implies( And(i1.valid(), i2.valid(), i3.valid()), i1 * (i2 + i3) <= i1 * i2 + i1 * i3)) #subdistributivty | |
# isotonic | |
prove(Implies( And( i1 <= i2, i3 <= i4 ), (i1 + i3) <= i2 + i4 )) | |
prove(Implies( And(i1.valid(), i2.valid(), i3.valid(), i4.valid(), i1 <= i2, i3 <= i4 ), (i1 * i3) <= i2 * i4 )) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from functools import reduce | |
def Max1(x,y): | |
return If(x <= y, y, x) | |
def Min1(x,y): | |
return If(x <= y, x, y) | |
def Abs(x): | |
return If(x <= 0, -x, x) | |
def Min(*args): | |
return reduce(Min1, args) | |
def Max(*args): | |
return reduce(Max1, args) | |
z = Real('z') | |
prove(z <= Max(x,y,z)) | |
prove(x <= Max(x,y)) | |
prove(Min(x,y) <= x) | |
prove(Min(x,y) <= y) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import operator as op | |
def NPArray(n, prefix=None, dtype=RealSort()): | |
return np.array( [FreshConst(dtype, prefix=prefix) for i in range(n)] ) | |
v = NPArray(3) | |
w = NPArray(3) | |
l = Real("l") | |
prove( np.dot(v,w * l) == l * np.dot(v,w) ) # linearity of dot product | |
prove(np.dot(v, w)**2 <= np.dot(v,v) * np.dot(w,w)) # cauchy schwartz | |
def vec_eq(x,y): # a vectorized z3 equality | |
return And(np.vectorize(op.eq)(x,y).tolist()) | |
prove( vec_eq((v + w) * l, v * l + w * l)) # distributivity of scalar multiplication | |
z = NPArray(9).reshape(3,3) # some matrix | |
prove( vec_eq( z @ (v + w) , z @ v + z @ w )) # linearity of matrix multiply | |
prove( vec_eq( z @ (v * l) , (z @ v) * l)) # linearity of matrix multiply |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#https://z3prover.github.io/api/html/namespacez3py.html#a2f0f4611f0b706d666a8227b6347266a | |
def prove(claim, **keywords): | |
"""Try to prove the given claim. | |
This is a simple function for creating demonstrations. It tries to prove | |
`claim` by showing the negation is unsatisfiable. | |
>>> p, q = Bools('p q') | |
>>> prove(Not(And(p, q)) == Or(Not(p), Not(q))) | |
proved | |
""" | |
if z3_debug(): | |
_z3_assert(is_bool(claim), "Z3 Boolean expression expected") | |
s = Solver() | |
s.set(**keywords) | |
s.add(Not(claim)) | |
if keywords.get('show', False): | |
print(s) | |
r = s.check() | |
if r == unsat: | |
print("proved") | |
elif r == unknown: | |
print("failed to prove") | |
print(s.model()) | |
else: | |
print("counterexample") | |
print(s.model()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from z3 import * | |
p = Bool("p") | |
q = Bool("q") | |
prove(Implies(And(p,q), p)) # simple destruction of the And | |
prove( And(p,q) == Not(Or(Not(p),Not(q)))) #De Morgan's Law | |
x = Real("x") | |
y = Real("y") | |
z = Real("z") | |
prove(x + y == y + x) #Commutativity | |
prove(((x + y) + z) == ((x + (y + z)))) #associativity | |
prove(x + 0 == x) # 0 additive identity | |
prove(1 * x == x) | |
prove(Or(x > 0, x < 0, x == 0)) #trichotomy | |
prove(x**2 >= 0) #positivity of a square | |
prove(x * (y + z) == x * y + x * z) #distributive law |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
sin = Function("sin", RealSort(), RealSort()) | |
cos = Function("cos", RealSort(), RealSort()) | |
x = Real('x') | |
trig = [sin(0) == 0, | |
cos(0) == 1, | |
sin(180) == 0, | |
cos(180) == -1, # Using degrees is easier than radians. We have no pi. | |
ForAll([x], sin(2*x) == 2*sin(x)*cos(x)), | |
ForAll([x], sin(x)*sin(x) + cos(x) * cos(x) == 1), | |
ForAll([x], cos(2*x) == cos(x)*cos(x) - sin(x) * sin(x))] | |
s = Solver() | |
s.set(auto_config=False, mbqi=False) | |
s.add(trig) | |
s.add( RealVal(1 / np.sqrt(2) + 0.0000000000000001) <= cos(45)) | |
s.check() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment