Skip to content

Instantly share code, notes, and snippets.

@philzook58
Created December 29, 2019 04:09
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save philzook58/40634d9be7a52e760476d041c9299c4e to your computer and use it in GitHub Desktop.
Save philzook58/40634d9be7a52e760476d041c9299c4e to your computer and use it in GitHub Desktop.
simple z3py proofs
def babylonian(x):
res = 1
for i in range(7):
res = (x / res + res) / 2
return res
x, y = Reals("x y")
prove(Implies(And(y**2 == x, y >= 0, 0 <= x, x <= 10), babylonian(x) - y <= 0.01))
x = Int("x")
y = Int("y")
def Even(x):
q = FreshInt()
return Exists([q], x == 2*q)
def Odd(x):
return Not(Even(x))
prove(Implies( And(Even(x), Odd(y)) , Odd(x + y)))
prove(Implies( And(Even(x), Even(y)) , Even(x + y)))
def inductionNat(f): # proves a predicate f forall nats by building s simple inductive version of f.
n = FreshInt()
return And(f(IntVal(0)), ForAll([n], Implies(And(n > 0, f(n)), f(n+1))))
'''
# doesn't solve
sumn = Function('sumn', IntSort(), IntSort())
n = FreshInt()
s = Solver()
s.add(ForAll([n], sumn(n) == If(n == 0, 0, n + sumn(n-1))))
claim = ForAll([n], Implies( n >= 0, sumn(n) == n * (n+1) / 2))
s.add(Not(claim))
s.check()
'''
# solves immediately
sumn = Function('sumn', IntSort(), IntSort())
n = FreshInt()
s = Solver()
s.add(ForAll([n], sumn(n) == If(n == 0, 0, n + sumn(n-1))))
claim = inductionNat(lambda n : sumn(n) == n * (n+1) / 2)
s.add(Not(claim))
s.check() #comes back unsat = proven
class Interval():
def __init__(self,l,r):
self.l = l
self.r = r
def __add__(self,rhs):
if type(rhs) == Interval:
return Interval(self.l + rhs.l, self.r + rhs.r)
def __sub__(self, rhs):
return Interval(self.l)
def __mul__(self,rhs):
combos = [self.l * rhs.l, self.l * rhs.r, self.r * rhs.l, self.r*rhs.r]
return Interval( Min(*combos), Max(*combos))
def fresh():
l = FreshReal()
r = FreshReal()
return Interval(l,r)
def valid(self): # It is problematic that I have to rememeber to use this. A way around it?
return self.l <= self.r
def __le__(self,rhs): # Or( self.r < self.l ) (ie is bottom)
return And(rhs.l <= self.l, self.r <= rhs.r )
def __lt__(self,rhs):
return And(rhs.l < self.l, self.r < rhs.r )
def forall( eq ):
i = Interval.fresh()
return ForAll([i.l,i.r] , Implies(i.valid(), eq(i) ))
def elem(self,item):
return And(self.l <= item, item <= self.r)
def join(self,rhs):
return Interval(Min(self.l, rhs.l), Max(self.r, rhs.r))
def meet(self,rhs):
return Interval(Max(self.l, rhs.l), Min(self.r, rhs.r))
def width(self):
return self.r - self.l
def mid(self):
return (self.r + self.l)/2
def bisect(self):
return Interval(self.l, self.mid()), Interval(self.mid(), self.r)
def point(x):
return Interval(x,x)
def recip(self): #assume 0 is not in
return Interval(1/self.r, 1/self.l)
def __truediv__(self,rhs):
return self * rhs.recip()
def __repr__(self):
return f"[{self.l} , {self.r}]"
def pos(self):
return And(self.l > 0, self.r > 0)
def neg(self):
return And(self.l < 0, self.r < 0)
def non_zero(self):
return Or(self.pos(), self.neg())
x, y = Reals("x y")
i1 = Interval.fresh()
i2 = Interval.fresh()
i3 = Interval.fresh()
i4 = Interval.fresh()
prove(Implies(And(i1.elem(x), i2.elem(y)), (i1 + i2).elem(x + y)))
prove(Implies(And(i1.elem(x), i2.elem(y)), (i1 * i2).elem(x * y)))
prove(Implies( And(i1 <= i2, i2 <= i3), i1 <= i3 )) # transitivity of inclusion
prove( Implies( And(i1.valid(), i2.valid(), i3.valid()), i1 * (i2 + i3) <= i1 * i2 + i1 * i3)) #subdistributivty
# isotonic
prove(Implies( And( i1 <= i2, i3 <= i4 ), (i1 + i3) <= i2 + i4 ))
prove(Implies( And(i1.valid(), i2.valid(), i3.valid(), i4.valid(), i1 <= i2, i3 <= i4 ), (i1 * i3) <= i2 * i4 ))
from functools import reduce
def Max1(x,y):
return If(x <= y, y, x)
def Min1(x,y):
return If(x <= y, x, y)
def Abs(x):
return If(x <= 0, -x, x)
def Min(*args):
return reduce(Min1, args)
def Max(*args):
return reduce(Max1, args)
z = Real('z')
prove(z <= Max(x,y,z))
prove(x <= Max(x,y))
prove(Min(x,y) <= x)
prove(Min(x,y) <= y)
import numpy as np
import operator as op
def NPArray(n, prefix=None, dtype=RealSort()):
return np.array( [FreshConst(dtype, prefix=prefix) for i in range(n)] )
v = NPArray(3)
w = NPArray(3)
l = Real("l")
prove( np.dot(v,w * l) == l * np.dot(v,w) ) # linearity of dot product
prove(np.dot(v, w)**2 <= np.dot(v,v) * np.dot(w,w)) # cauchy schwartz
def vec_eq(x,y): # a vectorized z3 equality
return And(np.vectorize(op.eq)(x,y).tolist())
prove( vec_eq((v + w) * l, v * l + w * l)) # distributivity of scalar multiplication
z = NPArray(9).reshape(3,3) # some matrix
prove( vec_eq( z @ (v + w) , z @ v + z @ w )) # linearity of matrix multiply
prove( vec_eq( z @ (v * l) , (z @ v) * l)) # linearity of matrix multiply
#https://z3prover.github.io/api/html/namespacez3py.html#a2f0f4611f0b706d666a8227b6347266a
def prove(claim, **keywords):
"""Try to prove the given claim.
This is a simple function for creating demonstrations. It tries to prove
`claim` by showing the negation is unsatisfiable.
>>> p, q = Bools('p q')
>>> prove(Not(And(p, q)) == Or(Not(p), Not(q)))
proved
"""
if z3_debug():
_z3_assert(is_bool(claim), "Z3 Boolean expression expected")
s = Solver()
s.set(**keywords)
s.add(Not(claim))
if keywords.get('show', False):
print(s)
r = s.check()
if r == unsat:
print("proved")
elif r == unknown:
print("failed to prove")
print(s.model())
else:
print("counterexample")
print(s.model())
from z3 import *
p = Bool("p")
q = Bool("q")
prove(Implies(And(p,q), p)) # simple destruction of the And
prove( And(p,q) == Not(Or(Not(p),Not(q)))) #De Morgan's Law
x = Real("x")
y = Real("y")
z = Real("z")
prove(x + y == y + x) #Commutativity
prove(((x + y) + z) == ((x + (y + z)))) #associativity
prove(x + 0 == x) # 0 additive identity
prove(1 * x == x)
prove(Or(x > 0, x < 0, x == 0)) #trichotomy
prove(x**2 >= 0) #positivity of a square
prove(x * (y + z) == x * y + x * z) #distributive law
sin = Function("sin", RealSort(), RealSort())
cos = Function("cos", RealSort(), RealSort())
x = Real('x')
trig = [sin(0) == 0,
cos(0) == 1,
sin(180) == 0,
cos(180) == -1, # Using degrees is easier than radians. We have no pi.
ForAll([x], sin(2*x) == 2*sin(x)*cos(x)),
ForAll([x], sin(x)*sin(x) + cos(x) * cos(x) == 1),
ForAll([x], cos(2*x) == cos(x)*cos(x) - sin(x) * sin(x))]
s = Solver()
s.set(auto_config=False, mbqi=False)
s.add(trig)
s.add( RealVal(1 / np.sqrt(2) + 0.0000000000000001) <= cos(45))
s.check()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment