Skip to content

Instantly share code, notes, and snippets.

@Hirrolot
Last active January 15, 2024 15:28
Show Gist options
  • Save Hirrolot/505901460f131da1f0cd8b118e46a7bc to your computer and use it in GitHub Desktop.
Save Hirrolot/505901460f131da1f0cd8b118e46a7bc to your computer and use it in GitHub Desktop.
Higher-order polymorphic lambda calculus (Fω)

This is a minimalistic OCaml implementation of the type system from chapter 30 of TAPL, "Higher-Order Polymorphism".

The implementation uses bidirectional typing and does not feature existential types. Binders are represented as metalanguage functions (HOAS-style); free variables (TyFreeVar & FreeVar) are represented as De Bruijn levels.

See also:

Show unit tests
let test_print_kind () =
  let assert_print k expected = assert (print_kind k = expected) in

  assert_print KnStar "KnStar";
  assert_print (KnArr (KnStar, KnStar)) "KnArr(KnStar, KnStar)";
  assert_print
    (KnArr (KnStar, KnArr (KnStar, KnStar)))
    "KnArr(KnStar, KnArr(KnStar, KnStar))"

let test_print_ty () =
  let assert_print ty expected = assert (print_ty 42 ty = expected) in

  assert_print (TyFreeVar 0) "TyFreeVar 0";
  assert_print
    (TyArr (TyFreeVar 0, TyFreeVar 1))
    "TyArr(TyFreeVar 0, TyFreeVar 1)";
  assert_print
    (TyAll (KnStar, fun a -> TyArr (TyFreeVar 3, a)))
    "TyAll(KnStar, TyArr(TyFreeVar 3, TyFreeVar 42))";
  assert_print
    (TyLam (fun a -> TyArr (TyFreeVar 3, a)))
    "TyLam(TyArr(TyFreeVar 3, TyFreeVar 42))";
  assert_print
    (TyAppl (TyFreeVar 0, TyFreeVar 1))
    "TyAppl(TyFreeVar 0, TyFreeVar 1)";
  assert_print (TyAnn (TyFreeVar 0, KnStar)) "TyAnn(TyFreeVar 0, KnStar)"

let test_print_term () =
  let assert_print t expected = assert (print_term 42 t = expected) in

  assert_print (FreeVar 0) "FreeVar 0";
  assert_print
    (Lam (fun t -> Appl (t, FreeVar 0)))
    "Lam(Appl(FreeVar 42, FreeVar 0))";
  assert_print (Appl (FreeVar 0, FreeVar 1)) "Appl(FreeVar 0, FreeVar 1)";
  assert_print
    (TLam (fun a -> TAppl (FreeVar 0, a)))
    "TLam(TAppl(FreeVar 0, TyFreeVar 42))";
  assert_print (TAppl (FreeVar 0, TyFreeVar 1)) "TAppl(FreeVar 0, TyFreeVar 1)";
  assert_print (Ann (FreeVar 0, TyFreeVar 1)) "Ann(FreeVar 0, TyFreeVar 1)"

let test_equate_ty () =
  let assert_eq ty = assert (equate_ty 42 (ty, ty)) in
  let assert_neq a b =
    assert (not @@ equate_ty 42 (a, b));
    assert (not @@ equate_ty 42 (b, a))
  in

  (* [TyFreeVar] *)
  assert_eq (TyFreeVar 0);
  assert_neq (TyFreeVar 0) (TyFreeVar 1);
  assert_neq (TyFreeVar 0) (TyArr (TyFreeVar 0, TyFreeVar 1));

  (* [TyArr] *)
  assert_eq (TyArr (TyFreeVar 0, TyFreeVar 1));
  assert_neq
    (TyArr (TyFreeVar 0, TyFreeVar 1))
    (TyArr (TyFreeVar 1, TyFreeVar 1));
  assert_neq
    (TyArr (TyFreeVar 0, TyFreeVar 1))
    (TyArr (TyFreeVar 0, TyFreeVar 0));
  assert_neq (TyArr (TyFreeVar 0, TyFreeVar 1)) (TyFreeVar 0);

  (* [TyAll] *)
  assert_eq (TyAll (KnStar, fun a -> a));
  assert_neq
    (TyAll (KnStar, fun a -> a))
    (TyAll (KnArr (KnStar, KnStar), fun a -> a));
  assert_neq
    (TyAll (KnStar, fun a -> a))
    (TyAll (KnStar, fun _ -> TyFreeVar 123));
  assert_neq (TyAll (KnStar, fun a -> a)) (TyFreeVar 0);

  (* [TyLam] *)
  assert_eq (TyLam (fun a -> a));
  assert_neq (TyLam (fun a -> a)) (TyLam (fun _ -> TyFreeVar 123));
  assert_neq (TyLam (fun a -> a)) (TyFreeVar 0);

  (* [TyAppl] *)
  assert_eq (TyAppl (TyFreeVar 0, TyFreeVar 1));
  assert_neq
    (TyAppl (TyFreeVar 0, TyFreeVar 1))
    (TyAppl (TyFreeVar 1, TyFreeVar 1));
  assert_neq
    (TyAppl (TyFreeVar 0, TyFreeVar 1))
    (TyAppl (TyFreeVar 0, TyFreeVar 0));
  assert_neq (TyAppl (TyFreeVar 0, TyFreeVar 1)) (TyFreeVar 0);

  (* [TyAnn] *)
  assert_eq (TyAnn (TyFreeVar 0, KnStar));
  assert_neq (TyAnn (TyFreeVar 0, KnStar)) (TyAnn (TyFreeVar 1, KnStar));
  assert_neq
    (TyAnn (TyFreeVar 0, KnStar))
    (TyAnn (TyFreeVar 0, KnArr (KnStar, KnStar)));
  assert_neq (TyAnn (TyFreeVar 0, KnStar)) (TyFreeVar 0)

let simple_comp_ty x = TyAppl (TyLam (fun a -> a), TyFreeVar x)

let test_eval_ty () =
  let assert_eval ty expected = assert (equate_ty 0 (eval_ty ty, expected)) in

  (* [TyFreeVar] *)
  assert_eval (TyFreeVar 42) (TyFreeVar 42);

  (* [TyArr] *)
  assert_eval
    (TyArr (simple_comp_ty 0, simple_comp_ty 42))
    (TyArr (TyFreeVar 0, TyFreeVar 42));

  (* [TyAll] *)
  assert_eval (TyAll (KnStar, fun a -> a)) (TyAll (KnStar, fun a -> a));
  assert_eval
    (TyAll (KnStar, fun a -> TyAppl (TyLam (fun a -> a), a)))
    (TyAll (KnStar, fun a -> a));

  (* [TyLam] *)
  assert_eval (TyLam (fun a -> a)) (TyLam (fun a -> a));
  assert_eval
    (TyLam (fun a -> TyAppl (TyLam (fun a -> a), a)))
    (TyLam (fun a -> a));

  (* [TyAppl] *)
  assert_eval (simple_comp_ty 42) (TyFreeVar 42);
  assert_eval
    (TyAppl (TyFreeVar 0, TyFreeVar 42))
    (TyAppl (TyFreeVar 0, TyFreeVar 42));
  assert_eval
    (TyAppl (simple_comp_ty 0, TyFreeVar 42))
    (TyAppl (TyFreeVar 0, TyFreeVar 42));

  (* [TyAnn] *)
  assert_eval (TyAnn (simple_comp_ty 42, KnStar)) (TyFreeVar 42)

(* For the sake of simplicity, type checker tests might abuse the
   well-formedness preconditions of the type and context parameters: for
   example, we often use magic De Bruijn levels. *)

let assert_infer_kind ctx ty expected_k =
  let lvl = List.length ctx in
  assert (infer_kind lvl ctx ty = expected_k)

let test_infer_ty_free_var () =
  let k_arr = KnArr (KnStar, KnStar) in
  assert_infer_kind [ TyVarBind KnStar ] (TyFreeVar 0) KnStar;
  assert_infer_kind
    [ TyVarBind k_arr; TyVarBind KnStar; TyVarBind k_arr ]
    (TyFreeVar 1) KnStar;
  assert_infer_kind
    [ TyVarBind KnStar; TyVarBind k_arr; TyVarBind k_arr ]
    (TyFreeVar 2) KnStar;
  try
    assert_infer_kind
      [ TyVarBind KnStar; TyVarBind KnStar; TyVarBind KnStar ]
      (TyFreeVar 3) KnStar;
    assert false
  with Invalid_argument msg -> (
    assert (msg = "List.nth");
    try
      assert_infer_kind [ VarBind (TyFreeVar 0) ] (TyFreeVar 0) KnStar;
      assert false
    with Failure msg -> assert (msg = "Expected a type variable: TyFreeVar 0"))

let test_infer_ty_arr () =
  assert_infer_kind
    [ TyVarBind KnStar; TyVarBind KnStar ]
    (TyArr (TyFreeVar 1, TyFreeVar 0))
    KnStar;
  try
    assert_infer_kind
      [ TyVarBind (KnArr (KnStar, KnStar)); TyVarBind KnStar ]
      (TyArr (TyFreeVar 1, TyFreeVar 0))
      KnStar;
    assert false
  with Failure msg -> (
    assert (msg = "Want KnStar, got KnArr(KnStar, KnStar): TyFreeVar 1");
    try
      assert_infer_kind
        [ TyVarBind (KnArr (KnStar, KnStar)); TyVarBind KnStar ]
        (TyArr (TyFreeVar 0, TyFreeVar 1))
        KnStar;
      assert false
    with Failure msg ->
      assert (msg = "Want KnStar, got KnArr(KnStar, KnStar): TyFreeVar 1"))

let test_infer_ty_all () =
  assert_infer_kind
    [ TyVarBind (KnArr (KnStar, KnStar)) ]
    (TyAll (KnStar, fun x -> TyAppl (TyFreeVar 0, x)))
    KnStar;
  try
    assert_infer_kind
      [ TyVarBind (KnArr (KnStar, KnStar)) ]
      (TyAll (KnStar, fun _ -> TyFreeVar 0))
      KnStar;
    assert false
  with Failure msg ->
    assert (msg = "Want KnStar, got KnArr(KnStar, KnStar): TyFreeVar 0")

let test_infer_ty_appl () =
  assert_infer_kind
    [ TyVarBind (KnArr (KnStar, KnStar)); TyVarBind KnStar ]
    (TyAppl (TyFreeVar 1, TyFreeVar 0))
    KnStar;
  try
    assert_infer_kind
      [ TyVarBind KnStar; TyVarBind KnStar ]
      (TyAppl (TyFreeVar 1, TyFreeVar 0))
      KnStar;
    assert false
  with Failure msg -> (
    assert (msg = "Want KnArr, got KnStar: TyFreeVar 1");
    try
      assert_infer_kind
        [
          TyVarBind (KnArr (KnStar, KnStar)); TyVarBind (KnArr (KnStar, KnStar));
        ]
        (TyAppl (TyFreeVar 1, TyFreeVar 0))
        KnStar;
      assert false
    with Failure msg ->
      assert (msg = "Want KnStar, got KnArr(KnStar, KnStar): TyFreeVar 0"))

let test_infer_ty_ann () =
  assert_infer_kind []
    (TyAnn (TyLam (fun a -> a), KnArr (KnStar, KnStar)))
    (KnArr (KnStar, KnStar));
  try
    assert_infer_kind [ TyVarBind KnStar ]
      (TyAnn (TyFreeVar 0, KnArr (KnStar, KnStar)))
      KnStar;
    assert false
  with Failure msg ->
    assert (msg = "Want KnArr(KnStar, KnStar), got KnStar: TyFreeVar 0")

let test_infer_ty_lam () =
  try
    assert_infer_kind [] (TyLam (fun a -> a)) KnStar;
    assert false
  with Failure msg -> assert (msg = "Not inferrable: TyLam(TyFreeVar 0)")

let assert_check_kind ctx (ty, k) =
  let lvl = List.length ctx in
  check_kind lvl ctx (ty, k)

let test_check_ty_lam () =
  assert_check_kind
    [ TyVarBind (KnArr (KnStar, KnStar)) ]
    (TyLam (fun a -> TyAppl (TyFreeVar 0, a)), KnArr (KnStar, KnStar));
  try
    assert_check_kind [] (TyLam (fun a -> a), KnStar);
    assert false
  with Failure msg -> (
    assert (msg = "Want KnArr, got KnStar: TyLam(TyFreeVar 0)");
    try
      assert_check_kind
        [ TyVarBind (KnArr (KnStar, KnStar)) ]
        (TyLam (fun _ -> TyFreeVar 0), KnArr (KnStar, KnStar));
      assert false
    with Failure msg ->
      assert (msg = "Want KnStar, got KnArr(KnStar, KnStar): TyFreeVar 0"))

let test_check_infer_ty () =
  assert_check_kind [ TyVarBind KnStar ] (TyFreeVar 0, KnStar);
  try
    assert_check_kind [ TyVarBind KnStar ] (TyFreeVar 0, KnArr (KnStar, KnStar));
    assert false
  with Failure msg ->
    assert (msg = "Want KnArr(KnStar, KnStar), got KnStar: TyFreeVar 0")

let assert_infer_ty ctx t expected_ty =
  let lvl = List.length ctx in
  assert (infer_ty lvl ctx t = expected_ty)

let test_infer_free_var () =
  assert_infer_ty [ VarBind (TyFreeVar 42) ] (FreeVar 0) (TyFreeVar 42);
  assert_infer_ty
    [ VarBind (TyFreeVar 0); VarBind (TyFreeVar 42); VarBind (TyFreeVar 0) ]
    (FreeVar 1) (TyFreeVar 42);
  assert_infer_ty
    [ VarBind (TyFreeVar 42); VarBind (TyFreeVar 0); VarBind (TyFreeVar 0) ]
    (FreeVar 2) (TyFreeVar 42);
  try
    assert_infer_ty
      [ VarBind (TyFreeVar 0); VarBind (TyFreeVar 0); VarBind (TyFreeVar 0) ]
      (FreeVar 3) (TyFreeVar 42);
    assert false
  with Invalid_argument msg -> (
    assert (msg = "List.nth");
    try
      assert_infer_ty [ TyVarBind KnStar ] (FreeVar 0) (TyFreeVar 42);
      assert false
    with Failure msg -> assert (msg = "Expected a term variable: FreeVar 0"))

let test_infer_appl () =
  assert_infer_ty
    [
      VarBind (TyArr (TyFreeVar 123, TyFreeVar 42));
      VarBind (TyFreeVar 123);
      VarBind (TyFreeVar 42);
    ]
    (Appl (FreeVar 2, FreeVar 1))
    (TyFreeVar 42);
  try
    assert_infer_ty
      [ VarBind (TyFreeVar 123); VarBind (TyFreeVar 42) ]
      (Appl (FreeVar 1, FreeVar 0))
      (TyFreeVar 55);
    assert false
  with Failure msg -> (
    assert (msg = "Want TyArr, got TyFreeVar 123: FreeVar 1");
    try
      assert_infer_ty
        [
          VarBind (TyArr (TyFreeVar 123, TyFreeVar 42));
          VarBind (TyFreeVar 123);
          VarBind (TyFreeVar 42);
        ]
        (Appl (FreeVar 2, FreeVar 0))
        (TyFreeVar 55);
      assert false
    with Failure msg ->
      assert (msg = "Want TyFreeVar 123, got TyFreeVar 42: FreeVar 0"))

let test_infer_tappl () =
  assert_infer_ty
    [
      VarBind (TyAll (KnStar, fun a -> TyAppl (TyFreeVar 42, a)));
      TyVarBind KnStar;
    ]
    (TAppl (FreeVar 1, TyFreeVar 0))
    (TyAppl (TyFreeVar 42, TyFreeVar 0));
  try
    assert_infer_ty
      [ VarBind (TyFreeVar 123); TyVarBind KnStar ]
      (TAppl (FreeVar 1, TyFreeVar 0))
      (TyFreeVar 42);
    assert false
  with Failure msg -> (
    assert (msg = "Want TyAll, got TyFreeVar 123: FreeVar 1");
    try
      assert_infer_ty
        [
          VarBind (TyAll (KnStar, fun a -> a));
          TyVarBind (KnArr (KnStar, KnStar));
        ]
        (TAppl (FreeVar 1, TyFreeVar 0))
        (TyFreeVar 42);
      assert false
    with Failure msg ->
      assert (msg = "Want KnStar, got KnArr(KnStar, KnStar): TyFreeVar 0"))

let test_infer_ann () =
  assert_infer_ty [ TyVarBind KnStar ]
    (Ann (Lam (fun a -> a), TyArr (TyFreeVar 0, TyFreeVar 0)))
    (TyArr (TyFreeVar 0, TyFreeVar 0));
  (* The annotation must be a well-formed type. *)
  try
    assert_infer_ty
      [ VarBind (TyFreeVar 123); TyVarBind (KnArr (KnStar, KnStar)) ]
      (Ann (FreeVar 1, TyAnn (TyFreeVar 0, KnStar)))
      (TyFreeVar 42);
    assert false
  with Failure msg -> (
    assert (msg = "Want KnStar, got KnArr(KnStar, KnStar): TyFreeVar 0");
    try
      assert_infer_ty
        [ VarBind (TyFreeVar 123); TyVarBind KnStar ]
        (Ann (FreeVar 1, TyFreeVar 0))
        (TyFreeVar 42);
      assert false
    with Failure msg ->
      assert (msg = "Want TyFreeVar 0, got TyFreeVar 123: FreeVar 1"))

let test_infer_lam () =
  try
    assert_infer_ty [] (Lam (fun a -> a)) (TyFreeVar 42);
    assert false
  with Failure msg -> assert (msg = "Not inferrable: Lam(FreeVar 0)")

let test_infer_tlam () =
  try
    assert_infer_ty [] (TLam (fun _ -> FreeVar 123)) (TyFreeVar 42);
    assert false
  with Failure msg -> assert (msg = "Not inferrable: TLam(FreeVar 123)")

let assert_check_ty ctx (t, ty) =
  let lvl = List.length ctx in
  check_ty lvl ctx (t, ty)

let test_check_lam () =
  let comp_ty_arr =
    TyAppl (TyLam (fun _ -> TyArr (TyFreeVar 123, TyFreeVar 42)), TyFreeVar 55)
  in
  (* Test that the type is evaluated. *)
  assert_check_ty
    [
      VarBind (TyArr (TyFreeVar 123, TyFreeVar 42));
      VarBind (TyFreeVar 123);
      VarBind (TyFreeVar 42);
    ]
    (Lam (fun x -> Appl (FreeVar 2, x)), comp_ty_arr);
  try
    assert_check_ty [] (Lam (fun x -> x), TyFreeVar 123);
    assert false
  with Failure msg -> (
    assert (msg = "Want TyArr, got TyFreeVar 123: Lam(FreeVar 0)");
    try
      assert_check_ty
        [ VarBind (TyFreeVar 123) ]
        (Lam (fun _ -> FreeVar 0), TyArr (TyFreeVar 123, TyFreeVar 42));
      assert false
    with Failure msg ->
      assert (msg = "Want TyFreeVar 42, got TyFreeVar 123: FreeVar 0"))

let test_check_tlam () =
  let comp_ty_all =
    TyAppl (TyLam (fun _ -> TyAll (KnStar, fun a -> a)), TyFreeVar 55)
  in
  (* Test that the type is evaluated. *)
  assert_check_ty
    [ VarBind (TyAll (KnStar, fun a -> a)) ]
    (TLam (fun a -> TAppl (FreeVar 0, a)), comp_ty_all);
  try
    assert_check_ty [] (TLam (fun _ -> FreeVar 0), TyFreeVar 123);
    assert false
  with Failure msg -> (
    assert (msg = "Want TyAll, got TyFreeVar 123: TLam(FreeVar 0)");
    try
      assert_check_ty
        [ VarBind (TyFreeVar 123) ]
        (TLam (fun _ -> FreeVar 0), TyAll (KnStar, fun _ -> TyFreeVar 42));
      assert false
    with Failure msg ->
      assert (msg = "Want TyFreeVar 42, got TyFreeVar 123: FreeVar 0"))

let test_check_infer () =
  assert_check_ty [ VarBind (TyFreeVar 123) ] (FreeVar 0, TyFreeVar 123);
  (* Test beta equality. *)
  assert_check_ty
    [ VarBind (TyLam (fun _ -> TyFreeVar 123)) ]
    (FreeVar 0, TyLam (fun _ -> simple_comp_ty 123));
  try
    assert_check_ty [ VarBind (TyFreeVar 123) ] (FreeVar 0, TyFreeVar 42);
    assert false
  with Failure msg ->
    assert (msg = "Want TyFreeVar 42, got TyFreeVar 123: FreeVar 0")

let () =
  test_print_kind ();
  test_print_ty ();
  test_print_term ();

  test_equate_ty ();
  test_eval_ty ();

  test_infer_ty_free_var ();
  test_infer_ty_arr ();
  test_infer_ty_all ();
  test_infer_ty_appl ();
  test_infer_ty_ann ();
  test_infer_ty_lam ();
  test_check_ty_lam ();
  test_check_infer_ty ();

  test_infer_free_var ();
  test_infer_appl ();
  test_infer_tappl ();
  test_infer_ann ();
  test_infer_lam ();
  test_infer_tlam ();
  test_check_lam ();
  test_check_tlam ();
  test_check_infer ()
type kind = KnStar | KnArr of kind * kind
type ty =
| TyFreeVar of int
| TyArr of ty * ty
| TyAll of kind * (ty -> ty)
| TyLam of (ty -> ty)
| TyAppl of ty * ty
| TyAnn of ty * kind
let unfurl_ty lvl f = f (TyFreeVar lvl)
let unfurl_ty2 lvl (f, g) = (unfurl_ty lvl f, unfurl_ty lvl g)
type term =
| FreeVar of int
| Lam of (term -> term)
| Appl of term * term
| TLam of (ty -> term)
| TAppl of term * ty
| Ann of term * ty
let unfurl lvl f = f (FreeVar lvl)
(* Prints a kind representation into a string. *)
let rec print_kind = function
| KnStar -> "KnStar"
| KnArr (k1, k2) -> "KnArr(" ^ print_kind k1 ^ ", " ^ print_kind k2 ^ ")"
(* Prints a type representation into a string. *)
let rec print_ty lvl =
let plunge f = print_ty (lvl + 1) (unfurl_ty lvl f) in
function
| TyFreeVar x -> "TyFreeVar " ^ string_of_int x
| TyArr (a, b) -> "TyArr(" ^ print_ty lvl a ^ ", " ^ print_ty lvl b ^ ")"
| TyAll (k, f) -> "TyAll(" ^ print_kind k ^ ", " ^ plunge f ^ ")"
| TyLam f -> "TyLam(" ^ plunge f ^ ")"
| TyAppl (a, b) -> "TyAppl(" ^ print_ty lvl a ^ ", " ^ print_ty lvl b ^ ")"
| TyAnn (a, k) -> "TyAnn(" ^ print_ty lvl a ^ ", " ^ print_kind k ^ ")"
(* Prints a term representation into a string. *)
let rec print_term lvl =
let plunge f = print_term (lvl + 1) (unfurl lvl f) in
let plunge_ty f = print_term (lvl + 1) (unfurl_ty lvl f) in
function
| FreeVar x -> "FreeVar " ^ string_of_int x
| Lam f -> "Lam(" ^ plunge f ^ ")"
| Appl (m, n) -> "Appl(" ^ print_term lvl m ^ ", " ^ print_term lvl n ^ ")"
| TLam f -> "TLam(" ^ plunge_ty f ^ ")"
| TAppl (m, a) -> "TAppl(" ^ print_term lvl m ^ ", " ^ print_ty lvl a ^ ")"
| Ann (m, a) -> "Ann(" ^ print_term lvl m ^ ", " ^ print_ty lvl a ^ ")"
(* Just structurally checks two types for equality. *)
let rec equate_ty lvl =
let plunge (f, g) = equate_ty (lvl + 1) (unfurl_ty2 lvl (f, g)) in
function
| TyFreeVar x, TyFreeVar y -> x = y
| TyArr (a, b), TyArr (a', b') | TyAppl (a, b), TyAppl (a', b') ->
equate_ty lvl (a, a') && equate_ty lvl (b, b')
| TyAll (k, f), TyAll (k', g) -> k = k' && plunge (f, g)
| TyLam f, TyLam g -> plunge (f, g)
| TyAnn (a, k), TyAnn (a', k') -> equate_ty lvl (a, a') && k = k'
| _ -> false
(* Full reduction on a given type. *)
let rec eval_ty = function
| TyArr (a, b) -> TyArr (eval_ty a, eval_ty b)
| TyAll (k, f) -> TyAll (k, fun a -> eval_ty (f a))
| TyLam f -> TyLam (fun a -> eval_ty (f a))
| TyAppl (a, b) -> (
match (eval_ty a, eval_ty b) with
| TyLam f, b -> f b
| a, b -> TyAppl (a, b))
| TyAnn (a, _k) -> eval_ty a
| TyFreeVar x -> TyFreeVar x
(* Either a term or type variable binding. *)
type binding = VarBind of ty | TyVarBind of kind
let bind ty ctx = VarBind ty :: ctx
let bind_ty k ctx = TyVarBind k :: ctx
let panic_ty lvl ty fmt =
let open Printf in
let fail fmt = ksprintf failwith fmt in
ksprintf (fun s -> fail "%s: %s" s (print_ty lvl ty)) fmt
(* Infers a kind for a given type in a well-formed context. *)
let rec infer_kind lvl ctx = function
| TyFreeVar x -> (
match List.nth ctx (lvl - x - 1) with
| TyVarBind k -> k
| VarBind _ -> panic_ty lvl (TyFreeVar x) "Expected a type variable")
| TyArr (a, b) ->
check_kind lvl ctx (a, KnStar);
check_kind lvl ctx (b, KnStar);
KnStar
| TyAll (k, f) ->
check_kind (lvl + 1) (bind_ty k ctx) (unfurl_ty lvl f, KnStar);
KnStar
| TyAppl (a, b) -> (
match infer_kind lvl ctx a with
| KnArr (k1, k2) ->
check_kind lvl ctx (b, k1);
k2
| a_k -> panic_ty lvl a "Want KnArr, got %s" (print_kind a_k))
| TyAnn (a, k) ->
check_kind lvl ctx (a, k);
k
| ty -> panic_ty lvl ty "Not inferrable"
(* Checks a given type against a given kind in a well-formed context. *)
and check_kind lvl ctx = function
| TyLam f, KnArr (k1, k2) ->
check_kind (lvl + 1) (bind_ty k1 ctx) (unfurl_ty lvl f, k2)
| TyLam f, k -> panic_ty lvl (TyLam f) "Want KnArr, got %s" (print_kind k)
| ty, k ->
let got_k = infer_kind lvl ctx ty in
if k <> got_k then
panic_ty lvl ty "Want %s, got %s" (print_kind k) (print_kind got_k)
(* Checks two types for beta (computational) equality. *)
let beta_eq_ty lvl (a, b) = equate_ty lvl (eval_ty a, eval_ty b)
let panic lvl t fmt =
let open Printf in
let fail fmt = ksprintf failwith fmt in
ksprintf (fun s -> fail "%s: %s" s (print_term lvl t)) fmt
(* Infers a type for a given term in a well-formed context. *)
let rec infer_ty lvl ctx = function
| FreeVar x -> (
match List.nth ctx (lvl - x - 1) with
| VarBind ty -> ty
| TyVarBind _ -> panic lvl (FreeVar x) "Expected a term variable")
| Appl (m, n) -> (
match infer_ty lvl ctx m with
| TyArr (a, b) ->
check_ty lvl ctx (n, a);
b
| m_ty -> panic lvl m "Want TyArr, got %s" (print_ty lvl m_ty))
| TAppl (m, a) -> (
match infer_ty lvl ctx m with
| TyAll (k, f) ->
check_kind lvl ctx (a, k);
f a
| m_ty -> panic lvl m "Want TyAll, got %s" (print_ty lvl m_ty))
| Ann (m, a) ->
check_kind lvl ctx (a, KnStar);
check_ty lvl ctx (m, a);
a
| t -> panic lvl t "Not inferrable"
(* Checks a given term against a given well-formed type in a well-formed
context. *)
and check_ty lvl ctx = function
| Lam f, ty -> (
match eval_ty ty with
| TyArr (a, b) -> check_ty (lvl + 1) (bind a ctx) (unfurl lvl f, b)
| ty -> panic lvl (Lam f) "Want TyArr, got %s" (print_ty lvl ty))
| TLam f, ty -> (
match eval_ty ty with
| TyAll (k, g) ->
check_ty (lvl + 1) (bind_ty k ctx) (unfurl_ty2 lvl (f, g))
| ty -> panic lvl (TLam f) "Want TyAll, got %s" (print_ty lvl ty))
| t, ty ->
let got_ty = infer_ty lvl ctx t in
if not (beta_eq_ty lvl (ty, got_ty)) then
panic lvl t "Want %s, got %s" (print_ty lvl ty) (print_ty lvl got_ty)
(* Term-level reduction would look much like [eval_ty] above. *)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment