Skip to content

Instantly share code, notes, and snippets.

@mb64
Last active March 26, 2022 06:09
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mb64/e178dd9893ae13d4f22241963770f6b2 to your computer and use it in GitHub Desktop.
Save mb64/e178dd9893ae13d4f22241963770f6b2 to your computer and use it in GitHub Desktop.
Itty bitty SMT solver: DPLL(T) where T = equality. Likely buggy
(* Itty bitty SMT solver: DPLL(T) where T = equality *)
(*
# let x, y = 0, 1 (* integer variable IDs *) ;;
# let prob = SMT.new_problem 2 (* 2 for two variables *) ;;
# let b = SMT.new_bool prob;;
# SMT.add_clause prob [b; SMT.eq prob x y];
SMT.add_clause prob [SMT.not b];
SMT.solve prob;;
- : SMT.response = SMT.SAT
# SMT.add_clause prob [SMT.not (SMT.eq prob x y)];
SMT.solve prob;;
- : SMT.response = SMT.UNSAT
*)
module SAT : sig
type atom
val not : atom -> atom
type clause = atom list
type problem
val no_problem : problem
val add_var : problem -> atom * problem
val add_clause : problem -> atom list -> problem
type soln = Yes | No | IDK
(* dpll and its callback both raise Unsat if it's unsat.
TODO: have the callback provide an unsat core to learn *)
exception Unsat
val dpll : problem -> ((atom -> soln) -> unit) -> unit
end = struct
type atom = int
let not = lnot
let is_neg x = x < 0
let atom_to_var x = if is_neg x then not x else x
type clause = atom list
type problem = { num_vars: int; clauses: clause list }
let no_problem: problem = { num_vars = 0; clauses = [] }
let add_var (p : problem) = p.num_vars, { p with num_vars = p.num_vars + 1 }
let add_clause (p : problem) c = { p with clauses = c :: p.clauses }
type soln = Yes | No | IDK
exception Unsat (* :( *)
let flip_soln = function
| Yes -> No
| No -> Yes
| IDK -> IDK
let dpll ({ num_vars; clauses } : problem) (verify : (atom -> soln) -> unit) =
let model: soln array = Array.make num_vars IDK in
let clauses: clause array = Array.of_list clauses in
let num_clauses = Array.length clauses in
let watch_clauses: int list array = Array.make (2*num_vars) [] in
let watch_literal_1: int array = Array.make num_clauses 0 in
let watch_literal_2: int array = Array.make num_clauses 0 in
let unit_prop_worklist = ref [] in
let gotta_unit_prop x = unit_prop_worklist := x :: !unit_prop_worklist in
(* initialize watch literals *)
let atom_to_idx a = a + num_vars in
Array.iteri (fun i cl -> match cl with
| [] -> raise Unsat
| [x] -> gotta_unit_prop x
| x::y::_ ->
let add_clause a = let idx = atom_to_idx a in
watch_clauses.(idx) <- i :: watch_clauses.(idx) in
add_clause x; add_clause y;
watch_literal_1.(i) <- x;
watch_literal_2.(i) <- y) clauses;
let trail = ref [] in
let current_state a =
if is_neg a then flip_soln model.(not a) else model.(a) in
let rec backtrack_until a = match !trail with
| [] -> failwith "impossible -- needs to reach a"
| a' :: _ when a' = a -> ()
| a' :: rest ->
trail := rest;
model.(if is_neg a' then not a' else a') <- IDK;
backtrack_until a in
let rec unit_prop_all () = match !unit_prop_worklist with
| [] -> ()
| a :: rest -> unit_prop_worklist := rest; set_to_true a
and set_to_true a = match current_state a with
| Yes -> ()
| No -> raise Unsat
| IDK ->
trail := a :: !trail;
(if is_neg a then model.(not a) <- No else model.(a) <- Yes);
(* a has just been set to true. Look at (not a)-containing clauses for
unit prop opportunities. *)
let one_clause i =
let clause = clauses.(i) in
if List.exists (fun a -> current_state a = Yes) clause then () else
match List.filter (fun a -> current_state a = IDK) clause with
| [] -> failwith "impossible: a unit should have been propagated"
| [x] -> gotta_unit_prop x
| x::y::_ ->
(* still at least two things left. make one the new watcher *)
let old_watcher = not a in
let which_array =
if watch_literal_1.(i) = old_watcher
then watch_literal_1
else watch_literal_2 in
assert (which_array.(i) = old_watcher);
let new_watcher = if which_array.(i) = x then y else x in
which_array.(i) <- new_watcher;
watch_clauses.(new_watcher) <- i :: watch_clauses.(new_watcher);
watch_clauses.(old_watcher) <-
List.filter (fun j -> i <> j) watch_clauses.(old_watcher) in
let clauses = watch_clauses.(atom_to_idx (not a)) in
List.iter one_clause clauses;
unit_prop_all () in
(* The main recursive DPLL loop! *)
(* This is already a lot of code but still it'd be nice to do CDCL :/ *)
let rec go v =
(* dumbest possible variable ordering: 0 to N-1, in order *)
if v = num_vars then verify current_state else
if model.(v) <> IDK then go (v+1) else
try
set_to_true v;
go (v+1)
with Unsat -> begin
backtrack_until v;
set_to_true (not v);
go (v+1)
end in
unit_prop_all ();
go 0
end
module SMT : sig
type atom
val not : atom -> atom
type clause = atom list
type var_id = int
type problem
val new_problem : int (* number of variables *) -> problem
val eq : problem -> var_id -> var_id -> atom
val new_bool : problem -> atom
val add_clause : problem -> clause -> unit
type response = SAT | UNSAT
val solve : problem -> response
end = struct
type atom = SAT.atom
let not = SAT.not
type clause = SAT.clause
type var_id = int
type problem =
{ num_vars: int
; atoms: (var_id * var_id, atom) Hashtbl.t
; mutable sat: SAT.problem }
let new_problem num_vars: problem =
{ num_vars = num_vars
; atoms = Hashtbl.create 16
; sat = SAT.no_problem }
let eq (p: problem) x y =
let pair = if x < y then x, y else y, x in
match Hashtbl.find_opt p.atoms pair with
| Some a -> a
| None ->
let a, new_sat = SAT.add_var p.sat in
p.sat <- new_sat;
Hashtbl.add p.atoms pair a;
a
let new_bool (p: problem) =
let a, new_sat = SAT.add_var p.sat in
p.sat <- new_sat; a
let add_clause (p: problem) c =
p.sat <- SAT.add_clause p.sat c
type response = SAT | UNSAT
let solve ({ num_vars; atoms; sat }: problem) =
let verify model =
let parents = Array.init num_vars (fun i -> i) in
let rec find i =
let p = parents.(i) in
if p = i then i else let x = find p in parents.(i) <- x; x in
let union i j = parents.(find i) <- find j in
Hashtbl.iter (fun (i, j) a ->
if model a = SAT.Yes then union i j) atoms;
Hashtbl.iter (fun (i, j) a ->
if model a = SAT.No && find i = find j then raise SAT.Unsat) atoms in
match SAT.dpll sat verify with
| () -> SAT
| exception SAT.Unsat -> UNSAT
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment