Skip to content

Instantly share code, notes, and snippets.

@TyOverby
Created December 12, 2020 19:37
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save TyOverby/52c283d5d082f674be93024ee31a118d to your computer and use it in GitHub Desktop.
Save TyOverby/52c283d5d082f674be93024ee31a118d to your computer and use it in GitHub Desktop.
open! Base
type binop =
[ `Add
| `Sub
| `Mul
| `Div
]
[@@deriving equal, sexp, compare, hash]
type unop =
[ `Sqrt
| `Neg
| `Square
]
[@@deriving equal, sexp, compare, hash]
module Structure = struct
type 'recursive t =
| Const of float
| Var of int
| Binop of binop * 'recursive * 'recursive
| Unop of unop * 'recursive
[@@deriving sexp, equal, compare, hash]
let label : _ t -> string = function
| Const f -> Printf.sprintf "%f" f
| Var i -> Printf.sprintf "[%d]" i
| Binop (`Add, _, _) -> "+"
| Binop (`Sub, _, _) -> "-"
| Binop (`Mul, _, _) -> "*"
| Binop (`Div, _, _) -> "/"
| Unop (`Sqrt, _) -> "sqrt"
| Unop (`Neg, _) -> "neg"
| Unop (`Square, _) -> "^2"
;;
end
module Tree = struct
type t = T of t Structure.t [@@deriving equal, compare]
let rec sexp_of_t (T t) = Structure.sexp_of_t sexp_of_t t
let label (T inner) = Structure.label inner
let to_graphviz tree =
let id = ref 0 in
let buf = Buffer.create 1024 in
let rec loop (T t) =
Int.incr id;
let id = Printf.sprintf "t_%d" !id in
Printf.bprintf buf "%s [label=\"%s\"];\n" id (label (T t));
match t with
| Const _ | Var _ -> id
| Binop (_, a, b) ->
let a, b = loop a, loop b in
Printf.bprintf buf "%s -> %s;\n" a id;
Printf.bprintf buf "%s -> %s;\n" b id;
id
| Unop (_, a) ->
let a = loop a in
Printf.bprintf buf "%s -> %s;\n" a id;
id
in
let _root = loop tree in
"digraph G {\n" ^ Buffer.contents buf ^ "}"
;;
let rec optimize : t -> t = function
| T ((Const _ | Var _) as t) -> T t
| T (Binop (`Add, a, b)) ->
(match optimize a, optimize b with
| T (Const a), T (Const b) -> T (Const (a +. b))
| T (Const 0.0), a -> T (Unop (`Neg, a))
| a, T (Const 0.0) -> a
| a, b when equal a b -> T (Binop (`Mul, T (Const 2.0), a))
| a, b -> T (Binop (`Add, a, b)))
| T (Binop (`Sub, a, b)) ->
(match optimize a, optimize b with
| T (Const a), T (Const b) -> T (Const (a -. b))
| a, b when equal a b -> T (Const 0.0)
| a, b -> T (Binop (`Sub, a, b)))
| T (Binop (`Mul, a, b)) ->
(match optimize a, optimize b with
| T (Const a), T (Const b) -> T (Const (a *. b))
| T (Const 0.0), _ | _, T (Const 0.0) -> T (Const 0.0)
| T (Const 1.0), a | a, T (Const 1.0) -> a
| a, b when equal a b -> T (Unop (`Square, a))
| a, b -> T (Binop (`Mul, a, b)))
| T (Binop (`Div, a, b)) ->
(match optimize a, optimize b with
| T (Const 0.0), _ -> T (Const 0.0)
| a, b when equal a b -> T (Const 1.0)
| _, T (Const 0.0) -> T (Const Float.nan)
| T (Const a), T (Const b) -> T (Const (a /. b))
| a, b -> T (Binop (`Div, a, b)))
| T (Unop (`Sqrt, a)) ->
(match optimize a with
| T (Unop (`Square, a)) -> a
| T (Const a) -> T (Const (Float.sqrt a))
| a -> T (Unop (`Sqrt, a)))
| T (Unop (`Neg, a)) ->
(match optimize a with
| T (Unop (`Neg, a)) -> a
| a -> T (Unop (`Neg, a)))
| T (Unop (`Square, a)) ->
(match optimize a with
| T (Unop (`Sqrt, a)) -> a
| a -> T (Unop (`Square, a)))
;;
end
module Graph = struct
module Node = struct
type t = string Structure.t [@@deriving equal, sexp, compare, hash]
end
type t =
{ nodes : (string * Node.t) list
; root : string
}
[@@deriving equal, sexp, compare]
let collide_reorderable = function
| Structure.Binop (((`Add | `Mul) as op), a, b) ->
let a' = String.max a b in
let b' = String.min a b in
Structure.Binop (op, a', b')
| other -> other
;;
let of_tree tree =
let id = ref 0 in
let nodes = ref [] in
let seen = Hashtbl.create (module Node) in
let find_or_push node =
let node = collide_reorderable node in
match Hashtbl.find seen node with
| Some id -> id
| None ->
Int.incr id;
let id = Printf.sprintf "t_%d" !id in
nodes := (id, node) :: !nodes;
Hashtbl.set seen ~key:node ~data:id;
id
in
let rec loop : Tree.t -> string = function
| T (Const c) -> find_or_push (Const c)
| T (Var c) -> find_or_push (Var c)
| T (Binop (op, a, b)) ->
let a = loop a in
let b = loop b in
find_or_push (Binop (op, a, b))
| T (Unop (op, a)) ->
let a = loop a in
find_or_push (Unop (op, a))
in
let root = loop tree in
{ nodes = List.rev !nodes; root }
;;
let to_graphviz { root = _; nodes } =
let buf = Buffer.create 1024 in
List.iter nodes ~f:(function
| id, ((Var _ | Const _) as node) ->
Printf.bprintf buf "%s [label=\"%s\"];\n" id (Structure.label node)
| id, (Binop (_op, a, b) as node) ->
Printf.bprintf buf "%s [label=\"%s\"];\n" id (Structure.label node);
Printf.bprintf buf "%s -> %s;\n" a id;
Printf.bprintf buf "%s -> %s;\n" b id
| id, (Unop (_op, a) as node) ->
Printf.bprintf buf "%s [label=\"%s\"];\n" id (Structure.label node);
Printf.bprintf buf "%s -> %s\n" a id);
"digraph G {\n" ^ Buffer.contents buf ^ "}"
;;
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment