Skip to content

Instantly share code, notes, and snippets.

@secondwtq
Last active February 23, 2020 16:58
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save secondwtq/a2d2a220df790f66569c2091b6e8e954 to your computer and use it in GitHub Desktop.
Save secondwtq/a2d2a220df790f66569c2091b6e8e954 to your computer and use it in GitHub Desktop.
(* Implementation of Local Bidirectional Type Inference
* Based on:
* Jones, Simon L. Peyton, Dimitm.
*)
(*
* Tau Type - Monotype
* Rho Type - Tau | Sigma Sigma
* (the second form of Rho is absent in Rank-1 system)
* Sigma Type - Polytype (forall a. Rho)
*)
module D = Decl
module S = Semtree
module T = Types
module SP = Symtab.Pure
module U = Util
module M = Modular
module TU = Type_util
let transform_expr_apply_op_infix (src: S.expr): S.expr =
match src with
(* The type here is meaningless *)
| (ExprApplyOpInfix { func; lhs; rhs }, t), annot -> (
ExprApply { func = (
(ExprApply { func; arg = lhs; }, t), annot); arg = rhs }, t), annot
| _ -> U.unreachable "unexpected expression in transform_expr_apply_op_infix"
let restore_expr_apply_op_infix (src: S.expr): S.expr =
match src with
| (ExprApply {
func = (ExprApply { func; arg = lhs }, tyinter), annotinter;
arg = rhs }, tyexpr), annot ->
(ExprApplyOpInfix { func; lhs; rhs }, tyexpr), annot
| _ -> U.unreachable "unexpected expression in restore_expr_apply_op_infix"
module Typechecker = struct
type tau = T.t
type rho = T.t
type sigma = T.t
type kind = | Infer | Check
module MetaVarEnv = Symtab.Symap.Nonpure
type metavar_env = tau MetaVarEnv.t
module VarEnv = Symtab.Symap.Nonpure
type var_env = sigma VarEnv.t
module SkolEnv = Symtab.Symap.Nonpure
type skol_env = sigma SkolEnv.t
module TypeSet = CCSet.Make(T)
module MetaVarSet = CCSet.Make(Symbol)
module TypeVarSet = CCSet.Make(Symbol)
type context = {
inside: D.decl_module;
metavar_env: metavar_env;
var_env: var_env;
}
let context_create ~(inside: D.decl_module): context = {
inside;
metavar_env = MetaVarEnv.create 31;
var_env = VarEnv.create 31;
}
exception TcError of string
exception TcUnificationError of T.t * T.t
let rec get_metavars (ty: T.t): MetaVarSet.t =
match ty with
| TMetaVar var -> MetaVarSet.singleton var
| TFunc (tyarg, tyret) -> MetaVarSet.union (get_metavars tyarg)
(get_metavars tyret)
| TForAll (_, ty) -> get_metavars ty
| TIntr _ | TVar _ | TConstr _ -> MetaVarSet.empty
| TTuple tys -> List.fold_left MetaVarSet.union MetaVarSet.empty
(List.map get_metavars tys)
| TApp (src, arg) -> MetaVarSet.union (get_metavars src)
(get_metavars arg)
| _ -> U.unreachable ("unexpected type in get_metavars " ^ (T.to_string ty))
let rec get_free_typevars (bound: TypeVarSet.t) (ty: T.t): TypeVarSet.t =
match ty with
| TVar var -> if TypeVarSet.mem var bound
then TypeVarSet.empty else TypeVarSet.singleton var
| TForAll (var, ty) -> get_free_typevars (TypeVarSet.add var bound) ty
| TFunc (tyarg, tyret) -> TypeVarSet.union (get_free_typevars bound tyarg)
(get_free_typevars bound tyret)
| TConstr _ | TMetaVar _ | TIntr _ -> TypeVarSet.empty
| TTuple tys -> List.fold_left TypeVarSet.union TypeVarSet.empty
(List.map (get_free_typevars bound) tys)
| TApp (src, arg) -> TypeVarSet.union (get_free_typevars bound src)
(get_free_typevars bound arg)
| _ -> U.unreachable "unexpected type in get_free_typevars"
let rec skolemise ~inside (sigma: sigma): TypeVarSet.t * rho =
match TU.full_dealias_and_apply_forall ~inside sigma with
| TForAll (var, ty) ->
let skol_typevar = Symbol.create None (Symbol.name var) in
let skol_typevar_set, rho = skolemise ~inside
(TU.substitute ~inside
(TU.SubstEnv.singleton var (T.TVar skol_typevar)) ty) in
TypeVarSet.add skol_typevar skol_typevar_set, rho
| TFunc (tyarg, tyret) ->
let skol_typevar_set, rhoret = skolemise ~inside tyret in
skol_typevar_set, TFunc (tyarg, rhoret)
| _ -> TypeVarSet.empty, sigma
let rec zonk (ctx: context) (ty: T.t): T.t =
let zonk_ = zonk ctx in
match ty with
| TMetaVar var -> (
match MetaVarEnv.get ctx.metavar_env var with
| Some ty -> zonk_ ty | None -> ty)
| TFunc (tyarg, tyret) -> TFunc (zonk_ tyarg, zonk_ tyret)
| TForAll (var, ty) -> TForAll (var, zonk_ ty)
| TVar _ | TConstr _ -> ty
| TTuple tys -> TTuple (List.map zonk_ tys)
| TIntr _ -> ty
| TApp (src, arg) -> TApp (zonk_ src, zonk_ arg)
| _ -> U.unreachable ("unexpected type in zonk: " ^ T.to_string ty)
let zonk_and_metavars (ctx: context) (tys: TypeSet.t): MetaVarSet.t =
let tys = TypeSet.map (zonk ctx) tys in
(* TODO: use iter/seq *)
List.fold_left MetaVarSet.union MetaVarSet.empty
(List.map get_metavars (TypeSet.to_list tys))
let zonk_and_free_typevars (ctx: context) (tys: TypeSet.t): TypeVarSet.t =
let tys = TypeSet.map (zonk ctx) tys in
(* TODO: use iter/seq *)
List.fold_left TypeVarSet.union TypeVarSet.empty
(List.map (get_free_typevars TypeVarSet.empty)
(TypeSet.to_list tys))
let var_env_get_types (env: var_env): TypeSet.t =
env |> VarEnv.values |> TypeSet.of_iter
let generalize (ctx: context) (rho: rho) (metavar_set: MetaVarSet.t): sigma =
(* TODO: Binder Generation *)
let rec f ty metavars vars =
match vars with
| [] -> ty
| var :: vars -> (
MetaVarEnv.add ctx.metavar_env (CCList.hd metavars) (T.TVar var);
f (T.TForAll (var, ty)) (CCList.tl metavars) vars) in
let metavars = metavar_set |> MetaVarSet.to_list in
let binders = metavars |> List.map Symbol.clone in
zonk ctx @@ f rho metavars binders
let rec unify_var (ctx: context) (var: Symbol.t) (ty: tau): tau =
(* Printf.printf "Unifying MetaVar %s w/ %s ...\n"
* (Symbol.to_string var) (T.to_string ty); *)
match MetaVarEnv.get ctx.metavar_env var with
| Some ty2 -> unify_tau ctx ty ty2
| None ->
match ty with
| TMetaVar var2 -> (
match MetaVarEnv.get ctx.metavar_env var2 with
| Some ty2 -> unify_tau ctx (TMetaVar var) ty2
| None -> MetaVarEnv.add ctx.metavar_env var ty; ty)
(* TODO: Add occurs check *)
| ty -> MetaVarEnv.add ctx.metavar_env var ty; ty
and unify_tau (ctx: context) (lhs: tau) (rhs: tau): tau =
let try_apply_forall () =
match TU.apply_forall ~inside:ctx.inside lhs,
TU.apply_forall ~inside:ctx.inside rhs with
| Some lhs, None -> unify_tau ctx lhs rhs
| None, Some rhs -> unify_tau ctx lhs rhs
| _ -> raise (TcUnificationError (lhs, rhs))
in
let try_dealias () =
match TU.dealias ~inside:ctx.inside lhs,
TU.dealias ~inside:ctx.inside rhs with
| Some lhs, None -> unify_tau ctx lhs rhs
| None, Some rhs -> unify_tau ctx lhs rhs
| _ -> try_apply_forall ()
in
match lhs, rhs with
| TVar vl, TVar vr when Symbol.equal vl vr -> lhs
| TMetaVar vl, TMetaVar vr when Symbol.equal vl vr -> lhs
| TMetaVar vl, rhs -> unify_var ctx vl rhs
| lhs, TMetaVar vr -> unify_var ctx vr lhs
| TFunc (tyargl, tyretl), TFunc (tyargr, tyretr) ->
TFunc (unify_tau ctx tyargl tyargr, unify_tau ctx tyretl tyretr)
| TIntr _, TIntr _ when T.equal lhs rhs -> lhs
| TConstr ctorl, TConstr ctorr when Bind.equal ctorl ctorr -> lhs
| TApp (lhsl, rhsl), TApp (lhsr, rhsr) ->
TApp (unify_tau ctx lhsl lhsr, unify_tau ctx rhsl rhsr)
| TTuple lhs, TTuple rhs when List.length lhs = List.length rhs ->
TTuple (List.map2 (unify_tau ctx) lhs rhs)
| _, _ -> try_dealias ()
let unify_func (ctx: context) (rho: rho): sigma * rho =
match rho with
| TFunc (tyarg, tyret) -> tyarg, tyret
| tau ->
let metavar_arg = Symbol.create None "metavar_arg"
and metavar_ret = Symbol.create None "metavar_ret" in
let tyarg = T.TMetaVar metavar_arg and tyret = T.TMetaVar metavar_ret in
let ret = T.TFunc (tyarg, tyret) in
let _ = unify_tau ctx tau ret in
tyarg, tyret
let rec instantiate ~(inside: D.decl_module) (ty: sigma): rho =
match TU.dealias_and_apply_forall_once ~inside ty with
(* Since our ForAll only supports 1 argument yet, the type is not really
* rho, and we need to recurse into it. *)
| T.TForAll (tyvar, rhoty) ->
let new_metavar = Symbol.create None (Symbol.name tyvar) in
instantiate ~inside (
TU.substitute ~inside (TU.SubstEnv.singleton tyvar
(T.TMetaVar new_metavar)) rhoty)
| _ -> ty
let rec subsumption_check (ctx: context) (sigmal: sigma)
(sigmar: sigma): sigma =
let skol_tvars, rhor = skolemise ~inside:ctx.inside sigmar in
let check = subsumption_check_rho ctx sigmal rhor in
(* TODO: Add free type vars check *)
check
and subsumption_check_rho (ctx: context) (s: sigma) (r: rho): rho =
match s, r with
| T.TForAll _, _ ->
subsumption_check_rho ctx (instantiate ~inside:ctx.inside s) r
| rho1, (T.TFunc (tyarg, tyret)) ->
let tyarg_, tyret_ = unify_func ctx rho1 in
T.TFunc (subsumption_check ctx tyarg tyarg_,
subsumption_check_rho ctx tyret_ tyret)
| (T.TFunc (tyarg, tyret)) as rho1, rho2 ->
subsumption_check_rho ctx rho2 rho1
| _, _ (* Tau Types *) -> unify_tau ctx s r
let inst_sigma_check (ctx: context) (s: sigma) (r: rho): rho =
subsumption_check_rho ctx s r
let inst_sigma_infer (ctx: context) (s: sigma): rho =
instantiate ~inside:ctx.inside s
let resolve_var_type (ctx: context) (ident: Symbol.t): T.t =
match VarEnv.find_opt ctx.var_env ident with
| Some ty -> ty
| None ->
match Modular.resolve_global_symbol ~inside:ctx.inside ident with
| Some (D.DeclValue decl) -> decl.val_type
| _ -> U.unreachable "unexpected resolve failure in resolve_var_type"
(* TODO: We currently do this very lazily and just return some random type,
* what we really want to do is returning a placeholder type (or some
* "typeclass" constrainted type) and unify later *)
let infer_literal_type (src: Literal.t): T.t =
match src with
| LBool _ -> T.TIntr Intrinsic_type.Bool
| LNum (Literal_num.NlitU32 _) -> T.TIntr (Intrinsic_type.Int {
signed = Unsigned; width = 32;
})
| LString _ -> T.TIntr Intrinsic_type.Intr_string
| _ ->
U.unreachable "unexpected literal in infer_literal_type"
let rec tc_match (ctx: context) (expr: S.expr)
(expr_ty: rho option): S.expr * T.t =
let tc_match_case ctx (case: S.match_case)
pattern_ty expr_ty: S.match_case * T.t =
let match_case_node, annot = case in
let return ?(match_case = match_case_node) ty =
(match_case, annot), ty in
match match_case_node with
| S.MatchCase { pattern; value; } ->
let pattern, ty = check_pattern ctx pattern pattern_ty in
let new_bindings = Semutil.pattern_new_bindings pattern in
let _ = SP.iter (VarEnv.add ctx.var_env) new_bindings in
let value, ty = match expr_ty with
| Some ty -> check_rho ctx value ty
| None -> infer_rho ctx value in
return ~match_case:(S.MatchCase { pattern; value }) ty
in
let (expr_node, ty), annot = expr in
let return ?(expr = expr_node) ty = ((expr, ty), annot), ty in
(* let env_ty_set = var_env_get_types ctx.var_env in
* let env_metavar_set = zonk_and_metavars ctx env_ty_set in *)
match expr_node with
| ExprMatch { value; cases; } ->
let value, value_ty = infer_rho ctx value in (
match cases with
| hd :: rest ->
let head, head_ty = tc_match_case ctx hd value_ty expr_ty in
let cases, ty = List.fold_left (
fun (cases, last_ty) next ->
let next, next_ty = tc_match_case ctx next value_ty expr_ty in
next :: cases, unify_tau ctx last_ty next_ty
) ([head], head_ty) rest in
return ~expr:(ExprMatch { value; cases = List.rev cases; }) ty
| _ -> U.unreachable "match must have at least one case")
| _ -> U.unreachable "unexpected expression in tc_match"
and check_pattern (ctx: context) (pattern: S.pattern)
(pattern_ty: rho): S.pattern * T.t =
let (pattern_node, ty), annot = pattern in
let return ?(pattern = pattern_node) ty = ((pattern, ty), annot), ty in
match pattern_node with
| PLiteral { value } ->
return @@ unify_tau ctx (infer_literal_type value) pattern_ty
| PVar { name } ->
return @@ pattern_ty
| PConstrApp { bound; args } ->
tc_constructor_pattern ctx pattern (Some pattern_ty)
and tc_constructor_pattern (ctx: context) (pattern: S.pattern)
(pattern_ty: rho option): S.pattern * T.t =
let (pattern_node, ty), annot = pattern in
let return ?(pattern = pattern_node) ty = ((pattern, ty), annot), ty in
match pattern_node with
| PConstrApp { bound; args } ->
let constr_ty = bound.val_type in
let constr_ty = inst_sigma_infer ctx constr_ty in
let args, ret_ty = List.fold_left (fun (ret, ctor_ty) arg ->
match ctor_ty with
| T.TFunc (tyarg, tyret) -> check_pattern ctx arg tyarg :: ret, tyret
| _ -> U.unreachable "invalid data constructor pattern"
) ([], constr_ty) args in
let ret_ty = CCOpt.map_or ~default:ret_ty
(unify_tau ctx ret_ty) pattern_ty in
return ~pattern:(PConstrApp {
bound = bound; args = List.map fst @@ List.rev args;
}) ret_ty
| _ -> U.unreachable "unexpected pattern in tc_constructor_pattern"
and infer_pattern (ctx: context) (pattern: S.pattern): S.pattern * T.t =
let (pattern_node, ty), annot = pattern in
let return ?(pattern = pattern_node) ty = ((pattern, ty), annot), ty in
match pattern_node with
| PLiteral { value } ->
return @@ infer_literal_type value
| PVar { name } ->
let metavar = Symbol.clone name in
return @@ T.TMetaVar metavar
| PConstrApp { bound; args } ->
tc_constructor_pattern ctx pattern None
and tc_let (ctx: context) (expr: S.expr)
(expr_ty: rho option): S.expr * T.t =
let (expr_node, ty), annot = expr in
let return ?(expr = expr_node) ty = ((expr, ty), annot), ty in
let env_ty_set = var_env_get_types ctx.var_env in
let env_metavar_set = zonk_and_metavars ctx env_ty_set in
match expr_node with
| ExprLet { name; value; scope; tyannot; typeargs; args; name_ty } ->
let f (S.Arg { name; tyannot }, annot) =
let ty =
match tyannot with
| TPlaceholder -> T.TMetaVar (Symbol.clone name)
| ty -> ty in
let _ = VarEnv.add ctx.var_env name ty in
(S.Arg { name; tyannot = ty }, annot), ty
in
let get_var_ty value_ty arg_types =
let var_ty = List.fold_right
(fun arg last -> T.TFunc (arg, last))
arg_types value_ty in
let var_ty = List.fold_right
(fun arg last -> T.TForAll (arg, last))
(List.map (fun (S.TypeArg { name }, _) -> name) typeargs) var_ty in
var_ty
in
let args_and_types = List.map f args in
let value_ty, var_ty_metavar = match tyannot with
| TPlaceholder ->
let metavar = Symbol.clone (fst name) in
T.TMetaVar metavar, Some metavar
| ty -> ty, None in
(if not (CCList.is_empty args) then
VarEnv.add ctx.var_env (fst name)
(get_var_ty value_ty (List.map snd args_and_types)));
let value, new_value_ty =
match tyannot with
| TPlaceholder -> infer_sigma ctx value
| tyannot -> check_sigma ctx value tyannot in (
let _ = CCOpt.map
(fun metavar -> MetaVarEnv.add ctx.metavar_env metavar new_value_ty)
var_ty_metavar in
let var_ty = get_var_ty new_value_ty (List.map snd args_and_types) in
let res_metavar_set = zonk_and_metavars ctx (TypeSet.singleton var_ty) in
let var_ty =
generalize ctx var_ty
(MetaVarSet.diff res_metavar_set env_metavar_set) in
VarEnv.add ctx.var_env (fst name) var_ty;
let scope, ty = match expr_ty with
| Some expr_ty -> check_rho ctx scope expr_ty
| None -> infer_rho ctx scope in
return ~expr:(ExprLet { name; value; scope; tyannot = value_ty;
typeargs;
args = List.map fst args_and_types;
name_ty = var_ty }) ty)
| _ -> U.unreachable "unexpected expression in tc_let"
(* TODO: Get rid of the duplication *)
(* TODO: This functions are tightly coupled with the annotation data
* structures of Semtree.expr *)
and check_rho (ctx: context) (expr: S.expr) (expr_ty: rho): S.expr * T.t =
let (expr_node, ty), annot = expr in
let return ?(expr = expr_node) ty = ((expr, ty), annot), ty in
match expr_node with
| ExprConst { value } ->
return @@ unify_tau ctx (infer_literal_type value) expr_ty
| ExprIdent { value } ->
let ty = resolve_var_type ctx (fst value) in
return @@ inst_sigma_check ctx ty expr_ty
| ExprApply { func; arg } ->
let func, func_ty = infer_rho ctx func in
let tyarg, tyret = unify_func ctx func_ty in
let arg, _ = check_sigma ctx arg tyarg in
let ty = inst_sigma_check ctx tyret expr_ty in
return ~expr:(ExprApply { func; arg }) ty
| ExprLet _ ->
tc_let ctx expr (Some expr_ty)
| ExprIf { cond; expr_then; expr_else } ->
let cond, _ = check_rho ctx cond (T.TIntr Intrinsic_type.Bool) in
let expr_then, tythen = check_rho ctx expr_then expr_ty in
let expr_else, tyelse = check_rho ctx expr_else expr_ty in
return ~expr:(ExprIf { cond; expr_then; expr_else }) @@
unify_tau ctx tythen tyelse
| ExprTuple { value } ->
let expr, inferred_ty = infer_rho ctx expr in
let inferred_ty = unify_tau ctx inferred_ty expr_ty in
expr, inferred_ty
| ExprApplyOpInfix { func; lhs; rhs } ->
let expr, ty =
check_rho ctx (transform_expr_apply_op_infix expr) expr_ty in
restore_expr_apply_op_infix expr, ty
| ExprTyAnnot { value; tyannot } ->
let value, tyannot = check_sigma ctx value tyannot in
let tyannot = inst_sigma_check ctx tyannot expr_ty in
return ~expr:(ExprTyAnnot { value; tyannot }) tyannot
| ExprMatch _ -> tc_match ctx expr (Some expr_ty)
| ExprPlaceholder -> U.unreachable "unexpected expression in check_rho"
and infer_rho (ctx: context) (expr: S.expr): S.expr * T.t =
let (expr_node, ty), annot = expr in
let return ?(expr = expr_node) ty = ((expr, ty), annot), ty in
match expr_node with
| ExprConst { value } ->
return @@ infer_literal_type value
| ExprIdent { value } ->
let ty = resolve_var_type ctx (fst value) in
return @@ inst_sigma_infer ctx ty
| ExprApply { func; arg } ->
let func, func_ty = infer_rho ctx func in
let tyarg, tyret = unify_func ctx func_ty in
let arg, _ = check_sigma ctx arg tyarg in
let ty = inst_sigma_infer ctx tyret in
return ~expr:(ExprApply { func; arg }) ty
| ExprLet _ ->
tc_let ctx expr None
| ExprIf { cond; expr_then; expr_else } ->
(* P60 of the paper *)
let cond, _ = check_rho ctx cond (T.TIntr Intrinsic_type.Bool) in
let expr_then, tythen = infer_rho ctx expr_then and
expr_else, tyelse = infer_rho ctx expr_else in
let _ = subsumption_check ctx tythen tyelse and
_ = subsumption_check ctx tyelse tythen in
return ~expr:(ExprIf { cond; expr_then; expr_else }) @@
unify_tau ctx tythen tyelse
| ExprTuple { value } ->
let values = List.map (infer_rho ctx) value in
return ~expr:(ExprTuple { value = List.map fst values }) @@
T.TTuple (List.map snd values)
| ExprApplyOpInfix { func; lhs; rhs } ->
let expr, ty = infer_rho ctx (transform_expr_apply_op_infix expr) in
restore_expr_apply_op_infix expr, ty
| ExprTyAnnot { value; tyannot } ->
let value, tyannot = check_sigma ctx value tyannot in
let tyannot = inst_sigma_infer ctx tyannot in
return ~expr:(ExprTyAnnot { value; tyannot }) tyannot
| ExprMatch _ -> tc_match ctx expr None
| ExprPlaceholder -> U.unreachable "unexpected expression in infer_rho"
and check_sigma (ctx: context) (expr: S.expr) (sigma: sigma): S.expr * sigma =
let skol_typevar_set, rho = skolemise ~inside:ctx.inside sigma in
let ret = check_rho ctx expr rho in
(* TODO: Right below - Ensure no type variable escaped to environment *)
let env_ty_set = var_env_get_types ctx.var_env in
let esc_typevar_set =
zonk_and_free_typevars ctx (TypeSet.add sigma env_ty_set) in
let bad_typevars = TypeVarSet.inter skol_typevar_set esc_typevar_set in
if TypeVarSet.is_empty bad_typevars
then ret else raise (TcError ("Type " ^ (T.to_string sigma) ^
" not polymorphic enough"))
(* TODO: This function is useless ATM *)
and infer_sigma (ctx: context) (expr: S.expr): S.expr * sigma =
let env_ty_set = var_env_get_types ctx.var_env
and expr, expr_ty = infer_rho ctx expr in
let (expr_node, ty), annot = expr
and env_metavar_set = zonk_and_metavars ctx env_ty_set
and res_metavar_set = zonk_and_metavars ctx (TypeSet.singleton expr_ty) in
let ret_ty =
generalize ctx expr_ty
(MetaVarSet.diff res_metavar_set env_metavar_set) in
((expr_node, ret_ty), annot), ret_ty
(* TODO: We need a map function for expr *)
let rec zonk_expr (ctx: context) (expr: S.expr): S.expr =
let zonk_expr_ = zonk_expr ctx
and zonk_ = zonk ctx
and (expr_node, ty), annot = expr in
let return ?(expr = expr_node) () = ((expr, zonk ctx ty), annot) in
match expr_node with
| ExprConst _ -> return ()
| ExprIdent _ -> return ()
| ExprApply { func; arg } ->
return ~expr:(ExprApply {
func = zonk_expr_ func; arg = zonk_expr_ arg }) ()
| ExprLet { name; value; scope; typeargs; args; tyannot; name_ty } ->
let zonk_arg arg = S.(
match arg with
| Arg arg, annot -> Arg {
arg with tyannot = zonk_ arg.tyannot }, annot) in
return ~expr:(ExprLet {
name; value = zonk_expr_ value; scope = zonk_expr_ scope;
tyannot = zonk_ tyannot; typeargs; args = List.map zonk_arg args;
name_ty = zonk_ name_ty }) ()
| ExprIf { cond; expr_then; expr_else } ->
return ~expr:(ExprIf {
cond = zonk_expr_ cond; expr_then = zonk_expr_ expr_then;
expr_else = zonk_expr_ expr_else }) ()
| ExprTuple { value } ->
return ~expr:(ExprTuple { value = List.map zonk_expr_ value }) ()
| ExprApplyOpInfix { func; lhs; rhs } ->
return ~expr:(ExprApplyOpInfix {
func = zonk_expr_ func; lhs = zonk_expr_ lhs;
rhs = zonk_expr_ rhs }) ()
| ExprTyAnnot { value; tyannot } -> return ~expr:(ExprTyAnnot {
value = zonk_expr_ value; tyannot = zonk_ tyannot;
}) ()
| ExprMatch { value; cases } -> return ~expr:(ExprMatch {
value = zonk_expr_ value; cases = List.map (zonk_match_case ctx) cases;
}) ()
| ExprPlaceholder -> U.unreachable "unexpected expression in zonk_expr"
and zonk_match_case (ctx: context) (match_case: S.match_case): S.match_case =
let (S.MatchCase { pattern; value }), annot = match_case in
MatchCase {
pattern = zonk_pattern ctx pattern;
value = zonk_expr ctx value;
}, annot
and zonk_pattern (ctx: context) (pattern: S.pattern): S.pattern =
let zonk_pattern_ = zonk_pattern ctx
and (pattern_node, ty), annot = pattern in
let return ?(pattern = pattern_node) () = ((pattern, zonk ctx ty), annot) in
match pattern_node with
| PLiteral _ -> return ()
| PVar _ -> return ()
| PConstrApp { bound; args } -> return ~pattern:(PConstrApp {
bound; args = List.map zonk_pattern_ args }) ()
end
(* let rec trans_expr ~(inside: D.decl_module)
* (symtab: T.t SP.t) (stree: S.expr): S.expr = *)
let rec get_decl_return_type (args: D.decl_arg list) (ty: T.t): T.t =
match args, ty with
| _, T.TForAll (var, ty) -> get_decl_return_type args ty
| arg :: args, T.TFunc (tyarg, tyret) -> get_decl_return_type args tyret
| [], ty -> ty
| _ -> U.unreachable (
Printf.sprintf "unexpected input in get_decl_return_type. Type: %s"
(T.to_string ty))
let rec trans_top ~(inside: D.decl_module) (src: S.top): unit =
match src with
| S.TopLevel { value }, _ -> trans_mod ~inside value
and trans_mod ~(inside: D.decl_module) (src: S.modu): unit =
match src with
| S.ModuleFrag { value; bound }, _ ->
List.iter (trans_decl ~inside:bound) value
and trans_decl ~(inside: D.decl_module) (src: S.decl): unit =
match src with
| S.DeclModule modu, _ -> trans_mod ~inside modu
| S.DeclValue src, _ ->
let expr = trans_decl_value ~inside src.value src.bound in
src.value <- expr
| _ -> ()
and trans_decl_value ~(inside: D.decl_module)
(src: S.expr) (decl: D.decl_value): S.expr =
(* let symtab = decl.args |>
* List.map (fun (D.{ name; val_type }: D.decl_arg) -> name, val_type) |>
* SP.of_list *)
let context = Typechecker.context_create ~inside
and args_list =
List.map (fun (D.{ name; val_type }) -> name, val_type) decl.args in
let _ = Typechecker.VarEnv.add_iter context.var_env
(CCList.to_iter args_list) in
let _ = Printf.printf "Typechecking %s ...\n" (Symbol.show decl.name) in
Typechecker.zonk_expr context @@ fst @@ Typechecker.check_sigma
context src (get_decl_return_type decl.args decl.val_type)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment