Skip to content

Instantly share code, notes, and snippets.

@Hirrolot
Created March 24, 2023 10:34
Show Gist options
  • Save Hirrolot/3b233b7a80edf43234d2b89051af12b4 to your computer and use it in GitHub Desktop.
Save Hirrolot/3b233b7a80edf43234d2b89051af12b4 to your computer and use it in GitHub Desktop.
A simple CPS evaluation as in "Compiling with Continuations", Andrew W. Appel
type cps_var =
(* Taken from the lambda term during CPS conversion. *)
| CLamVar of string
(* Generated uniquely during CPS conversion. *)
| CGenVar of int
type cps_term =
| CFix of (cps_var * cps_var list * cps_term) list * cps_term
| CAppl of cps_var * cps_var list
| CRecord of cps_var list * binder
| CSelect of cps_var * int * binder
| CHalt of cps_var
(* Binds a unique [cps_var] within [cps_term]. *)
and binder = cps_var * cps_term
(* The result of evaluation. *)
type 'a value = VRecord of 'a value list | VFn of ('a value list -> 'a)
(* The evaluation environment. *)
type 'a env = (cps_var * 'a value) list
let bindn env names values = List.combine names values @ env
let getn env vars = List.map (fun var -> List.assoc var env) vars
(* Evaluates [cps_term] into [value] under the environment [env]. *)
let rec eval (env : 'a env) = function
| CFix (defs, m) ->
(* Produces a single function whose body is to be evaluated in an
augmented environment [g env']. *)
let rec h env' (_f, params, body) =
VFn (fun args -> eval (bindn (g env') params args) body)
(* Defines all the mutually recursive functions in [env'], producing
a new augmented environment. *)
and g env' =
bindn env'
(List.map (fun (f, _params, _body) -> f) defs)
(List.map (h env') defs)
in
eval (g env) m
| CAppl (f, args) -> (
match List.assoc f env with
| VFn fn -> fn (getn env args)
| _ -> failwith "Not a function")
| CRecord (fields, (x, m)) -> eval ((x, VRecord (getn env fields)) :: env) m
| CSelect (record, i, (x, m)) -> (
match List.assoc record env with
| VRecord fields -> eval ((x, List.nth fields i) :: env) m
| _ -> failwith "Not a record")
| CHalt _var -> failwith "Halted"
(* Test CPS evaluation. *)
let () =
let assert_eval env cps_t expected = assert (eval env cps_t = expected) in
let uncallable_fn = VFn (fun _args -> failwith "Must not be called") in
(* [CAppl] *)
assert_eval
[
(CLamVar "f", VFn (fun _args -> 42));
(CGenVar 33, VFn (fun _args -> 52));
( CLamVar "h",
VFn
(function
| [ VFn f; VFn g ] ->
assert (f [] = 42);
assert (g [] = 52);
123
| _ -> failwith "Invalid arguments") );
]
(CAppl (CLamVar "h", [ CLamVar "f"; CGenVar 33 ]))
123;
(* Get the first variable in the list. *)
assert_eval
[
(CLamVar "f", VFn (fun _args -> 42));
(CLamVar "f", VFn (fun _args -> -1));
(CGenVar 33, VFn (fun _args -> 52));
(CGenVar 33, VFn (fun _args -> -1));
( CLamVar "h",
VFn
(function
| [ VFn f; VFn g ] ->
assert (f [] = 42);
assert (g [] = 52);
123
| _ -> failwith "Invalid arguments") );
]
(CAppl (CLamVar "h", [ CLamVar "f"; CGenVar 33 ]))
123;
(* [CAppl] not a function. *)
try
let _ = eval [ (CLamVar "r", VRecord []) ] (CAppl (CLamVar "r", [])) in
assert false
with Failure msg -> (
assert (msg = "Not a function");
let assert_f f = assert (f [ VFn (fun _args -> 1); uncallable_fn ] = 1) in
let assert_g g = assert (g [ uncallable_fn; VFn (fun _args -> 1) ] = 1) in
(* [CFix] *)
assert_eval
[
( CLamVar "h",
VFn
(function
| [ VFn f; VFn g ] ->
assert_f f;
assert_g g;
123
| _ -> failwith "Invalid arguments") );
]
(CFix
( [
( CLamVar "f",
[ CLamVar "x"; CLamVar "y" ],
CAppl (CLamVar "x", [ CLamVar "f"; CLamVar "g" ]) );
( CLamVar "g",
[ CLamVar "x"; CLamVar "y" ],
CAppl (CLamVar "y", [ CLamVar "f"; CLamVar "g" ]) );
],
CAppl (CLamVar "f", [ CLamVar "h"; CLamVar "g" ]) ))
123;
(* [CRecord] *)
assert_eval
[
(CLamVar "f", VFn (fun _args -> 42));
(CLamVar "g", VFn (fun _args -> 52));
( CLamVar "h",
VFn
(function
| [ VRecord [ VFn f; VFn g ] ] ->
assert (f [] = 42);
assert (g [] = 52);
123
| _ -> failwith "Invalid arguments") );
]
(CRecord
( [ CLamVar "f"; CLamVar "g" ],
(CLamVar "r", CAppl (CLamVar "h", [ CLamVar "r" ])) ))
123;
let r =
(CLamVar "r", VRecord [ VFn (fun _args -> 42); VFn (fun _args -> 52) ])
in
let assert_record_field n = function
| [ VFn f ] ->
assert (f [] = n);
123
| _ -> failwith "Invalid arguments"
in
(* [CSelect] 0 *)
assert_eval
[ r; (CLamVar "h", VFn (assert_record_field 42)) ]
(CSelect
(CLamVar "r", 0, (CLamVar "f", CAppl (CLamVar "h", [ CLamVar "f" ]))))
123;
(* [CSelect] 1 *)
assert_eval
[ r; (CLamVar "h", VFn (assert_record_field 52)) ]
(CSelect
(CLamVar "r", 1, (CLamVar "f", CAppl (CLamVar "h", [ CLamVar "f" ]))))
123;
(* [CSelect] non-existent field. *)
try
let _ =
eval [ r ]
(CSelect (CLamVar "r", 3, (CLamVar "x", CHalt (CLamVar "x"))))
in
assert false
with Failure msg -> (
(assert (msg = "nth");
(* [CSelect] not a record. *)
try
let _ =
eval
[ (CLamVar "f", uncallable_fn) ]
(CSelect (CLamVar "f", 0, (CLamVar "x", CHalt (CLamVar "x"))))
in
assert false
with Failure msg -> assert (msg = "Not a record"));
(* [CHalt] *)
try
let _ = eval [] (CHalt (CLamVar "x")) in
assert false
with Failure msg -> assert (msg = "Halted")))
@Hirrolot
Copy link
Author

Hirrolot commented Mar 24, 2023

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment