Create a gist now

Instantly share code, notes, and snippets.

@akabe /subtyping1.ml
Last active Nov 22, 2015

What would you like to do?
A subtyping encoding by phantom types [Fluet and Pucella, JFP 2006]
open Format
(** The types of the source language
(Hindley-Milner + subtyping + bounded polymorphism) *)
module SL =
struct
type 'a typ =
| Base of 'a (** base type *)
| Var of string (** type variable *)
| Arrow of 'a typ * 'a typ (** function type *)
type 'a type_scheme = (** type schemes with subtyping *)
| Forall of (string * 'a typ) list * 'a typ
(** A pretty printer for types in the source language. *)
let pp_type pp_base ppf t =
let rec aux b ppf = function
| Base bt -> pp_base ppf bt
| Var x -> fprintf ppf "'%s" x
| Arrow (t1, t2) ->
let (fmt : _ format) = if b then "(%a -> %a)" else "%a -> %a" in
fprintf ppf fmt (aux true) t1 (aux false) t2
in
aux false ppf t
(** A pretty printer for type schemes in the source language. *)
let pp_type_scheme pp_base ppf = function
| Forall ([], t) -> pp_type pp_base ppf t
| Forall (args, t) ->
let pp_sep ppf () = pp_print_string ppf ", " in
let pp_elm ppf (x, bt) = fprintf ppf "'%s <: %a" x (pp_type pp_base) bt in
fprintf ppf "forall %a. %a" (pp_print_list ~pp_sep pp_elm) args (pp_type pp_base) t
end
module TL =
struct
type typ =
| Var of string (** type variable *)
| Arrow of typ * typ (** function type *)
| T of typ (** type constructor T *)
| Z of typ (** type constructor Z (phantom) *)
| Unit (** unit type *)
| Tuple of typ list (** product type *)
type 'a type_scheme = (** type schemes with subtyping *)
| Forall of string list * typ
let genvar =
let c = ref 0 in
fun () -> incr c ; Var ("a" ^ string_of_int !c)
(** Returns a list of free type variables in a given type. *)
let fv t =
let rec aux acc = function
| Var x -> if List.mem x acc then acc else x :: acc
| Arrow (t1, t2) -> aux (aux acc t1) t2
| T t | Z t -> aux acc t
| Unit -> acc
| Tuple ts -> List.fold_left aux acc ts
in
aux [] t
(** A pretty printer for types in the target language. *)
let pp_type ppf t =
let rec aux b ppf = function
| Var x -> fprintf ppf "'%s" x
| Arrow (t1, t2) ->
let (fmt : _ format) = if b then "(%a -> %a)" else "%a -> %a" in
fprintf ppf fmt (aux true) t1 (aux false) t2
| T t -> fprintf ppf "%a t" (aux true) t
| Z t -> fprintf ppf "%a z" (aux true) t
| Unit -> pp_print_string ppf "unit"
| Tuple [] -> ()
| Tuple [t] -> aux false ppf t
| Tuple ts ->
let pp_sep ppf () = pp_print_string ppf " * " in
let (fmt : _ format) = if b then "(%a)" else "%a" in
fprintf ppf fmt (pp_print_list ~pp_sep (aux true)) ts
in
aux false ppf t
(** A pretty printer for type schemes in the target language. *)
let pp_type_scheme ppf = function
| Forall ([], t) -> pp_type ppf t
| Forall (args, t) ->
let pp_sep ppf () = pp_print_string ppf ", " in
let pp_elm ppf x = fprintf ppf "'%s" x in
fprintf ppf "forall %a. %a" (pp_print_list ~pp_sep pp_elm) args pp_type t
end
module Lattice =
struct
type 'a t = Node of 'a * 'a t list
let rec find_map f (Node (x, children)) =
let rec aux f = function (* find_map for lists *)
| [] -> None
| x :: xs -> match f x with None -> aux f xs | y -> y
in
match f x with None -> aux (find_map f) children | y -> y
(** Check whether a given tree has the single leaf, or not. *)
let has_bottom tree =
let rec aux acc = function
| Node (ty, []) -> if List.mem ty acc then acc else ty :: acc
| Node (_, children) -> List.fold_left aux acc children
in
match aux [] tree with [_] -> true | _ -> false
end
(** {2 Powerset lattice} *)
(** Conversion from subtyping hierarchy into powerset lattice *)
let make_powerset_lattice tree =
let powerset (Lattice.Node ((_, pset), _)) = pset in
let union xs ys = (* union set of [xs] and [ys] *)
List.fold_left (fun acc x -> if List.mem x acc then acc else x :: acc) ys xs
in
let mk_leaf =
if Lattice.has_bottom tree
then (fun ty -> Lattice.Node ((ty, []), [])) (* to simplify encoded types *)
else (fun ty -> Lattice.Node ((ty, [ty]), [])) in
let rec aux = function
| Lattice.Node (ty, []) -> mk_leaf ty
| Lattice.Node (ty, children) ->
let children' = List.map aux children in
let s = List.fold_left (fun s v -> union s (powerset v)) [] children' in
if List.exists (fun v -> s = powerset v) children'
then Lattice.Node ((ty, ty :: s), children')
else Lattice.Node ((ty, s), children')
in
aux tree
(** {2 Concrete and abstract encoding (for base types)} *)
let encode f g lattice ty =
let open Lattice in
let (Node ((_, uset), _)) = lattice in
let aux (ty', pset) =
if ty = ty'
then Some (List.map (fun i -> if List.mem i pset then f () else g ()) uset)
else None
in
match find_map aux lattice with
| None -> failwith "Oops! A given type is not found in the lattice."
| Some tys -> TL.Tuple tys
(** Concrete encoding for base types *)
let encodeBaseC lattice ty =
encode (fun () -> TL.Unit) (fun () -> TL.Z TL.Unit) lattice ty
(** Abstract encoding for base types *)
let encodeBaseA lattice ty =
encode (fun () -> TL.genvar ()) (fun () -> TL.Z (TL.genvar ())) lattice ty
(** {2 Translation} *)
exception Unexpected_type_param
let rec transC lattice = function
| SL.Var x -> raise Unexpected_type_param
| SL.Base ty -> TL.T (encodeBaseC lattice ty)
| SL.Arrow (t1, t2) -> TL.Arrow (transC lattice t1, transC lattice t2)
let rec transA lattice = function
| SL.Var x -> raise Unexpected_type_param
| SL.Base ty -> TL.T (encodeBaseA lattice ty)
| SL.Arrow (t1, t2) -> TL.Arrow (transC lattice t1, transA lattice t2)
let transT ?(rho = fun x -> TL.Var x) lattice (SL.Forall (bounds, ty)) =
let tauAs = List.map (fun (_, ty) -> transA lattice ty) bounds in
let fvs = List.fold_left (fun acc tauA -> acc @ TL.fv tauA) [] tauAs in
let rho' = List.fold_left2
(fun f (x, _) tauA -> (fun y -> if x = y then tauA else f y))
rho bounds tauAs in
let rec aux = function
| SL.Base ty -> encodeBaseC lattice ty
| SL.Var x -> rho' x
| SL.Arrow (t1, t2) -> TL.Arrow (aux t1, aux t1)
in
TL.Forall (List.sort compare fvs, aux ty)
(* Make a powerset lattice: *)
let lattice =
let open Lattice in
Node ("A", [Node ("B", [Node ("F", [])]);
Node ("C", [Node ("D", [Node ("F", [])]);
Node ("E", [])])])
|> make_powerset_lattice
let () =
(* t = forall 'a <: E. 'a -> 'a *)
let t = SL.Forall (["a", SL.Base "E"], SL.Arrow (SL.Var "a", SL.Var "a")) in
printf "%a@." (SL.pp_type_scheme pp_print_string) t;
let t' = transT lattice t in
printf "%a@." TL.pp_type_scheme t'
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment