Skip to content

Instantly share code, notes, and snippets.

@ramntry
Last active August 29, 2015 14:06
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ramntry/68a7534d88f4a40e28b2 to your computer and use it in GitHub Desktop.
Save ramntry/68a7534d88f4a40e28b2 to your computer and use it in GitHub Desktop.
Extended GADT example in OCaml
(* ('a, 'v) expr GADT:
* 'a stands for type of expression constants.
* 'v stands for type of result of expression evaluation.
*
* For most cases an expression evaluates to value of the same
* type as it's constants, but if the last operator of expression
* is a comparison (Less constructor) or negation (Not constructor)
* the whole expression has a boolean type. Note that in that case
* 'a parameter stays the same, probably non-boolean type. What actually
* changes is 'v. So, the following statement is true:
*
* ('v = bool) or ('v = 'a) (1)
*
* 'v surve the only purpose of make possible to constrain type of
* value Less and Not constructors returns. It is exactly what
* for ('a, 'v) expr is GADT, not ordinary ADT. ('a, 'v) expr can not
* be treated as container of values with type 'v! Insteed, it should be
* treated as container for type 'a, specifically as leaf tree of Const's.
* Thus it makes sense to define for ('a, 'v) expr map and even fold
* transformations. Fold is simple, because you can extract from an
* expression a list of constants used in it. Obviously, fold is
* pretty well applicable transformantion for any kind of list.
*
* The map is harder one. To make things simpler If constructor is defined
* by such way that If expression can be evaluated only to 'a value, not
* to both 'a or boolean value depending on the 'then' and 'else' expressions.
* Nevertheless, Less and Not expressions evaluates to bool while Const,
* Add, and If evaluates to 'a, so that the first argument of an If
* constructor can be any kind of expression if and only if 'a = bool.
* Otherwise the If's first argument must be Less or Not expression.
* This means that in general case (bool, bool) expr can be mapped
* only to another (bool, bool) expr! In contrast ('a, _) expr where 'a is not
* bool can be freely mapped to ('b, _) expr for all 'b. The full scheme of
* determination of possible ways of mapping follows, you should treat left
* hand side of it as exhaustive from-top-to-bottom pattern matching on types:
*
* map
* (bool, bool) expr --> (bool, bool) expr (2)
* ('a, bool) expr --> ('b, bool) expr (3)
* ('a, 'a) expr --> ('b, 'b) expr (4)
*)
type ('a, _) expr =
| Const : 'a -> ('a, 'a) expr
| Add : ('a, 'a) expr * ('a, 'a) expr -> ('a, 'a) expr
| Less : ('a, 'a) expr * ('a, 'a) expr -> ('a, bool) expr
| Not : ('a, bool) expr -> ('a, bool) expr
| If : ('a, bool) expr * ('a, 'a) expr * ('a, 'a) expr -> ('a, 'a) expr
let rec eval : type a v. (a -> a -> a) -> (a -> a -> bool) -> (a, v) expr -> v =
fun (add : a -> a -> a) less expr ->
match expr with
| Const x -> x
| Add (l, r) -> add (eval add less l) (eval add less r)
| Less (l, r) -> less (eval add less l) (eval add less r)
| Not e -> not (eval add less e)
| If (cond, t, e) ->
if eval add less cond
then eval add less t
else eval add less e
let rec fold_right : type a v. (a -> 'b -> 'b) -> 'b -> (a, v) expr -> 'b =
fun combine acc expr ->
match expr with
| Const x -> combine x acc
| Add (l, r) -> fold_right combine (fold_right combine acc r) l
| Less (l, r) -> fold_right combine (fold_right combine acc r) l
| Not e -> fold_right combine acc e
| If (cond, t, e) ->
fold_right combine
(fold_right combine
(fold_right combine acc e) t) cond
let consts expr = fold_right (fun a acc -> a :: acc) [] expr
(** With current implementation of map you should to choose an appropriate
* version of map function (bool_map, a_bool_map or a_v_map) by hand.
* To make life more interesting the a_bool_map can fail at run-time if
* you make the wrong choice.
*)
(* (2) *)
let rec bool_map : (bool -> bool) -> (bool, bool) expr -> (bool, bool) expr =
fun func expr ->
match expr with
| Const x -> Const (func x)
| Add (l, r) -> Add (bool_map func l, bool_map func r)
| Less (l, r) -> Less (bool_map func l, bool_map func r)
| Not e -> Not (bool_map func e)
| If (cond, t, e) ->
If (bool_map func cond, bool_map func t, bool_map func e)
(* (3) *)
let rec a_bool_map : type a b. (a -> b) -> (a, bool) expr -> (b, bool) expr =
fun func expr ->
match expr with
| Less (l, r) -> Less (a_v_map func l, a_v_map func r)
| Not e -> Not (a_bool_map func e)
| _ -> failwith ("a_bool_map: If expression has type (bool, bool) expr you "
^ "should use the bool_map instead. Otherwise it is internal error.")
(* (4) *)
and a_v_map : type a b. (a -> b) -> (a, a) expr -> (b, b) expr =
fun func expr ->
match expr with
| Const x -> Const (func x)
| Add (l, r) -> Add (a_v_map func l, a_v_map func r)
| If (cond, t, e) ->
If (a_bool_map func cond, a_v_map func t, a_v_map func e)
| _ -> failwith "a_v_map: Internal error"
(** Expression examples and tests.
*)
let int_bool_expr =
Not (Less (Const 50,
Add (Const 30,
Const 40)))
let int_int_expr =
Add (Add (Const 10,
Const 20),
If (Not int_bool_expr,
Add (Const 60,
Const 70),
Const 80))
let bool_expr =
If (Not (If (If (Add (Not (Const false),
Const false),
If (Less (Const true,
Const false),
Const true,
Const false),
Const true),
Const false,
Less (Add (Const false,
Const false),
Const false))),
Const false,
Const true)
let () =
assert (eval ( + ) ( < ) (Const 10) = 10);
assert (eval ( || ) (fun l r -> not l || r) (Not (Const true)) = false);
assert (eval ( ^ ) ( < ) (Const "ok") = "ok");
assert (eval ( + ) ( < ) int_bool_expr = false);
assert (eval ( + ) ( < ) int_int_expr = 160);
assert (eval ( || ) (fun l r -> not l || r) bool_expr = true);
assert (consts int_int_expr = [10; 20; 50; 30; 40; 60; 70; 80]);
assert (consts bool_expr = [false; false; true; false; true; false; true;
false; false; false; false; false; true]);
assert (eval ( || ) (fun l r -> not l || r) (bool_map not bool_expr) = false);
assert (bool_map not bool_expr <> bool_expr);
assert (eval ( + ) ( < ) (a_bool_map (fun x -> x - 50) int_bool_expr) = true);
assert (eval ( ^ ) ( < ) (a_v_map string_of_int int_int_expr) = "102080");
(* ! *)
assert (try ignore (a_bool_map (fun _ -> 0) bool_expr); false
with Failure _ -> true);
assert (eval ( ^ ) ( < )
(a_v_map string_of_bool
(a_v_map (fun x -> x mod 20 = 0) int_int_expr))
= "falsetruetruefalse")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment