Created
March 28, 2017 21:29
-
-
Save msullivan/4a2e78688c4027507f583f62a5b5b4dc to your computer and use it in GitHub Desktop.
a modification of smlnj-lib's BinarySetFn that lets sets contain sets
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
(* A modification of binary-set-fn.sml to be able to handle set elements | |
* that contain sets. *) | |
(* binary-set-fn.sml | |
* | |
* COPYRIGHT (c) 1993 by AT&T Bell Laboratories. See COPYRIGHT file for details. | |
* | |
* This code was adapted from Stephen Adams' binary tree implementation | |
* of applicative integer sets. | |
* | |
* Copyright 1992 Stephen Adams. | |
* | |
* This software may be used freely provided that: | |
* 1. This copyright notice is attached to any copy, derived work, | |
* or work including all or part of this software. | |
* 2. Any derived work must contain a prominent notice stating that | |
* it has been altered from the original. | |
* | |
* Name(s): Stephen Adams. | |
* Department, Institution: Electronics & Computer Science, | |
* University of Southampton | |
* Address: Electronics & Computer Science | |
* University of Southampton | |
* Southampton SO9 5NH | |
* Great Britian | |
* E-mail: sra@ecs.soton.ac.uk | |
* | |
* Comments: | |
* | |
* 1. The implementation is based on Binary search trees of Bounded | |
* Balance, similar to Nievergelt & Reingold, SIAM J. Computing | |
* 2(1), March 1973. The main advantage of these trees is that | |
* they keep the size of the tree in the node, giving a constant | |
* time size operation. | |
* | |
* 2. The bounded balance criterion is simpler than N&R's alpha. | |
* Simply, one subtree must not have more than `weight' times as | |
* many elements as the opposite subtree. Rebalancing is | |
* guaranteed to reinstate the criterion for weight>2.23, but | |
* the occasional incorrect behaviour for weight=2 is not | |
* detrimental to performance. | |
* | |
* 3. There are two implementations of union. The default, | |
* hedge_union, is much more complex and usually 20% faster. I | |
* am not sure that the performance increase warrants the | |
* complexity (and time it took to write), but I am leaving it | |
* in for the competition. It is derived from the original | |
* union by replacing the split_lt(gt) operations with a lazy | |
* version. The `obvious' version is called old_union. | |
* | |
* 4. Most time is spent in T', the rebalancing constructor. If my | |
* understanding of the output of *<file> in the sml batch | |
* compiler is correct then the code produced by NJSML 0.75 | |
* (sparc) for the final case is very disappointing. Most | |
* invocations fall through to this case and most of these cases | |
* fall to the else part, i.e. the plain contructor, | |
* T(v,ln+rn+1,l,r). The poor code allocates a 16 word vector | |
* and saves lots of registers into it. In the common case it | |
* then retrieves a few of the registers and allocates the 5 | |
* word T node. The values that it retrieves were live in | |
* registers before the massive save. | |
* | |
* Modified to functor to support general ordered values | |
*) | |
signature BINARY_TREE = | |
sig | |
type 'a set | |
val compare : ('a * 'b -> order) -> 'a set * 'b set -> order | |
end | |
local | |
structure BinaryTreeInternal = | |
struct | |
datatype 'a set | |
= E | |
| T of { | |
elt : 'a, | |
cnt : int, | |
left : 'a set, | |
right : 'a set | |
} | |
local | |
fun next ((t as T{right, ...})::rest) = (t, left(right, rest)) | |
| next _ = (E, []) | |
and left (E, rest) = rest | |
| left (t as T{left=l, ...}, rest) = left(l, t::rest) | |
in | |
(********** Moved from BinarySetFn and made to take key_cmp **********) | |
fun compare key_cmp (s1, s2) = let | |
fun cmp (t1, t2) = (case (next t1, next t2) | |
of ((E, _), (E, _)) => EQUAL | |
| ((E, _), _) => LESS | |
| (_, (E, _)) => GREATER | |
| ((T{elt=e1, ...}, r1), (T{elt=e2, ...}, r2)) => ( | |
case key_cmp(e1, e2) | |
of EQUAL => cmp (r1, r2) | |
| order => order | |
(* end case *)) | |
(* end case *)) | |
in | |
cmp (left(s1, []), left(s2, [])) | |
end | |
end | |
end | |
in | |
structure BinaryTree : BINARY_TREE = BinaryTreeInternal | |
functor BinarySetFn (K : ORD_KEY) : ORD_SET = | |
struct | |
structure Key = K | |
type item = K.ord_key | |
(********** Modified type and compare implementation **********) | |
open BinaryTreeInternal | |
type set = item set | |
fun compare sets = BinaryTreeInternal.compare K.compare sets | |
fun numItems E = 0 | |
| numItems (T{cnt,...}) = cnt | |
fun isEmpty E = true | |
| isEmpty _ = false | |
fun mkT(v,n,l,r) = T{elt=v,cnt=n,left=l,right=r} | |
(* N(v,l,r) = T(v,1+numItems(l)+numItems(r),l,r) *) | |
fun N(v,E,E) = mkT(v,1,E,E) | |
| N(v,E,r as T{cnt=n,...}) = mkT(v,n+1,E,r) | |
| N(v,l as T{cnt=n,...}, E) = mkT(v,n+1,l,E) | |
| N(v,l as T{cnt=n,...}, r as T{cnt=m,...}) = mkT(v,n+m+1,l,r) | |
fun single_L (a,x,T{elt=b,left=y,right=z,...}) = N(b,N(a,x,y),z) | |
| single_L _ = raise Match | |
fun single_R (b,T{elt=a,left=x,right=y,...},z) = N(a,x,N(b,y,z)) | |
| single_R _ = raise Match | |
fun double_L (a,w,T{elt=c,left=T{elt=b,left=x,right=y,...},right=z,...}) = | |
N(b,N(a,w,x),N(c,y,z)) | |
| double_L _ = raise Match | |
fun double_R (c,T{elt=a,left=w,right=T{elt=b,left=x,right=y,...},...},z) = | |
N(b,N(a,w,x),N(c,y,z)) | |
| double_R _ = raise Match | |
(* | |
** val weight = 3 | |
** fun wt i = weight * i | |
*) | |
fun wt (i : int) = i + i + i | |
fun T' (v,E,E) = mkT(v,1,E,E) | |
| T' (v,E,r as T{left=E,right=E,...}) = mkT(v,2,E,r) | |
| T' (v,l as T{left=E,right=E,...},E) = mkT(v,2,l,E) | |
| T' (p as (_,E,T{left=T _,right=E,...})) = double_L p | |
| T' (p as (_,T{left=E,right=T _,...},E)) = double_R p | |
(* these cases almost never happen with small weight*) | |
| T' (p as (_,E,T{left=T{cnt=ln,...},right=T{cnt=rn,...},...})) = | |
if ln<rn then single_L p else double_L p | |
| T' (p as (_,T{left=T{cnt=ln,...},right=T{cnt=rn,...},...},E)) = | |
if ln>rn then single_R p else double_R p | |
| T' (p as (_,E,T{left=E,...})) = single_L p | |
| T' (p as (_,T{right=E,...},E)) = single_R p | |
| T' (p as (v,l as T{elt=lv,cnt=ln,left=ll,right=lr}, | |
r as T{elt=rv,cnt=rn,left=rl,right=rr})) = | |
if rn >= wt ln (*right is too big*) | |
then | |
let val rln = numItems rl | |
val rrn = numItems rr | |
in | |
if rln < rrn then single_L p else double_L p | |
end | |
else if ln >= wt rn (*left is too big*) | |
then | |
let val lln = numItems ll | |
val lrn = numItems lr | |
in | |
if lrn < lln then single_R p else double_R p | |
end | |
else mkT(v,ln+rn+1,l,r) | |
fun add (E,x) = mkT(x,1,E,E) | |
| add (set as T{elt=v,left=l,right=r,cnt},x) = | |
case K.compare(x,v) of | |
LESS => T'(v,add(l,x),r) | |
| GREATER => T'(v,l,add(r,x)) | |
| EQUAL => mkT(x,cnt,l,r) | |
fun add' (s, x) = add(x, s) | |
fun concat3 (E,v,r) = add(r,v) | |
| concat3 (l,v,E) = add(l,v) | |
| concat3 (l as T{elt=v1,cnt=n1,left=l1,right=r1}, v, | |
r as T{elt=v2,cnt=n2,left=l2,right=r2}) = | |
if wt n1 < n2 then T'(v2,concat3(l,v,l2),r2) | |
else if wt n2 < n1 then T'(v1,l1,concat3(r1,v,r)) | |
else N(v,l,r) | |
fun split_lt (E,x) = E | |
| split_lt (T{elt=v,left=l,right=r,...},x) = | |
case K.compare(v,x) of | |
GREATER => split_lt(l,x) | |
| LESS => concat3(l,v,split_lt(r,x)) | |
| _ => l | |
fun split_gt (E,x) = E | |
| split_gt (T{elt=v,left=l,right=r,...},x) = | |
case K.compare(v,x) of | |
LESS => split_gt(r,x) | |
| GREATER => concat3(split_gt(l,x),v,r) | |
| _ => r | |
fun min (T{elt=v,left=E,...}) = v | |
| min (T{left=l,...}) = min l | |
| min _ = raise Match | |
fun delmin (T{left=E,right=r,...}) = r | |
| delmin (T{elt=v,left=l,right=r,...}) = T'(v,delmin l,r) | |
| delmin _ = raise Match | |
fun delete' (E,r) = r | |
| delete' (l,E) = l | |
| delete' (l,r) = T'(min r,l,delmin r) | |
fun concat (E, s) = s | |
| concat (s, E) = s | |
| concat (t1 as T{elt=v1,cnt=n1,left=l1,right=r1}, | |
t2 as T{elt=v2,cnt=n2,left=l2,right=r2}) = | |
if wt n1 < n2 then T'(v2,concat(t1,l2),r2) | |
else if wt n2 < n1 then T'(v1,l1,concat(r1,t2)) | |
else T'(min t2,t1, delmin t2) | |
local | |
fun trim (lo,hi,E) = E | |
| trim (lo,hi,s as T{elt=v,left=l,right=r,...}) = | |
if K.compare(v,lo) = GREATER | |
then if K.compare(v,hi) = LESS then s else trim(lo,hi,l) | |
else trim(lo,hi,r) | |
fun uni_bd (s,E,_,_) = s | |
| uni_bd (E,T{elt=v,left=l,right=r,...},lo,hi) = | |
concat3(split_gt(l,lo),v,split_lt(r,hi)) | |
| uni_bd (T{elt=v,left=l1,right=r1,...}, | |
s2 as T{elt=v2,left=l2,right=r2,...},lo,hi) = | |
concat3(uni_bd(l1,trim(lo,v,s2),lo,v), | |
v, | |
uni_bd(r1,trim(v,hi,s2),v,hi)) | |
(* inv: lo < v < hi *) | |
(* all the other versions of uni and trim are | |
* specializations of the above two functions with | |
* lo=-infinity and/or hi=+infinity | |
*) | |
fun trim_lo (_, E) = E | |
| trim_lo (lo,s as T{elt=v,right=r,...}) = | |
case K.compare(v,lo) of | |
GREATER => s | |
| _ => trim_lo(lo,r) | |
fun trim_hi (_, E) = E | |
| trim_hi (hi,s as T{elt=v,left=l,...}) = | |
case K.compare(v,hi) of | |
LESS => s | |
| _ => trim_hi(hi,l) | |
fun uni_hi (s,E,_) = s | |
| uni_hi (E,T{elt=v,left=l,right=r,...},hi) = | |
concat3(l,v,split_lt(r,hi)) | |
| uni_hi (T{elt=v,left=l1,right=r1,...}, | |
s2 as T{elt=v2,left=l2,right=r2,...},hi) = | |
concat3(uni_hi(l1,trim_hi(v,s2),v),v,uni_bd(r1,trim(v,hi,s2),v,hi)) | |
fun uni_lo (s,E,_) = s | |
| uni_lo (E,T{elt=v,left=l,right=r,...},lo) = | |
concat3(split_gt(l,lo),v,r) | |
| uni_lo (T{elt=v,left=l1,right=r1,...}, | |
s2 as T{elt=v2,left=l2,right=r2,...},lo) = | |
concat3(uni_bd(l1,trim(lo,v,s2),lo,v),v,uni_lo(r1,trim_lo(v,s2),v)) | |
fun uni (s,E) = s | |
| uni (E,s) = s | |
| uni (T{elt=v,left=l1,right=r1,...}, | |
s2 as T{elt=v2,left=l2,right=r2,...}) = | |
concat3(uni_hi(l1,trim_hi(v,s2),v), v, uni_lo(r1,trim_lo(v,s2),v)) | |
in | |
val hedge_union = uni | |
end | |
(* The old_union version is about 20% slower than | |
* hedge_union in most cases | |
*) | |
fun old_union (E,s2) = s2 | |
| old_union (s1,E) = s1 | |
| old_union (T{elt=v,left=l,right=r,...},s2) = | |
let val l2 = split_lt(s2,v) | |
val r2 = split_gt(s2,v) | |
in | |
concat3(old_union(l,l2),v,old_union(r,r2)) | |
end | |
val empty = E | |
fun singleton x = T{elt=x,cnt=1,left=E,right=E} | |
fun addList (s,l) = List.foldl (fn (i,s) => add(s,i)) s l | |
fun fromList l = addList (E, l) | |
val add = add | |
fun member (set, x) = let | |
fun pk E = false | |
| pk (T{elt=v, left=l, right=r, ...}) = ( | |
case K.compare(x,v) | |
of LESS => pk l | |
| EQUAL => true | |
| GREATER => pk r | |
(* end case *)) | |
in | |
pk set | |
end | |
local | |
(* true if every item in t is in t' *) | |
fun treeIn (t,t') = let | |
fun isIn E = true | |
| isIn (T{elt,left=E,right=E,...}) = member(t',elt) | |
| isIn (T{elt,left,right=E,...}) = | |
member(t',elt) andalso isIn left | |
| isIn (T{elt,left=E,right,...}) = | |
member(t',elt) andalso isIn right | |
| isIn (T{elt,left,right,...}) = | |
member(t',elt) andalso isIn left andalso isIn right | |
in | |
isIn t | |
end | |
in | |
fun isSubset (E,_) = true | |
| isSubset (_,E) = false | |
| isSubset (t as T{cnt=n,...},t' as T{cnt=n',...}) = | |
(n<=n') andalso treeIn (t,t') | |
fun equal (E,E) = true | |
| equal (t as T{cnt=n,...},t' as T{cnt=n',...}) = | |
(n=n') andalso treeIn (t,t') | |
| equal _ = false | |
end | |
fun delete (E,x) = raise LibBase.NotFound | |
| delete (set as T{elt=v,left=l,right=r,...},x) = | |
case K.compare(x,v) of | |
LESS => T'(v,delete(l,x),r) | |
| GREATER => T'(v,l,delete(r,x)) | |
| _ => delete'(l,r) | |
val union = hedge_union | |
fun intersection (E, _) = E | |
| intersection (_, E) = E | |
| intersection (s, T{elt=v,left=l,right=r,...}) = let | |
val l2 = split_lt(s,v) | |
val r2 = split_gt(s,v) | |
in | |
if member(s,v) | |
then concat3(intersection(l2,l),v,intersection(r2,r)) | |
else concat(intersection(l2,l),intersection(r2,r)) | |
end | |
fun difference (E,s) = E | |
| difference (s,E) = s | |
| difference (s, T{elt=v,left=l,right=r,...}) = | |
let val l2 = split_lt(s,v) | |
val r2 = split_gt(s,v) | |
in | |
concat(difference(l2,l),difference(r2,r)) | |
end | |
fun map f set = let | |
fun map'(acc, E) = acc | |
| map'(acc, T{elt,left,right,...}) = | |
map' (add (map' (acc, left), f elt), right) | |
in | |
map' (E, set) | |
end | |
fun app apf = | |
let fun apply E = () | |
| apply (T{elt,left,right,...}) = | |
(apply left;apf elt; apply right) | |
in | |
apply | |
end | |
fun foldl f b set = let | |
fun foldf (E, b) = b | |
| foldf (T{elt,left,right,...}, b) = | |
foldf (right, f(elt, foldf (left, b))) | |
in | |
foldf (set, b) | |
end | |
fun foldr f b set = let | |
fun foldf (E, b) = b | |
| foldf (T{elt,left,right,...}, b) = | |
foldf (left, f(elt, foldf (right, b))) | |
in | |
foldf (set, b) | |
end | |
fun listItems set = foldr (op::) [] set | |
fun filter pred set = | |
foldl (fn (item, s) => if (pred item) then add(s, item) else s) | |
empty set | |
fun partition pred set = | |
foldl | |
(fn (item, (s1, s2)) => | |
if (pred item) then (add(s1, item), s2) else (s1, add(s2, item)) | |
) | |
(empty, empty) set | |
fun find p E = NONE | |
| find p (T{elt,left,right,...}) = (case find p left | |
of NONE => if (p elt) | |
then SOME elt | |
else find p right | |
| a => a | |
(* end case *)) | |
fun exists p E = false | |
| exists p (T{elt, left, right,...}) = | |
(exists p left) orelse (p elt) orelse (exists p right) | |
end (* BinarySetFn *) | |
end | |
structure Prover = | |
struct | |
type name = string | |
(* Propositions are either atomic or an implication from a set | |
* props to a prop *) | |
datatype Prop = Atom of name | |
| Imp of Prop BinaryTree.set * Prop | |
fun name_compare (n1, n2) = String.compare (n1, n2) | |
fun prop_compare (Atom n1, Atom n2) = name_compare (n1, n2) | |
| prop_compare (Imp (ps, p), Imp (qs, q)) = | |
(case BinaryTree.compare prop_compare (ps, qs) of | |
EQUAL => prop_compare (p, q) | |
| x => x) | |
| prop_compare (Atom _, Imp _) = LESS | |
| prop_compare (Imp _, Atom _) = GREATER | |
structure PropSet = BinarySetFn(struct | |
type ord_key = Prop | |
val compare = prop_compare | |
end) | |
val a = Atom "a" | |
val b = Atom "b" | |
val c = Atom "c" | |
val a_imp_b = Imp (PropSet.fromList [a], b) | |
val c' = Imp (PropSet.fromList [a, a_imp_b], c) | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment