Skip to content

Instantly share code, notes, and snippets.

@keleshev
Last active June 2, 2022 18:37
Show Gist options
  • Save keleshev/3529129da1bd03b4e9e3e983434cedd8 to your computer and use it in GitHub Desktop.
Save keleshev/3529129da1bd03b4e9e3e983434cedd8 to your computer and use it in GitHub Desktop.
open Printf
module Syntax = struct
type t =
| Unit
| Boolean of bool
| Number of int
| Name of string
| Divide of t * t
| Sequence of t * t
| Let of {name: string; value: t; body: t}
| If of {conditional: t; consequence: t; alternative: t}
let map f = function
| Unit | Boolean _ | Number _ | Name _ as t ->
t
| Divide (left, right) ->
Divide (f left, f right)
| Sequence (left, right) ->
Sequence (f left, f right)
| Let {name; value; body} ->
Let {name; value=f value; body=f body}
| If {conditional; consequence; alternative} ->
let conditional = f conditional in
let consequence = f consequence in
let alternative = f alternative in
If {conditional; consequence; alternative}
let rec to_string = function
| Unit -> "()"
| Boolean b -> string_of_bool b
| Number n -> string_of_int n
| Name n -> n
| Divide (left, right) ->
sprintf "(%s / %s)" (to_string left) (to_string right)
| Sequence (left, right) ->
sprintf "(%s; %s)" (to_string left) (to_string right)
| Let {name; value; body} ->
sprintf "(let %s = %s in %s)" name (to_string value) (to_string body)
| If {conditional; consequence; alternative} ->
sprintf "(if %s then %s else %s)"
(to_string conditional) (to_string consequence) (to_string alternative)
let print t = printf ">>> %s\n" (to_string t)
let rec random () =
let random_string () = [|"x"; "y"; "z"|].(Random.int 3) in
match Random.int 12 with
| 0 -> Unit
| 1 -> Boolean (Random.bool ())
| 2 -> Number (Random.int 100 + 1)
| 3 -> Name (random_string ())
| 4 -> Unit
| 5 -> Boolean (Random.bool ())
| 6 -> Number (Random.int 100)
| 7 -> Name (random_string ())
| 8 -> Divide (random (), random ())
| 9 -> Sequence (random (), random ())
| 10 -> Let {name=random_string (); value=random (); body=random ()}
| 11 -> If {
conditional=random ();
consequence=random ();
alternative=random ();
}
| _ -> assert false
end
open Syntax
(* Three passes that are sequenced, not fused *)
module Not_fused = struct
module Dead_code_elimination = struct
let rec pass = function
| If {conditional=Boolean true; consequence; _} ->
pass consequence
| If {conditional=Boolean false; alternative; _} ->
pass alternative
| other ->
map pass other
end
module Constant_propagation = struct
let rec pass = function
| Divide (Number n, Number m) when m <> 0 ->
pass (Number (n / m))
| other ->
map pass other
end
module Remove_redundant_let = struct
let rec pass = function
| Let {name; value; body=Name n} when n = name ->
pass value
| other ->
map pass other
end
let pass t =
Remove_redundant_let.pass
(Constant_propagation.pass
(Dead_code_elimination.pass t))
end
(* Three passes fused manually into a single pass *)
module Manually_fused = struct
let rec pass = function
| If {conditional=Boolean true; consequence; _} ->
pass consequence
| If {conditional=Boolean false; alternative; _} ->
pass alternative
| Divide (Number n, Number m) when m <> 0 ->
pass (Number (n / m))
| Let {name; value; body=Name n} when n = name ->
pass value
| other ->
map pass other
end
module Fusion_by_extension = struct
module Dead_code_elimination = struct
let pass first_pass = function
| If {conditional=Boolean true; consequence; _} ->
first_pass consequence
| If {conditional=Boolean false; alternative; _} ->
first_pass alternative
| other ->
map first_pass other
end
module Constant_propagation = struct
let pass first_pass = function
| Divide (Number n, Number m) when m <> 0 ->
first_pass (Number (n / m))
| other ->
Dead_code_elimination.pass first_pass other
end
module Remove_redundant_let = struct
let pass first_pass = function
| Let {name; value; body=Name n} when n = name ->
first_pass value
| other ->
Constant_propagation.pass first_pass other
end
(* Fuse the passes together. Note: this does not pass the test. *)
let rec pass t = Remove_redundant_let.pass pass t
(*\ let (>>) f g x = f (g x)
let fix f_nonrec =
let rec f t = f_nonrec f t in
f
(* Another cute way of writing the same, but it is slower because
* the functions are not fully applied *)
let pass' =
fix (
Dead_code_elimination.pass >>
Constant_propagation.pass >>
Remove_redundant_let.pass >>
map
)
*)
end
module Fused = struct
module Dead_code_elimination = struct
let pass first_pass next_pass = function
| If {conditional=Boolean true; consequence; _} ->
first_pass consequence
| If {conditional=Boolean false; alternative; _} ->
first_pass alternative
| other ->
next_pass other
end
module Constant_propagation = struct
let pass first_pass next_pass = function
| Divide (Number n, Number m) when m <> 0 ->
first_pass (Number (n / m))
| other ->
next_pass other
end
module Remove_redundant_let = struct
let pass first_pass next_pass = function
| Let {name; value; body=Name n} when n = name ->
first_pass value
| other ->
next_pass other
end
(* Fuse the passes together *)
let rec pass t =
(Dead_code_elimination.pass pass
(Constant_propagation.pass pass
(Remove_redundant_let.pass pass
(map pass)))) t
end
module Test = struct
(* if false then () else ((); let x = 40 / 20 in x) *)
let tree =
If {conditional=Boolean false; consequence=Unit; alternative=
Sequence (Unit,
Let {name="x"; value=Divide (Number 40, Number 20); body=Name "x"})} in
assert (Not_fused.pass tree = Sequence (Unit, Number 2));
assert (Manually_fused.pass tree = Sequence (Unit, Number 2));
assert (Fusion_by_extension.pass tree = Sequence (Unit, Number 2));
assert (Fused.pass tree = Sequence (Unit, Number 2));
end
module Bench = struct
let n = 100_000_000
let trees = Array.init n (fun _ -> Syntax.random ())
let time name thunk =
let t = Sys.time () in
for i = 1 to n do
thunk (i - 1) |> ignore
done;
printf "%s: %gs\n" name (Sys.time () -. t)
let () =
time "Not_fused" (fun i -> Not_fused.pass trees.(i));
time "Manually_fused" (fun i -> Manually_fused.pass trees.(i));
time "Fusion_by_extension" (fun i -> Fusion_by_extension.pass trees.(i));
time "Fused" (fun i -> Fused.pass trees.(i))
(* time "Statistics" (fun i ->
let tree = trees.(i) in
let total = (String.length (to_string tree))
and not_fused = (String.length (to_string (Not_fused.pass tree)))
and manually_fused = (String.length (to_string (Manually_fused.pass tree)))
and fused_v0 = (String.length (to_string (Fusion_by_extension.pass tree)))
and fused = (String.length (to_string (Fused.pass tree)))
in
assert (fused = manually_fused && fused = fused_v0);
if total <> not_fused then
()
(*Printf.printf "%d,%d,%d,%d,%d\n" total not_fused manually_fused fused_v0 fused*)
)*)
let test_property_the_implementations_are_identical =
assert (Array.for_all (fun tree ->
let manually_fused = Manually_fused.pass tree in
let fused_by_extension = Fusion_by_extension.pass tree in
let fused = Fused.pass tree in
fused = manually_fused && fused = fused_by_extension
) trees)
end
(*
$ ocamlopt fusion.ml && ./camlprog.exe
Not_fused: 16.696s
Fused: 8.120s
Manually_fused: 5.786s
*)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment