Last active
February 23, 2020 16:58
-
-
Save secondwtq/a2d2a220df790f66569c2091b6e8e954 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
(* Implementation of Local Bidirectional Type Inference | |
* Based on: | |
* Jones, Simon L. Peyton, Dimitm. | |
*) | |
(* | |
* Tau Type - Monotype | |
* Rho Type - Tau | Sigma Sigma | |
* (the second form of Rho is absent in Rank-1 system) | |
* Sigma Type - Polytype (forall a. Rho) | |
*) | |
module D = Decl | |
module S = Semtree | |
module T = Types | |
module SP = Symtab.Pure | |
module U = Util | |
module M = Modular | |
module TU = Type_util | |
let transform_expr_apply_op_infix (src: S.expr): S.expr = | |
match src with | |
(* The type here is meaningless *) | |
| (ExprApplyOpInfix { func; lhs; rhs }, t), annot -> ( | |
ExprApply { func = ( | |
(ExprApply { func; arg = lhs; }, t), annot); arg = rhs }, t), annot | |
| _ -> U.unreachable "unexpected expression in transform_expr_apply_op_infix" | |
let restore_expr_apply_op_infix (src: S.expr): S.expr = | |
match src with | |
| (ExprApply { | |
func = (ExprApply { func; arg = lhs }, tyinter), annotinter; | |
arg = rhs }, tyexpr), annot -> | |
(ExprApplyOpInfix { func; lhs; rhs }, tyexpr), annot | |
| _ -> U.unreachable "unexpected expression in restore_expr_apply_op_infix" | |
module Typechecker = struct | |
type tau = T.t | |
type rho = T.t | |
type sigma = T.t | |
type kind = | Infer | Check | |
module MetaVarEnv = Symtab.Symap.Nonpure | |
type metavar_env = tau MetaVarEnv.t | |
module VarEnv = Symtab.Symap.Nonpure | |
type var_env = sigma VarEnv.t | |
module SkolEnv = Symtab.Symap.Nonpure | |
type skol_env = sigma SkolEnv.t | |
module TypeSet = CCSet.Make(T) | |
module MetaVarSet = CCSet.Make(Symbol) | |
module TypeVarSet = CCSet.Make(Symbol) | |
type context = { | |
inside: D.decl_module; | |
metavar_env: metavar_env; | |
var_env: var_env; | |
} | |
let context_create ~(inside: D.decl_module): context = { | |
inside; | |
metavar_env = MetaVarEnv.create 31; | |
var_env = VarEnv.create 31; | |
} | |
exception TcError of string | |
exception TcUnificationError of T.t * T.t | |
let rec get_metavars (ty: T.t): MetaVarSet.t = | |
match ty with | |
| TMetaVar var -> MetaVarSet.singleton var | |
| TFunc (tyarg, tyret) -> MetaVarSet.union (get_metavars tyarg) | |
(get_metavars tyret) | |
| TForAll (_, ty) -> get_metavars ty | |
| TIntr _ | TVar _ | TConstr _ -> MetaVarSet.empty | |
| TTuple tys -> List.fold_left MetaVarSet.union MetaVarSet.empty | |
(List.map get_metavars tys) | |
| TApp (src, arg) -> MetaVarSet.union (get_metavars src) | |
(get_metavars arg) | |
| _ -> U.unreachable ("unexpected type in get_metavars " ^ (T.to_string ty)) | |
let rec get_free_typevars (bound: TypeVarSet.t) (ty: T.t): TypeVarSet.t = | |
match ty with | |
| TVar var -> if TypeVarSet.mem var bound | |
then TypeVarSet.empty else TypeVarSet.singleton var | |
| TForAll (var, ty) -> get_free_typevars (TypeVarSet.add var bound) ty | |
| TFunc (tyarg, tyret) -> TypeVarSet.union (get_free_typevars bound tyarg) | |
(get_free_typevars bound tyret) | |
| TConstr _ | TMetaVar _ | TIntr _ -> TypeVarSet.empty | |
| TTuple tys -> List.fold_left TypeVarSet.union TypeVarSet.empty | |
(List.map (get_free_typevars bound) tys) | |
| TApp (src, arg) -> TypeVarSet.union (get_free_typevars bound src) | |
(get_free_typevars bound arg) | |
| _ -> U.unreachable "unexpected type in get_free_typevars" | |
let rec skolemise ~inside (sigma: sigma): TypeVarSet.t * rho = | |
match TU.full_dealias_and_apply_forall ~inside sigma with | |
| TForAll (var, ty) -> | |
let skol_typevar = Symbol.create None (Symbol.name var) in | |
let skol_typevar_set, rho = skolemise ~inside | |
(TU.substitute ~inside | |
(TU.SubstEnv.singleton var (T.TVar skol_typevar)) ty) in | |
TypeVarSet.add skol_typevar skol_typevar_set, rho | |
| TFunc (tyarg, tyret) -> | |
let skol_typevar_set, rhoret = skolemise ~inside tyret in | |
skol_typevar_set, TFunc (tyarg, rhoret) | |
| _ -> TypeVarSet.empty, sigma | |
let rec zonk (ctx: context) (ty: T.t): T.t = | |
let zonk_ = zonk ctx in | |
match ty with | |
| TMetaVar var -> ( | |
match MetaVarEnv.get ctx.metavar_env var with | |
| Some ty -> zonk_ ty | None -> ty) | |
| TFunc (tyarg, tyret) -> TFunc (zonk_ tyarg, zonk_ tyret) | |
| TForAll (var, ty) -> TForAll (var, zonk_ ty) | |
| TVar _ | TConstr _ -> ty | |
| TTuple tys -> TTuple (List.map zonk_ tys) | |
| TIntr _ -> ty | |
| TApp (src, arg) -> TApp (zonk_ src, zonk_ arg) | |
| _ -> U.unreachable ("unexpected type in zonk: " ^ T.to_string ty) | |
let zonk_and_metavars (ctx: context) (tys: TypeSet.t): MetaVarSet.t = | |
let tys = TypeSet.map (zonk ctx) tys in | |
(* TODO: use iter/seq *) | |
List.fold_left MetaVarSet.union MetaVarSet.empty | |
(List.map get_metavars (TypeSet.to_list tys)) | |
let zonk_and_free_typevars (ctx: context) (tys: TypeSet.t): TypeVarSet.t = | |
let tys = TypeSet.map (zonk ctx) tys in | |
(* TODO: use iter/seq *) | |
List.fold_left TypeVarSet.union TypeVarSet.empty | |
(List.map (get_free_typevars TypeVarSet.empty) | |
(TypeSet.to_list tys)) | |
let var_env_get_types (env: var_env): TypeSet.t = | |
env |> VarEnv.values |> TypeSet.of_iter | |
let generalize (ctx: context) (rho: rho) (metavar_set: MetaVarSet.t): sigma = | |
(* TODO: Binder Generation *) | |
let rec f ty metavars vars = | |
match vars with | |
| [] -> ty | |
| var :: vars -> ( | |
MetaVarEnv.add ctx.metavar_env (CCList.hd metavars) (T.TVar var); | |
f (T.TForAll (var, ty)) (CCList.tl metavars) vars) in | |
let metavars = metavar_set |> MetaVarSet.to_list in | |
let binders = metavars |> List.map Symbol.clone in | |
zonk ctx @@ f rho metavars binders | |
let rec unify_var (ctx: context) (var: Symbol.t) (ty: tau): tau = | |
(* Printf.printf "Unifying MetaVar %s w/ %s ...\n" | |
* (Symbol.to_string var) (T.to_string ty); *) | |
match MetaVarEnv.get ctx.metavar_env var with | |
| Some ty2 -> unify_tau ctx ty ty2 | |
| None -> | |
match ty with | |
| TMetaVar var2 -> ( | |
match MetaVarEnv.get ctx.metavar_env var2 with | |
| Some ty2 -> unify_tau ctx (TMetaVar var) ty2 | |
| None -> MetaVarEnv.add ctx.metavar_env var ty; ty) | |
(* TODO: Add occurs check *) | |
| ty -> MetaVarEnv.add ctx.metavar_env var ty; ty | |
and unify_tau (ctx: context) (lhs: tau) (rhs: tau): tau = | |
let try_apply_forall () = | |
match TU.apply_forall ~inside:ctx.inside lhs, | |
TU.apply_forall ~inside:ctx.inside rhs with | |
| Some lhs, None -> unify_tau ctx lhs rhs | |
| None, Some rhs -> unify_tau ctx lhs rhs | |
| _ -> raise (TcUnificationError (lhs, rhs)) | |
in | |
let try_dealias () = | |
match TU.dealias ~inside:ctx.inside lhs, | |
TU.dealias ~inside:ctx.inside rhs with | |
| Some lhs, None -> unify_tau ctx lhs rhs | |
| None, Some rhs -> unify_tau ctx lhs rhs | |
| _ -> try_apply_forall () | |
in | |
match lhs, rhs with | |
| TVar vl, TVar vr when Symbol.equal vl vr -> lhs | |
| TMetaVar vl, TMetaVar vr when Symbol.equal vl vr -> lhs | |
| TMetaVar vl, rhs -> unify_var ctx vl rhs | |
| lhs, TMetaVar vr -> unify_var ctx vr lhs | |
| TFunc (tyargl, tyretl), TFunc (tyargr, tyretr) -> | |
TFunc (unify_tau ctx tyargl tyargr, unify_tau ctx tyretl tyretr) | |
| TIntr _, TIntr _ when T.equal lhs rhs -> lhs | |
| TConstr ctorl, TConstr ctorr when Bind.equal ctorl ctorr -> lhs | |
| TApp (lhsl, rhsl), TApp (lhsr, rhsr) -> | |
TApp (unify_tau ctx lhsl lhsr, unify_tau ctx rhsl rhsr) | |
| TTuple lhs, TTuple rhs when List.length lhs = List.length rhs -> | |
TTuple (List.map2 (unify_tau ctx) lhs rhs) | |
| _, _ -> try_dealias () | |
let unify_func (ctx: context) (rho: rho): sigma * rho = | |
match rho with | |
| TFunc (tyarg, tyret) -> tyarg, tyret | |
| tau -> | |
let metavar_arg = Symbol.create None "metavar_arg" | |
and metavar_ret = Symbol.create None "metavar_ret" in | |
let tyarg = T.TMetaVar metavar_arg and tyret = T.TMetaVar metavar_ret in | |
let ret = T.TFunc (tyarg, tyret) in | |
let _ = unify_tau ctx tau ret in | |
tyarg, tyret | |
let rec instantiate ~(inside: D.decl_module) (ty: sigma): rho = | |
match TU.dealias_and_apply_forall_once ~inside ty with | |
(* Since our ForAll only supports 1 argument yet, the type is not really | |
* rho, and we need to recurse into it. *) | |
| T.TForAll (tyvar, rhoty) -> | |
let new_metavar = Symbol.create None (Symbol.name tyvar) in | |
instantiate ~inside ( | |
TU.substitute ~inside (TU.SubstEnv.singleton tyvar | |
(T.TMetaVar new_metavar)) rhoty) | |
| _ -> ty | |
let rec subsumption_check (ctx: context) (sigmal: sigma) | |
(sigmar: sigma): sigma = | |
let skol_tvars, rhor = skolemise ~inside:ctx.inside sigmar in | |
let check = subsumption_check_rho ctx sigmal rhor in | |
(* TODO: Add free type vars check *) | |
check | |
and subsumption_check_rho (ctx: context) (s: sigma) (r: rho): rho = | |
match s, r with | |
| T.TForAll _, _ -> | |
subsumption_check_rho ctx (instantiate ~inside:ctx.inside s) r | |
| rho1, (T.TFunc (tyarg, tyret)) -> | |
let tyarg_, tyret_ = unify_func ctx rho1 in | |
T.TFunc (subsumption_check ctx tyarg tyarg_, | |
subsumption_check_rho ctx tyret_ tyret) | |
| (T.TFunc (tyarg, tyret)) as rho1, rho2 -> | |
subsumption_check_rho ctx rho2 rho1 | |
| _, _ (* Tau Types *) -> unify_tau ctx s r | |
let inst_sigma_check (ctx: context) (s: sigma) (r: rho): rho = | |
subsumption_check_rho ctx s r | |
let inst_sigma_infer (ctx: context) (s: sigma): rho = | |
instantiate ~inside:ctx.inside s | |
let resolve_var_type (ctx: context) (ident: Symbol.t): T.t = | |
match VarEnv.find_opt ctx.var_env ident with | |
| Some ty -> ty | |
| None -> | |
match Modular.resolve_global_symbol ~inside:ctx.inside ident with | |
| Some (D.DeclValue decl) -> decl.val_type | |
| _ -> U.unreachable "unexpected resolve failure in resolve_var_type" | |
(* TODO: We currently do this very lazily and just return some random type, | |
* what we really want to do is returning a placeholder type (or some | |
* "typeclass" constrainted type) and unify later *) | |
let infer_literal_type (src: Literal.t): T.t = | |
match src with | |
| LBool _ -> T.TIntr Intrinsic_type.Bool | |
| LNum (Literal_num.NlitU32 _) -> T.TIntr (Intrinsic_type.Int { | |
signed = Unsigned; width = 32; | |
}) | |
| LString _ -> T.TIntr Intrinsic_type.Intr_string | |
| _ -> | |
U.unreachable "unexpected literal in infer_literal_type" | |
let rec tc_match (ctx: context) (expr: S.expr) | |
(expr_ty: rho option): S.expr * T.t = | |
let tc_match_case ctx (case: S.match_case) | |
pattern_ty expr_ty: S.match_case * T.t = | |
let match_case_node, annot = case in | |
let return ?(match_case = match_case_node) ty = | |
(match_case, annot), ty in | |
match match_case_node with | |
| S.MatchCase { pattern; value; } -> | |
let pattern, ty = check_pattern ctx pattern pattern_ty in | |
let new_bindings = Semutil.pattern_new_bindings pattern in | |
let _ = SP.iter (VarEnv.add ctx.var_env) new_bindings in | |
let value, ty = match expr_ty with | |
| Some ty -> check_rho ctx value ty | |
| None -> infer_rho ctx value in | |
return ~match_case:(S.MatchCase { pattern; value }) ty | |
in | |
let (expr_node, ty), annot = expr in | |
let return ?(expr = expr_node) ty = ((expr, ty), annot), ty in | |
(* let env_ty_set = var_env_get_types ctx.var_env in | |
* let env_metavar_set = zonk_and_metavars ctx env_ty_set in *) | |
match expr_node with | |
| ExprMatch { value; cases; } -> | |
let value, value_ty = infer_rho ctx value in ( | |
match cases with | |
| hd :: rest -> | |
let head, head_ty = tc_match_case ctx hd value_ty expr_ty in | |
let cases, ty = List.fold_left ( | |
fun (cases, last_ty) next -> | |
let next, next_ty = tc_match_case ctx next value_ty expr_ty in | |
next :: cases, unify_tau ctx last_ty next_ty | |
) ([head], head_ty) rest in | |
return ~expr:(ExprMatch { value; cases = List.rev cases; }) ty | |
| _ -> U.unreachable "match must have at least one case") | |
| _ -> U.unreachable "unexpected expression in tc_match" | |
and check_pattern (ctx: context) (pattern: S.pattern) | |
(pattern_ty: rho): S.pattern * T.t = | |
let (pattern_node, ty), annot = pattern in | |
let return ?(pattern = pattern_node) ty = ((pattern, ty), annot), ty in | |
match pattern_node with | |
| PLiteral { value } -> | |
return @@ unify_tau ctx (infer_literal_type value) pattern_ty | |
| PVar { name } -> | |
return @@ pattern_ty | |
| PConstrApp { bound; args } -> | |
tc_constructor_pattern ctx pattern (Some pattern_ty) | |
and tc_constructor_pattern (ctx: context) (pattern: S.pattern) | |
(pattern_ty: rho option): S.pattern * T.t = | |
let (pattern_node, ty), annot = pattern in | |
let return ?(pattern = pattern_node) ty = ((pattern, ty), annot), ty in | |
match pattern_node with | |
| PConstrApp { bound; args } -> | |
let constr_ty = bound.val_type in | |
let constr_ty = inst_sigma_infer ctx constr_ty in | |
let args, ret_ty = List.fold_left (fun (ret, ctor_ty) arg -> | |
match ctor_ty with | |
| T.TFunc (tyarg, tyret) -> check_pattern ctx arg tyarg :: ret, tyret | |
| _ -> U.unreachable "invalid data constructor pattern" | |
) ([], constr_ty) args in | |
let ret_ty = CCOpt.map_or ~default:ret_ty | |
(unify_tau ctx ret_ty) pattern_ty in | |
return ~pattern:(PConstrApp { | |
bound = bound; args = List.map fst @@ List.rev args; | |
}) ret_ty | |
| _ -> U.unreachable "unexpected pattern in tc_constructor_pattern" | |
and infer_pattern (ctx: context) (pattern: S.pattern): S.pattern * T.t = | |
let (pattern_node, ty), annot = pattern in | |
let return ?(pattern = pattern_node) ty = ((pattern, ty), annot), ty in | |
match pattern_node with | |
| PLiteral { value } -> | |
return @@ infer_literal_type value | |
| PVar { name } -> | |
let metavar = Symbol.clone name in | |
return @@ T.TMetaVar metavar | |
| PConstrApp { bound; args } -> | |
tc_constructor_pattern ctx pattern None | |
and tc_let (ctx: context) (expr: S.expr) | |
(expr_ty: rho option): S.expr * T.t = | |
let (expr_node, ty), annot = expr in | |
let return ?(expr = expr_node) ty = ((expr, ty), annot), ty in | |
let env_ty_set = var_env_get_types ctx.var_env in | |
let env_metavar_set = zonk_and_metavars ctx env_ty_set in | |
match expr_node with | |
| ExprLet { name; value; scope; tyannot; typeargs; args; name_ty } -> | |
let f (S.Arg { name; tyannot }, annot) = | |
let ty = | |
match tyannot with | |
| TPlaceholder -> T.TMetaVar (Symbol.clone name) | |
| ty -> ty in | |
let _ = VarEnv.add ctx.var_env name ty in | |
(S.Arg { name; tyannot = ty }, annot), ty | |
in | |
let get_var_ty value_ty arg_types = | |
let var_ty = List.fold_right | |
(fun arg last -> T.TFunc (arg, last)) | |
arg_types value_ty in | |
let var_ty = List.fold_right | |
(fun arg last -> T.TForAll (arg, last)) | |
(List.map (fun (S.TypeArg { name }, _) -> name) typeargs) var_ty in | |
var_ty | |
in | |
let args_and_types = List.map f args in | |
let value_ty, var_ty_metavar = match tyannot with | |
| TPlaceholder -> | |
let metavar = Symbol.clone (fst name) in | |
T.TMetaVar metavar, Some metavar | |
| ty -> ty, None in | |
(if not (CCList.is_empty args) then | |
VarEnv.add ctx.var_env (fst name) | |
(get_var_ty value_ty (List.map snd args_and_types))); | |
let value, new_value_ty = | |
match tyannot with | |
| TPlaceholder -> infer_sigma ctx value | |
| tyannot -> check_sigma ctx value tyannot in ( | |
let _ = CCOpt.map | |
(fun metavar -> MetaVarEnv.add ctx.metavar_env metavar new_value_ty) | |
var_ty_metavar in | |
let var_ty = get_var_ty new_value_ty (List.map snd args_and_types) in | |
let res_metavar_set = zonk_and_metavars ctx (TypeSet.singleton var_ty) in | |
let var_ty = | |
generalize ctx var_ty | |
(MetaVarSet.diff res_metavar_set env_metavar_set) in | |
VarEnv.add ctx.var_env (fst name) var_ty; | |
let scope, ty = match expr_ty with | |
| Some expr_ty -> check_rho ctx scope expr_ty | |
| None -> infer_rho ctx scope in | |
return ~expr:(ExprLet { name; value; scope; tyannot = value_ty; | |
typeargs; | |
args = List.map fst args_and_types; | |
name_ty = var_ty }) ty) | |
| _ -> U.unreachable "unexpected expression in tc_let" | |
(* TODO: Get rid of the duplication *) | |
(* TODO: This functions are tightly coupled with the annotation data | |
* structures of Semtree.expr *) | |
and check_rho (ctx: context) (expr: S.expr) (expr_ty: rho): S.expr * T.t = | |
let (expr_node, ty), annot = expr in | |
let return ?(expr = expr_node) ty = ((expr, ty), annot), ty in | |
match expr_node with | |
| ExprConst { value } -> | |
return @@ unify_tau ctx (infer_literal_type value) expr_ty | |
| ExprIdent { value } -> | |
let ty = resolve_var_type ctx (fst value) in | |
return @@ inst_sigma_check ctx ty expr_ty | |
| ExprApply { func; arg } -> | |
let func, func_ty = infer_rho ctx func in | |
let tyarg, tyret = unify_func ctx func_ty in | |
let arg, _ = check_sigma ctx arg tyarg in | |
let ty = inst_sigma_check ctx tyret expr_ty in | |
return ~expr:(ExprApply { func; arg }) ty | |
| ExprLet _ -> | |
tc_let ctx expr (Some expr_ty) | |
| ExprIf { cond; expr_then; expr_else } -> | |
let cond, _ = check_rho ctx cond (T.TIntr Intrinsic_type.Bool) in | |
let expr_then, tythen = check_rho ctx expr_then expr_ty in | |
let expr_else, tyelse = check_rho ctx expr_else expr_ty in | |
return ~expr:(ExprIf { cond; expr_then; expr_else }) @@ | |
unify_tau ctx tythen tyelse | |
| ExprTuple { value } -> | |
let expr, inferred_ty = infer_rho ctx expr in | |
let inferred_ty = unify_tau ctx inferred_ty expr_ty in | |
expr, inferred_ty | |
| ExprApplyOpInfix { func; lhs; rhs } -> | |
let expr, ty = | |
check_rho ctx (transform_expr_apply_op_infix expr) expr_ty in | |
restore_expr_apply_op_infix expr, ty | |
| ExprTyAnnot { value; tyannot } -> | |
let value, tyannot = check_sigma ctx value tyannot in | |
let tyannot = inst_sigma_check ctx tyannot expr_ty in | |
return ~expr:(ExprTyAnnot { value; tyannot }) tyannot | |
| ExprMatch _ -> tc_match ctx expr (Some expr_ty) | |
| ExprPlaceholder -> U.unreachable "unexpected expression in check_rho" | |
and infer_rho (ctx: context) (expr: S.expr): S.expr * T.t = | |
let (expr_node, ty), annot = expr in | |
let return ?(expr = expr_node) ty = ((expr, ty), annot), ty in | |
match expr_node with | |
| ExprConst { value } -> | |
return @@ infer_literal_type value | |
| ExprIdent { value } -> | |
let ty = resolve_var_type ctx (fst value) in | |
return @@ inst_sigma_infer ctx ty | |
| ExprApply { func; arg } -> | |
let func, func_ty = infer_rho ctx func in | |
let tyarg, tyret = unify_func ctx func_ty in | |
let arg, _ = check_sigma ctx arg tyarg in | |
let ty = inst_sigma_infer ctx tyret in | |
return ~expr:(ExprApply { func; arg }) ty | |
| ExprLet _ -> | |
tc_let ctx expr None | |
| ExprIf { cond; expr_then; expr_else } -> | |
(* P60 of the paper *) | |
let cond, _ = check_rho ctx cond (T.TIntr Intrinsic_type.Bool) in | |
let expr_then, tythen = infer_rho ctx expr_then and | |
expr_else, tyelse = infer_rho ctx expr_else in | |
let _ = subsumption_check ctx tythen tyelse and | |
_ = subsumption_check ctx tyelse tythen in | |
return ~expr:(ExprIf { cond; expr_then; expr_else }) @@ | |
unify_tau ctx tythen tyelse | |
| ExprTuple { value } -> | |
let values = List.map (infer_rho ctx) value in | |
return ~expr:(ExprTuple { value = List.map fst values }) @@ | |
T.TTuple (List.map snd values) | |
| ExprApplyOpInfix { func; lhs; rhs } -> | |
let expr, ty = infer_rho ctx (transform_expr_apply_op_infix expr) in | |
restore_expr_apply_op_infix expr, ty | |
| ExprTyAnnot { value; tyannot } -> | |
let value, tyannot = check_sigma ctx value tyannot in | |
let tyannot = inst_sigma_infer ctx tyannot in | |
return ~expr:(ExprTyAnnot { value; tyannot }) tyannot | |
| ExprMatch _ -> tc_match ctx expr None | |
| ExprPlaceholder -> U.unreachable "unexpected expression in infer_rho" | |
and check_sigma (ctx: context) (expr: S.expr) (sigma: sigma): S.expr * sigma = | |
let skol_typevar_set, rho = skolemise ~inside:ctx.inside sigma in | |
let ret = check_rho ctx expr rho in | |
(* TODO: Right below - Ensure no type variable escaped to environment *) | |
let env_ty_set = var_env_get_types ctx.var_env in | |
let esc_typevar_set = | |
zonk_and_free_typevars ctx (TypeSet.add sigma env_ty_set) in | |
let bad_typevars = TypeVarSet.inter skol_typevar_set esc_typevar_set in | |
if TypeVarSet.is_empty bad_typevars | |
then ret else raise (TcError ("Type " ^ (T.to_string sigma) ^ | |
" not polymorphic enough")) | |
(* TODO: This function is useless ATM *) | |
and infer_sigma (ctx: context) (expr: S.expr): S.expr * sigma = | |
let env_ty_set = var_env_get_types ctx.var_env | |
and expr, expr_ty = infer_rho ctx expr in | |
let (expr_node, ty), annot = expr | |
and env_metavar_set = zonk_and_metavars ctx env_ty_set | |
and res_metavar_set = zonk_and_metavars ctx (TypeSet.singleton expr_ty) in | |
let ret_ty = | |
generalize ctx expr_ty | |
(MetaVarSet.diff res_metavar_set env_metavar_set) in | |
((expr_node, ret_ty), annot), ret_ty | |
(* TODO: We need a map function for expr *) | |
let rec zonk_expr (ctx: context) (expr: S.expr): S.expr = | |
let zonk_expr_ = zonk_expr ctx | |
and zonk_ = zonk ctx | |
and (expr_node, ty), annot = expr in | |
let return ?(expr = expr_node) () = ((expr, zonk ctx ty), annot) in | |
match expr_node with | |
| ExprConst _ -> return () | |
| ExprIdent _ -> return () | |
| ExprApply { func; arg } -> | |
return ~expr:(ExprApply { | |
func = zonk_expr_ func; arg = zonk_expr_ arg }) () | |
| ExprLet { name; value; scope; typeargs; args; tyannot; name_ty } -> | |
let zonk_arg arg = S.( | |
match arg with | |
| Arg arg, annot -> Arg { | |
arg with tyannot = zonk_ arg.tyannot }, annot) in | |
return ~expr:(ExprLet { | |
name; value = zonk_expr_ value; scope = zonk_expr_ scope; | |
tyannot = zonk_ tyannot; typeargs; args = List.map zonk_arg args; | |
name_ty = zonk_ name_ty }) () | |
| ExprIf { cond; expr_then; expr_else } -> | |
return ~expr:(ExprIf { | |
cond = zonk_expr_ cond; expr_then = zonk_expr_ expr_then; | |
expr_else = zonk_expr_ expr_else }) () | |
| ExprTuple { value } -> | |
return ~expr:(ExprTuple { value = List.map zonk_expr_ value }) () | |
| ExprApplyOpInfix { func; lhs; rhs } -> | |
return ~expr:(ExprApplyOpInfix { | |
func = zonk_expr_ func; lhs = zonk_expr_ lhs; | |
rhs = zonk_expr_ rhs }) () | |
| ExprTyAnnot { value; tyannot } -> return ~expr:(ExprTyAnnot { | |
value = zonk_expr_ value; tyannot = zonk_ tyannot; | |
}) () | |
| ExprMatch { value; cases } -> return ~expr:(ExprMatch { | |
value = zonk_expr_ value; cases = List.map (zonk_match_case ctx) cases; | |
}) () | |
| ExprPlaceholder -> U.unreachable "unexpected expression in zonk_expr" | |
and zonk_match_case (ctx: context) (match_case: S.match_case): S.match_case = | |
let (S.MatchCase { pattern; value }), annot = match_case in | |
MatchCase { | |
pattern = zonk_pattern ctx pattern; | |
value = zonk_expr ctx value; | |
}, annot | |
and zonk_pattern (ctx: context) (pattern: S.pattern): S.pattern = | |
let zonk_pattern_ = zonk_pattern ctx | |
and (pattern_node, ty), annot = pattern in | |
let return ?(pattern = pattern_node) () = ((pattern, zonk ctx ty), annot) in | |
match pattern_node with | |
| PLiteral _ -> return () | |
| PVar _ -> return () | |
| PConstrApp { bound; args } -> return ~pattern:(PConstrApp { | |
bound; args = List.map zonk_pattern_ args }) () | |
end | |
(* let rec trans_expr ~(inside: D.decl_module) | |
* (symtab: T.t SP.t) (stree: S.expr): S.expr = *) | |
let rec get_decl_return_type (args: D.decl_arg list) (ty: T.t): T.t = | |
match args, ty with | |
| _, T.TForAll (var, ty) -> get_decl_return_type args ty | |
| arg :: args, T.TFunc (tyarg, tyret) -> get_decl_return_type args tyret | |
| [], ty -> ty | |
| _ -> U.unreachable ( | |
Printf.sprintf "unexpected input in get_decl_return_type. Type: %s" | |
(T.to_string ty)) | |
let rec trans_top ~(inside: D.decl_module) (src: S.top): unit = | |
match src with | |
| S.TopLevel { value }, _ -> trans_mod ~inside value | |
and trans_mod ~(inside: D.decl_module) (src: S.modu): unit = | |
match src with | |
| S.ModuleFrag { value; bound }, _ -> | |
List.iter (trans_decl ~inside:bound) value | |
and trans_decl ~(inside: D.decl_module) (src: S.decl): unit = | |
match src with | |
| S.DeclModule modu, _ -> trans_mod ~inside modu | |
| S.DeclValue src, _ -> | |
let expr = trans_decl_value ~inside src.value src.bound in | |
src.value <- expr | |
| _ -> () | |
and trans_decl_value ~(inside: D.decl_module) | |
(src: S.expr) (decl: D.decl_value): S.expr = | |
(* let symtab = decl.args |> | |
* List.map (fun (D.{ name; val_type }: D.decl_arg) -> name, val_type) |> | |
* SP.of_list *) | |
let context = Typechecker.context_create ~inside | |
and args_list = | |
List.map (fun (D.{ name; val_type }) -> name, val_type) decl.args in | |
let _ = Typechecker.VarEnv.add_iter context.var_env | |
(CCList.to_iter args_list) in | |
let _ = Printf.printf "Typechecking %s ...\n" (Symbol.show decl.name) in | |
Typechecker.zonk_expr context @@ fst @@ Typechecker.check_sigma | |
context src (get_decl_return_type decl.args decl.val_type) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment