Skip to content

Instantly share code, notes, and snippets.

@jad-hamza
Created September 8, 2020 12:50
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
Star You must be signed in to star a gist
Save jad-hamza/5878148330daf7b90cf6bf083276b5c6 to your computer and use it in GitHub Desktop.
import stainless.collection._ // for List
import stainless.equations._ // for ==:| and |: notations
import stainless.annotation._ // for @extern and @opaque annotations
object Demo3 {
sealed abstract class Tree {
def isSearchTree: Boolean = this match {
case Leaf => true
case Node(root, left, right) =>
left.isSearchTree &&
right.isSearchTree &&
left.forall(_ <= root) &&
right.forall(_ >= root)
}
def forall(p: BigInt => Boolean): Boolean = this match {
case Leaf => true
case Node(root, left, right) =>
p(root) && left.forall(p) && right.forall(p)
}
def insert(n: BigInt): Tree = this match {
case Leaf => Node(n, Leaf, Leaf)
case Node(root, left, right) if n <= root =>
Node(root, left.insert(n), right)
case Node(root, left, right) =>
Node(root, left, right.insert(n))
}
def toList: List[BigInt] = this match {
case Leaf => Nil()
case Node(root, left, right) =>
left.toList ++ (root :: right.toList)
}
}
case object Leaf extends Tree
case class Node(root: BigInt, left: Tree, right: Tree) extends Tree
def insert(l: List[BigInt], n: BigInt): List[BigInt] = l match {
case Nil() => n :: Nil()
case Cons(x, xs) if n <= x => n :: l
case Cons(x, xs) => x :: insert(xs, n)
}
// // WARNING: This is UNSOUND, only use for debugging
// @extern
// def assume(b: Boolean): Unit = {
// (??? : Unit)
// }.ensuring(_ => b)
def insertAppendLeft(
@induct l1: List[BigInt], l2: List[BigInt], middle: BigInt, n: BigInt
): Unit = {
require(n <= middle)
}.ensuring(_ =>
insert(l1 ++ (middle :: l2), n) == insert(l1, n) ++ (middle :: l2)
)
def insertAppendRight(
@induct l1: List[BigInt], l2: List[BigInt], middle: BigInt, n: BigInt
): Unit = {
require(l1.forall(_ < n) && middle < n)
}.ensuring(_ =>
insert(l1 ++ (middle :: l2), n) == l1 ++ (middle :: insert(l2, n))
)
def forallToList(t: Tree, p: BigInt => Boolean): Unit = {
require(t.forall(p))
t match {
case Leaf =>
case Node(root, left, right) =>
forallToList(left, p)
forallToList(right, p)
ListSpecs.listAppendValidProp(root :: right.toList, left.toList, p)
}
}.ensuring(_ =>
t.toList.forall(p)
)
def strictUpperBound(
@induct l: List[BigInt], x1: BigInt, x2: BigInt
): Unit = {
require(l.forall(_ <= x1) && x1 < x2)
}.ensuring(_ =>
l.forall(_ < x2)
)
def insertTreeList(t: Tree, n: BigInt): Unit = {
require(t.isSearchTree)
t match {
case Leaf =>
()
case Node(root, left, right) if n <= root =>
(
t.insert(n).toList ==:| trivial |:
Node(root, left.insert(n), right).toList ==:| trivial |:
left.insert(n).toList ++ (root :: right.toList) ==:|
insertTreeList(left, n) |:
insert(left.toList, n) ++ (root :: right.toList) ==:|
insertAppendLeft(left.toList, right.toList, root, n) |:
insert(left.toList ++ (root :: right.toList), n) ==:| trivial |:
insert(t.toList, n)
).qed
case Node(root, left, right) =>
forallToList(left, _ <= root)
// assert(left.toList.forall(_ <= root))
strictUpperBound(left.toList, root, n)
// assert(left.toList.forall(_ < n))
(
t.insert(n).toList ==:| trivial |:
Node(root, left, right.insert(n)).toList ==:| trivial |:
left.toList ++ (root :: right.insert(n).toList) ==:|
insertTreeList(right, n) |:
left.toList ++ (root :: insert(right.toList, n)) ==:|
insertAppendRight(left.toList, right.toList, root, n) |:
insert(left.toList ++ (root :: right.toList), n) ==:| trivial |:
insert(t.toList, n)
).qed
}
}.ensuring(_ =>
t.insert(n).toList == insert(t.toList, n)
)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment