Skip to content

Instantly share code, notes, and snippets.

@akabe
Created November 19, 2015 11:39
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save akabe/f3f9f37e6344cb7385a7 to your computer and use it in GitHub Desktop.
Save akabe/f3f9f37e6344cb7385a7 to your computer and use it in GitHub Desktop.
A subtyping encoding by phantom types
open Format
(** The types of the source language
(Hindley-Milner + subtyping + bounded polymorphism) *)
module SL =
struct
type 'a typ =
| Base of 'a (** base type *)
| Var of string (** type variable *)
| Arrow of 'a typ * 'a typ (** function type *)
(** A pretty printer for types in the source language. *)
let pp_type pp_base ppf t =
let rec aux b ppf = function
| Base bt -> pp_base ppf bt
| Var x -> fprintf ppf "'%s" x
| Arrow (t1, t2) ->
let (fmt : _ format) = if b then "(%a -> %a)" else "%a -> %a" in
fprintf ppf fmt (aux true) t1 (aux false) t2
in
aux false ppf t
end
module TL =
struct
type typ =
| Var of string (** type variable *)
| Arrow of typ * typ (** function type *)
| T of typ (** type constructor T *)
| W (** phantom type W *)
| Z (** phantom type Z *)
| Tuple of typ list (** product type *)
let genvar =
let c = ref 0 in
fun () -> incr c ; Var ("a" ^ string_of_int !c)
(** A pretty printer for types in the target language. *)
let pp_type ppf t =
let rec aux b ppf = function
| Var x -> fprintf ppf "'%s" x
| Arrow (t1, t2) ->
let (fmt : _ format) = if b then "(%a -> %a)" else "%a -> %a" in
fprintf ppf fmt (aux true) t1 (aux false) t2
| T t -> fprintf ppf "%a t" (aux true) t
| W -> pp_print_string ppf "w"
| Z -> pp_print_string ppf "z"
| Tuple [] -> ()
| Tuple [t] -> aux false ppf t
| Tuple ts ->
let pp_sep ppf () = pp_print_string ppf " * " in
let (fmt : _ format) = if b then "(%a)" else "%a" in
fprintf ppf fmt (pp_print_list ~pp_sep (aux true)) ts
in
aux false ppf t
end
module Lattice =
struct
type 'a t = Node of 'a * 'a t list
let rec find_map f (Node (x, children)) =
let rec aux f = function (* find_map for lists *)
| [] -> None
| x :: xs -> match f x with None -> aux f xs | y -> y
in
match f x with None -> aux (find_map f) children | y -> y
(** Check whether a given tree has the single leaf, or not. *)
let has_bottom tree =
let rec aux acc = function
| Node (ty, []) -> if List.mem ty acc then acc else ty :: acc
| Node (_, children) -> List.fold_left aux acc children
in
match aux [] tree with [_] -> true | _ -> false
end
(** {2 Powerset lattice} *)
(** Conversion from subtyping hierarchy into powerset lattice *)
let make_powerset_lattice tree =
let powerset (Lattice.Node ((_, pset), _)) = pset in
let union xs ys = (* union set of [xs] and [ys] *)
List.fold_left (fun acc x -> if List.mem x acc then acc else x :: acc) ys xs
in
let mk_leaf =
if Lattice.has_bottom tree
then (fun ty -> Lattice.Node ((ty, []), [])) (* to simplify encoded types *)
else (fun ty -> Lattice.Node ((ty, [ty]), [])) in
let rec aux = function
| Lattice.Node (ty, []) -> mk_leaf ty
| Lattice.Node (ty, children) ->
let children' = List.map aux children in
let s = List.fold_left (fun s v -> union s (powerset v)) [] children' in
if List.exists (fun v -> s = powerset v) children'
then Lattice.Node ((ty, ty :: s), children')
else Lattice.Node ((ty, s), children')
in
aux tree
(** {2 Concrete and abstract encoding (for base types)} *)
let encode f g lattice ty =
let open Lattice in
let (Node ((_, uset), _)) = lattice in
let aux (ty', pset) =
if ty = ty'
then Some (List.map (fun i -> if List.mem i pset then f () else g ()) uset)
else None
in
match find_map aux lattice with
| None -> failwith "Oops! A given type is not found in the lattice."
| Some tys -> TL.Tuple tys
(** Encoding for base types at covariant positions *)
let encodeBaseP lattice ty =
encode (fun () -> TL.W) (fun () -> TL.genvar ()) lattice ty
(** Encoding for base types at contravariant positions *)
let encodeBaseN lattice ty =
encode (fun () -> TL.genvar ()) (fun () -> TL.Z) lattice ty
(** {2 Translation} *)
let trans lattice =
let rec auxP = function
| SL.Var x -> TL.Var x
| SL.Base ty -> TL.T (encodeBaseP lattice ty)
| SL.Arrow (t1, t2) -> TL.Arrow (auxN t1, auxP t2)
and auxN = function
| SL.Var x -> TL.Var x
| SL.Base ty -> TL.T (encodeBaseN lattice ty)
| SL.Arrow (t1, t2) -> TL.Arrow (auxP t1, auxN t2)
in
auxP
(* Make a powerset lattice: *)
let lattice =
let open Lattice in
Node ("A", [Node ("B", [Node ("F", [])]);
Node ("C", [Node ("D", [Node ("F", [])]);
Node ("E", [])])])
|> make_powerset_lattice
let () =
let t = SL.Arrow (SL.Base "A", SL.Arrow (SL.Base "B", SL.Base "C")) in
printf "%a@." (SL.pp_type pp_print_string) t;
printf "%a@." TL.pp_type (trans lattice t);
let t = SL.Arrow (SL.Arrow (SL.Base "A", SL.Base "B"), SL.Base "C") in
printf "%a@." (SL.pp_type pp_print_string) t;
printf "%a@." TL.pp_type (trans lattice t)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment