Last active
June 2, 2022 18:37
-
-
Save keleshev/3529129da1bd03b4e9e3e983434cedd8 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
open Printf | |
module Syntax = struct | |
type t = | |
| Unit | |
| Boolean of bool | |
| Number of int | |
| Name of string | |
| Divide of t * t | |
| Sequence of t * t | |
| Let of {name: string; value: t; body: t} | |
| If of {conditional: t; consequence: t; alternative: t} | |
let map f = function | |
| Unit | Boolean _ | Number _ | Name _ as t -> | |
t | |
| Divide (left, right) -> | |
Divide (f left, f right) | |
| Sequence (left, right) -> | |
Sequence (f left, f right) | |
| Let {name; value; body} -> | |
Let {name; value=f value; body=f body} | |
| If {conditional; consequence; alternative} -> | |
let conditional = f conditional in | |
let consequence = f consequence in | |
let alternative = f alternative in | |
If {conditional; consequence; alternative} | |
let rec to_string = function | |
| Unit -> "()" | |
| Boolean b -> string_of_bool b | |
| Number n -> string_of_int n | |
| Name n -> n | |
| Divide (left, right) -> | |
sprintf "(%s / %s)" (to_string left) (to_string right) | |
| Sequence (left, right) -> | |
sprintf "(%s; %s)" (to_string left) (to_string right) | |
| Let {name; value; body} -> | |
sprintf "(let %s = %s in %s)" name (to_string value) (to_string body) | |
| If {conditional; consequence; alternative} -> | |
sprintf "(if %s then %s else %s)" | |
(to_string conditional) (to_string consequence) (to_string alternative) | |
let print t = printf ">>> %s\n" (to_string t) | |
let rec random () = | |
let random_string () = [|"x"; "y"; "z"|].(Random.int 3) in | |
match Random.int 12 with | |
| 0 -> Unit | |
| 1 -> Boolean (Random.bool ()) | |
| 2 -> Number (Random.int 100 + 1) | |
| 3 -> Name (random_string ()) | |
| 4 -> Unit | |
| 5 -> Boolean (Random.bool ()) | |
| 6 -> Number (Random.int 100) | |
| 7 -> Name (random_string ()) | |
| 8 -> Divide (random (), random ()) | |
| 9 -> Sequence (random (), random ()) | |
| 10 -> Let {name=random_string (); value=random (); body=random ()} | |
| 11 -> If { | |
conditional=random (); | |
consequence=random (); | |
alternative=random (); | |
} | |
| _ -> assert false | |
end | |
open Syntax | |
(* Three passes that are sequenced, not fused *) | |
module Not_fused = struct | |
module Dead_code_elimination = struct | |
let rec pass = function | |
| If {conditional=Boolean true; consequence; _} -> | |
pass consequence | |
| If {conditional=Boolean false; alternative; _} -> | |
pass alternative | |
| other -> | |
map pass other | |
end | |
module Constant_propagation = struct | |
let rec pass = function | |
| Divide (Number n, Number m) when m <> 0 -> | |
pass (Number (n / m)) | |
| other -> | |
map pass other | |
end | |
module Remove_redundant_let = struct | |
let rec pass = function | |
| Let {name; value; body=Name n} when n = name -> | |
pass value | |
| other -> | |
map pass other | |
end | |
let pass t = | |
Remove_redundant_let.pass | |
(Constant_propagation.pass | |
(Dead_code_elimination.pass t)) | |
end | |
(* Three passes fused manually into a single pass *) | |
module Manually_fused = struct | |
let rec pass = function | |
| If {conditional=Boolean true; consequence; _} -> | |
pass consequence | |
| If {conditional=Boolean false; alternative; _} -> | |
pass alternative | |
| Divide (Number n, Number m) when m <> 0 -> | |
pass (Number (n / m)) | |
| Let {name; value; body=Name n} when n = name -> | |
pass value | |
| other -> | |
map pass other | |
end | |
module Fusion_by_extension = struct | |
module Dead_code_elimination = struct | |
let pass first_pass = function | |
| If {conditional=Boolean true; consequence; _} -> | |
first_pass consequence | |
| If {conditional=Boolean false; alternative; _} -> | |
first_pass alternative | |
| other -> | |
map first_pass other | |
end | |
module Constant_propagation = struct | |
let pass first_pass = function | |
| Divide (Number n, Number m) when m <> 0 -> | |
first_pass (Number (n / m)) | |
| other -> | |
Dead_code_elimination.pass first_pass other | |
end | |
module Remove_redundant_let = struct | |
let pass first_pass = function | |
| Let {name; value; body=Name n} when n = name -> | |
first_pass value | |
| other -> | |
Constant_propagation.pass first_pass other | |
end | |
(* Fuse the passes together. Note: this does not pass the test. *) | |
let rec pass t = Remove_redundant_let.pass pass t | |
(*\ let (>>) f g x = f (g x) | |
let fix f_nonrec = | |
let rec f t = f_nonrec f t in | |
f | |
(* Another cute way of writing the same, but it is slower because | |
* the functions are not fully applied *) | |
let pass' = | |
fix ( | |
Dead_code_elimination.pass >> | |
Constant_propagation.pass >> | |
Remove_redundant_let.pass >> | |
map | |
) | |
*) | |
end | |
module Fused = struct | |
module Dead_code_elimination = struct | |
let pass first_pass next_pass = function | |
| If {conditional=Boolean true; consequence; _} -> | |
first_pass consequence | |
| If {conditional=Boolean false; alternative; _} -> | |
first_pass alternative | |
| other -> | |
next_pass other | |
end | |
module Constant_propagation = struct | |
let pass first_pass next_pass = function | |
| Divide (Number n, Number m) when m <> 0 -> | |
first_pass (Number (n / m)) | |
| other -> | |
next_pass other | |
end | |
module Remove_redundant_let = struct | |
let pass first_pass next_pass = function | |
| Let {name; value; body=Name n} when n = name -> | |
first_pass value | |
| other -> | |
next_pass other | |
end | |
(* Fuse the passes together *) | |
let rec pass t = | |
(Dead_code_elimination.pass pass | |
(Constant_propagation.pass pass | |
(Remove_redundant_let.pass pass | |
(map pass)))) t | |
end | |
module Test = struct | |
(* if false then () else ((); let x = 40 / 20 in x) *) | |
let tree = | |
If {conditional=Boolean false; consequence=Unit; alternative= | |
Sequence (Unit, | |
Let {name="x"; value=Divide (Number 40, Number 20); body=Name "x"})} in | |
assert (Not_fused.pass tree = Sequence (Unit, Number 2)); | |
assert (Manually_fused.pass tree = Sequence (Unit, Number 2)); | |
assert (Fusion_by_extension.pass tree = Sequence (Unit, Number 2)); | |
assert (Fused.pass tree = Sequence (Unit, Number 2)); | |
end | |
module Bench = struct | |
let n = 100_000_000 | |
let trees = Array.init n (fun _ -> Syntax.random ()) | |
let time name thunk = | |
let t = Sys.time () in | |
for i = 1 to n do | |
thunk (i - 1) |> ignore | |
done; | |
printf "%s: %gs\n" name (Sys.time () -. t) | |
let () = | |
time "Not_fused" (fun i -> Not_fused.pass trees.(i)); | |
time "Manually_fused" (fun i -> Manually_fused.pass trees.(i)); | |
time "Fusion_by_extension" (fun i -> Fusion_by_extension.pass trees.(i)); | |
time "Fused" (fun i -> Fused.pass trees.(i)) | |
(* time "Statistics" (fun i -> | |
let tree = trees.(i) in | |
let total = (String.length (to_string tree)) | |
and not_fused = (String.length (to_string (Not_fused.pass tree))) | |
and manually_fused = (String.length (to_string (Manually_fused.pass tree))) | |
and fused_v0 = (String.length (to_string (Fusion_by_extension.pass tree))) | |
and fused = (String.length (to_string (Fused.pass tree))) | |
in | |
assert (fused = manually_fused && fused = fused_v0); | |
if total <> not_fused then | |
() | |
(*Printf.printf "%d,%d,%d,%d,%d\n" total not_fused manually_fused fused_v0 fused*) | |
)*) | |
let test_property_the_implementations_are_identical = | |
assert (Array.for_all (fun tree -> | |
let manually_fused = Manually_fused.pass tree in | |
let fused_by_extension = Fusion_by_extension.pass tree in | |
let fused = Fused.pass tree in | |
fused = manually_fused && fused = fused_by_extension | |
) trees) | |
end | |
(* | |
$ ocamlopt fusion.ml && ./camlprog.exe | |
Not_fused: 16.696s | |
Fused: 8.120s | |
Manually_fused: 5.786s | |
*) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment