Skip to content

Instantly share code, notes, and snippets.

@osa1
Created June 15, 2023 16:23
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save osa1/5d505dbed2dac30955a822e1bb9079fb to your computer and use it in GitHub Desktop.
Save osa1/5d505dbed2dac30955a822e1bb9079fb to your computer and use it in GitHub Desktop.
HM type inference in Rust
/*
HM type inference with:
- Union-find for unification.
- Type variable levels for generalization.
Implementation became tricky because of the mutable level and link fields.
In this implementations links cannot form cycles (occurs check catches it), so we could use
`Rc<RefCell<..>>`-wrapped types. However:
- I'm not sure that `Rc<RefCell<..>>` is actually more convenient.
- More importantly, there's a variant of this algorithm presented in [2], which lazily performs
occurs checks and allows forming cycles. I don't know if that implementation will cause leaks
with `Rc` (we may have to manually break the cycles).
This version of the algorithm is ported from Programming Language Concepts chapter 6. Code is
available in [1].
Example programs are mostly copied from [2].
[1]: https://www.itu.dk/~sestoft/plc
[2]: https://okmij.org/ftp/ML/generalization.html
*/
use std::collections::HashMap as Map;
use std::collections::HashSet as Set;
use std::fmt;
#[derive(Debug)]
struct TypeArena {
types: Vec<Type>,
vars: Vec<TypeVar>,
}
impl fmt::Display for TypeArena {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(f, "Types:")?;
for (i, ty) in self.types.iter().enumerate() {
writeln!(f, "{}: {:?}", i, ty)?;
}
writeln!(f, "Vars:")?;
for (i, var) in self.vars.iter().enumerate() {
writeln!(f, "{}: {:?}", i, var)?;
}
writeln!(f, "-----------")
}
}
// size = 12
#[derive(Debug, Clone, Copy)]
enum Type {
Int,
Bool,
Fun(TypeRef, TypeRef),
Var(TypeVarRef),
Quantified(u32),
}
// size = 12
#[derive(Debug, Clone, Copy)]
struct TypeVar {
link: TypeVarLink,
level: u32,
}
// size = 8
#[derive(Debug, Clone, Copy)]
enum TypeVarLink {
Link(TypeRef),
NoLink,
}
// size = 4
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
struct TypeRef(u32);
// size = 4
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
struct TypeVarRef(u32);
impl TypeArena {
const INT_REF: TypeRef = TypeRef(0);
const BOOL_REF: TypeRef = TypeRef(1);
fn new() -> Self {
Self {
types: vec![Type::Int, Type::Bool],
vars: vec![],
}
}
fn get_type(&self, ty_ref: &TypeRef) -> &Type {
&self.types[ty_ref.0 as usize]
}
fn get_type_var(&self, ty_var_ref: &TypeVarRef) -> &TypeVar {
&self.vars[ty_var_ref.0 as usize]
}
fn get_type_var_mut(&mut self, ty_var_ref: &TypeVarRef) -> &mut TypeVar {
&mut self.vars[ty_var_ref.0 as usize]
}
fn fresh_type_var(&mut self, level: u32) -> TypeVarRef {
let ref_ = TypeVarRef(self.vars.len() as u32);
self.vars.push(TypeVar {
link: TypeVarLink::NoLink,
level,
});
ref_
}
fn new_type(&mut self, ty: Type) -> TypeRef {
let ref_ = TypeRef(self.types.len() as u32);
self.types.push(ty);
ref_
}
fn int(&self) -> TypeRef {
Self::INT_REF
}
fn bool(&self) -> TypeRef {
Self::BOOL_REF
}
}
#[derive(Debug, Clone)]
#[allow(unused)]
enum Expr {
Int,
Bool,
Var(&'static str),
Prim(Prim, Box<Expr>, Box<Expr>),
Let {
var: &'static str,
var_body: Box<Expr>,
let_body: Box<Expr>,
},
If {
cond: Box<Expr>,
then_: Box<Expr>,
else_: Box<Expr>,
},
Lam {
arg: &'static str,
body: Box<Expr>,
},
LetRec {
fun: &'static str,
arg: &'static str,
fun_body: Box<Expr>,
let_body: Box<Expr>,
},
App(Box<Expr>, Box<Expr>),
}
#[derive(Debug, Clone, Copy)]
#[allow(unused)]
enum Prim {
Multiply,
Add,
Subtract,
Equal,
LessThan,
And,
}
#[derive(Debug)]
struct TypeScheme {
/// Number of quantified type variables.
num_foralls: u32,
/// The type, with `Quantified` types for quantified type variables.
ty: TypeRef,
}
impl TypeScheme {
fn monomorphic(ty: TypeRef) -> Self {
TypeScheme { num_foralls: 0, ty }
}
}
fn instantiate(types: &mut TypeArena, level: u32, scheme: &TypeScheme) -> TypeRef {
if scheme.num_foralls == 0 {
return scheme.ty;
}
// Maps quantified variables to their instances.
let mut substs: Vec<TypeRef> = Vec::with_capacity(scheme.num_foralls as usize);
for _ in 0..scheme.num_foralls {
let fresh_var = types.fresh_type_var(level);
let fresh_type = types.new_type(Type::Var(fresh_var));
substs.push(fresh_type);
}
fn instantiate_vars(types: &mut TypeArena, substs: &[TypeRef], ty: &TypeRef) -> TypeRef {
match *types.get_type(ty) {
Type::Int | Type::Bool | Type::Var(_) => *ty,
Type::Fun(ty1, ty2) => {
let ty1 = instantiate_vars(types, substs, &ty1);
let ty2 = instantiate_vars(types, substs, &ty2);
types.new_type(Type::Fun(ty1, ty2))
}
Type::Quantified(idx) => substs[idx as usize],
}
}
instantiate_vars(types, &substs, &scheme.ty)
}
fn generalize(types: &mut TypeArena, level: u32, ty: &TypeRef) -> TypeScheme {
let type_vars: Set<TypeVarRef> = type_vars(types, ty);
let mut gen_idx = 0;
let generalized_type_vars: Map<TypeVarRef, u32> = type_vars
.iter()
.filter_map(|type_var_ref| {
if types.get_type_var(type_var_ref).level > level {
let idx = gen_idx;
gen_idx += 1;
Some((*type_var_ref, idx))
} else {
None
}
})
.collect();
fn generalize_vars(
types: &mut TypeArena,
substs: &Map<TypeVarRef, u32>,
ty: &TypeRef,
) -> TypeRef {
let ty = normalize_type(types, ty);
match *types.get_type(&ty) {
Type::Int | Type::Bool => ty,
Type::Fun(ty1, ty2) => {
let ty1 = generalize_vars(types, substs, &ty1);
let ty2 = generalize_vars(types, substs, &ty2);
types.new_type(Type::Fun(ty1, ty2))
}
Type::Var(var) => match substs.get(&var) {
Some(gen_idx) => types.new_type(Type::Quantified(*gen_idx)),
None => ty,
},
Type::Quantified(_) => panic!("Quantified type in generalize_vars"),
}
}
let scheme_ty = generalize_vars(types, &generalized_type_vars, ty);
TypeScheme {
num_foralls: generalized_type_vars.len() as u32,
ty: scheme_ty,
}
}
fn infer_types(
types: &mut TypeArena,
level: u32,
env: &mut Map<String, TypeScheme>,
expr: &Expr,
) -> Result<TypeRef, String> {
match expr {
Expr::Int => Ok(types.int()),
Expr::Bool => Ok(types.bool()),
Expr::Var(var) => {
let scheme = match env.get(*var) {
Some(scheme) => scheme,
None => return Err(format!("Unbound variable: {}", var)),
};
Ok(instantiate(types, level, scheme))
}
Expr::Prim(op, e1, e2) => {
let e1_ty = infer_types(types, level, env, e1)?;
let e2_ty = infer_types(types, level, env, e2)?;
match op {
// int -> int -> int
Prim::Multiply | Prim::Add | Prim::Subtract => {
unify(types, &types.int(), &e1_ty)?;
unify(types, &types.int(), &e2_ty)?;
Ok(types.int())
}
// a -> a -> bool
Prim::Equal => {
unify(types, &e1_ty, &e2_ty)?;
Ok(types.bool())
}
// int -> int -> bool
Prim::LessThan => {
unify(types, &types.int(), &e1_ty)?;
unify(types, &types.int(), &e2_ty)?;
Ok(types.bool())
}
// bool -> bool -> bool
Prim::And => {
unify(types, &types.bool(), &e1_ty)?;
unify(types, &types.bool(), &e2_ty)?;
Ok(types.bool())
}
}
}
Expr::Let {
var,
var_body,
let_body,
} => {
let var_body_ty = infer_types(types, level + 1, env, var_body)?;
let var_body_scheme = generalize(types, level, &var_body_ty);
let old_binding = env.insert((*var).to_owned(), var_body_scheme);
let let_body_ty = infer_types(types, level, env, let_body)?;
match old_binding {
Some(old_binding) => {
env.insert((*var).to_owned(), old_binding);
}
None => {
env.remove(*var);
}
}
Ok(let_body_ty)
}
Expr::If { cond, then_, else_ } => {
let cond_ty = infer_types(types, level, env, cond)?;
let then_ty = infer_types(types, level, env, then_)?;
let else_ty = infer_types(types, level, env, else_)?;
unify(types, &types.bool(), &cond_ty)?;
unify(types, &then_ty, &else_ty)?;
Ok(then_ty)
}
Expr::Lam { arg, body } => {
let fun_level = level + 1;
let fun_ty_var = types.fresh_type_var(fun_level);
let fun_ty = types.new_type(Type::Var(fun_ty_var));
let arg_ty_var = types.fresh_type_var(fun_level);
let arg_ty = types.new_type(Type::Var(arg_ty_var));
let old_arg_binding = env.insert((*arg).to_owned(), TypeScheme::monomorphic(arg_ty));
let ret_ty = infer_types(types, fun_level, env, body)?;
let expected_fun_ty = types.new_type(Type::Fun(arg_ty, ret_ty));
unify(types, &fun_ty, &expected_fun_ty)?;
match old_arg_binding {
Some(old_arg_binding) => {
env.insert((*arg).to_owned(), old_arg_binding);
}
None => {
env.remove(*arg);
}
}
Ok(fun_ty)
}
Expr::LetRec {
fun,
arg,
fun_body,
let_body,
} => {
let fun_ty = infer_types(
types,
level,
env,
&Expr::Lam {
arg,
body: Box::new((**fun_body).clone()),
},
)?;
let fun_ty_generalized = generalize(types, level, &fun_ty);
let old_binding = env.insert((*fun).to_owned(), fun_ty_generalized);
let body_ty = infer_types(types, level, env, let_body)?;
/*
TODO: Just for testing schemes..
match old_binding {
Some(old_binding) => {
env.insert((*fun).to_owned(), old_binding);
}
None => {
env.remove(*fun);
}
}
*/
Ok(body_ty)
}
Expr::App(fun, arg) => {
let fun_ty = infer_types(types, level, env, fun)?;
let arg_ty = infer_types(types, level, env, arg)?;
let ret_ty_var = types.fresh_type_var(level);
let ret_ty = types.new_type(Type::Var(ret_ty_var));
let temp_fun_ty = types.new_type(Type::Fun(arg_ty, ret_ty));
unify(types, &temp_fun_ty, &fun_ty)?;
Ok(ret_ty)
}
}
}
fn unify(types: &mut TypeArena, ref1: &TypeRef, ref2: &TypeRef) -> Result<(), String> {
let ref1_normalized = normalize_type(types, ref1);
let ref2_normalized = normalize_type(types, ref2);
let ty1 = *types.get_type(&ref1_normalized);
let ty2 = *types.get_type(&ref2_normalized);
match (ty1, ty2) {
(Type::Int, Type::Int) => Ok(()),
(Type::Bool, Type::Bool) => Ok(()),
(Type::Fun(t11, t12), Type::Fun(t21, t22)) => {
unify(types, &t11, &t21)?;
unify(types, &t12, &t22)?;
Ok(())
}
(Type::Var(var1), Type::Var(var2)) => {
if var1 == var2 {
return Ok(());
}
let var1_level = types.get_type_var(&var1).level;
let var2_level = types.get_type_var(&var2).level;
if var1_level < var2_level {
// TODO: I don't understand this part.. Why do we link outer one to the inner?
link_var(types, &var1, &ref2_normalized)?;
} else {
link_var(types, &var2, &ref1_normalized)?;
}
Ok(())
}
(Type::Var(var), _) => link_var(types, &var, &ref2_normalized),
(_, Type::Var(var)) => link_var(types, &var, &ref1_normalized),
(ty1, ty2) => Err(format!("Cannot unify {:?} and {:?}", ty1, ty2)),
}
}
fn type_vars(types: &mut TypeArena, ty: &TypeRef) -> Set<TypeVarRef> {
fn type_vars_(types: &mut TypeArena, ty: &TypeRef, set: &mut Set<TypeVarRef>) {
let norm_ty = normalize_type(types, ty);
match *types.get_type(&norm_ty) {
Type::Int | Type::Bool => {}
Type::Fun(ty1, ty2) => {
type_vars_(types, &ty1, set);
type_vars_(types, &ty2, set);
}
Type::Var(var) => {
set.insert(var);
}
Type::Quantified(_) => panic!("Quantified type variable in type_vars"),
}
}
let mut tvs = Default::default();
type_vars_(types, ty, &mut tvs);
tvs
}
fn prune_level(types: &mut TypeArena, max_level: u32, tvs: &Set<TypeVarRef>) {
for var in tvs {
let var_ = types.get_type_var_mut(var);
var_.level = max_level.min(var_.level);
}
}
fn link_var(types: &mut TypeArena, var: &TypeVarRef, ty: &TypeRef) -> Result<(), String> {
let var_level = types.get_type_var(var).level;
let ty_vars = type_vars(types, ty);
if ty_vars.contains(var) {
println!("{}", types);
return Err(format!(
"Occurs check fails when linking {:?} -> {:?}",
var,
types.get_type(ty)
));
}
prune_level(types, var_level, &ty_vars);
types.get_type_var_mut(var).link = TypeVarLink::Link(*ty);
Ok(())
}
fn normalize_type(types: &mut TypeArena, ty_ref: &TypeRef) -> TypeRef {
let ty = *types.get_type(ty_ref);
match ty {
Type::Var(var_ref) => {
let TypeVar { link, level: _ } = *types.get_type_var(&var_ref);
match link {
TypeVarLink::Link(type_ref) => {
let ty_normalized = normalize_type(types, &type_ref);
types.get_type_var_mut(&var_ref).link = TypeVarLink::Link(ty_normalized);
ty_normalized
}
TypeVarLink::NoLink => *ty_ref,
}
}
_ => *ty_ref,
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
//
// Printing
//
////////////////////////////////////////////////////////////////////////////////////////////////////
use std::io::Write;
trait Print {
fn print<W: Write>(&self, types: &TypeArena, f: &mut W);
}
impl Print for Type {
fn print<W: Write>(&self, types: &TypeArena, f: &mut W) {
fn follow_links<'a>(types: &'a TypeArena, ty: &'a Type) -> Type {
match ty {
Type::Var(var_ref) => {
let TypeVar { link, level: _ } = *types.get_type_var(var_ref);
match link {
TypeVarLink::Link(type_ref) => {
let type_ref_type = *types.get_type(&type_ref);
follow_links(types, &type_ref_type)
}
TypeVarLink::NoLink => *ty,
}
}
Type::Int | Type::Bool | Type::Fun(_, _) | Type::Quantified(_) => *ty,
}
}
match follow_links(types, self) {
Type::Int => write!(f, "int").unwrap(),
Type::Bool => write!(f, "bool").unwrap(),
Type::Fun(arg, ret) => {
write!(f, "(").unwrap();
arg.print(types, f);
write!(f, " -> ").unwrap();
ret.print(types, f);
write!(f, ")").unwrap();
}
Type::Var(var) => {
write!(f, "_{}", var.0).unwrap();
}
Type::Quantified(idx) => {
write!(f, "${}", idx).unwrap();
}
}
}
}
impl Print for TypeRef {
fn print<W: Write>(&self, types: &TypeArena, f: &mut W) {
let ty = *types.get_type(self);
ty.print(types, f);
}
}
impl Print for TypeScheme {
fn print<W: Write>(&self, types: &TypeArena, f: &mut W) {
let TypeScheme { num_foralls, ty } = self;
if *num_foralls != 0 {
write!(f, "∀ {} . ", num_foralls).unwrap();
}
ty.print(types, f);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
//
// Test programs
//
////////////////////////////////////////////////////////////////////////////////////////////////////
use std::io::stdout;
mod test_programs {
use super::{Expr, Prim};
fn b<T>(t: T) -> Box<T> {
Box::new(t)
}
/// Boxed variable expression.
fn bv(s: &'static str) -> Box<Expr> {
b(Expr::Var(s))
}
fn app(exprs: Vec<Box<Expr>>) -> Box<Expr> {
let mut iter = exprs.into_iter();
let mut e = iter.next().unwrap();
for arg in iter {
e = b(Expr::App(e, arg));
}
e
}
// let id a = a in id
pub(crate) fn id() -> Expr {
Expr::LetRec {
fun: "id",
arg: "a",
fun_body: bv("a"),
let_body: bv("id"),
}
}
// let first x =
// let
// second y = x
// in
// second
// in
// first
pub(crate) fn first() -> Expr {
Expr::LetRec {
fun: "first",
arg: "x",
fun_body: b(Expr::LetRec {
fun: "second",
arg: "y",
fun_body: bv("x"),
let_body: bv("second"),
}),
let_body: bv("first"),
}
}
// let f x =
// let g y =
// x y
// in
// g
// in
// f
pub(crate) fn app_fxy() -> Expr {
Expr::LetRec {
fun: "f",
arg: "x",
fun_body: b(Expr::LetRec {
fun: "g",
arg: "y",
fun_body: b(Expr::App(bv("x"), bv("y"))),
let_body: bv("g"),
}),
let_body: bv("f"),
}
}
// let add x =
// let add1 y =
// x + y
// in
// add1
// in
// add
pub(crate) fn add() -> Expr {
Expr::LetRec {
fun: "add",
arg: "x",
fun_body: b(Expr::LetRec {
fun: "add1",
arg: "y",
fun_body: b(Expr::Prim(Prim::Add, bv("x"), bv("y"))),
let_body: bv("add1"),
}),
let_body: bv("add"),
}
}
// \y -> y (\z -> y z)
//
// occurs check fail
pub(crate) fn heiber() -> Expr {
Expr::Lam {
arg: "y",
body: b(Expr::App(
bv("y"),
b(Expr::Lam {
arg: "z",
body: b(Expr::App(bv("y"), bv("z"))),
}),
)),
}
}
// \x -> \y -> \k -> k (k x y) (k y x)
// : t -> t -> (t -> t -> t) -> t
pub(crate) fn kirang() -> Expr {
Expr::Lam {
arg: "x",
body: b(Expr::Lam {
arg: "y",
body: b(Expr::Lam {
arg: "k",
body: b(Expr::App(
b(Expr::App(bv("k"), app(vec![bv("k"), bv("x"), bv("y")]))),
b(Expr::App(bv("k"), app(vec![bv("k"), bv("y"), bv("x")]))),
)),
}),
}),
}
}
// \x -> \k -> k (k x)
// : t -> (t -> t) -> t
pub(crate) fn kirang_simple() -> Expr {
Expr::Lam {
arg: "x",
body: b(Expr::Lam {
arg: "k",
body: b(Expr::App(bv("k"), b(Expr::App(bv("k"), bv("x"))))),
}),
}
}
// let id x = let y = x in y in id
pub(crate) fn sound_generalization_1() -> Expr {
Expr::LetRec {
fun: "id",
arg: "x",
fun_body: b(Expr::Let {
var: "y",
var_body: bv("x"),
let_body: bv("y"),
}),
let_body: bv("id"),
}
}
}
fn main() {
{
let mut types = TypeArena::new();
let mut env: Map<String, TypeScheme> = Default::default();
infer_types(&mut types, 0, &mut env, &test_programs::id()).unwrap();
let scm = env.get("id").unwrap();
print!("(\\x -> x) : ");
scm.print(&types, &mut stdout());
println!();
}
{
let mut types = TypeArena::new();
let mut env: Map<String, TypeScheme> = Default::default();
infer_types(&mut types, 0, &mut env, &test_programs::first()).unwrap();
let scm = env.get("first").unwrap();
print!("(\\x y -> x) : ");
scm.print(&types, &mut stdout());
println!();
}
{
let mut types = TypeArena::new();
let mut env: Map<String, TypeScheme> = Default::default();
infer_types(&mut types, 0, &mut env, &test_programs::app_fxy()).unwrap();
let scm = env.get("f").unwrap();
print!("(\\x y -> x y) : ");
scm.print(&types, &mut stdout());
println!();
}
{
let mut types = TypeArena::new();
let mut env: Map<String, TypeScheme> = Default::default();
infer_types(&mut types, 0, &mut env, &test_programs::add()).unwrap();
let scm = env.get("add").unwrap();
print!("(\\x y -> x + y) : ");
scm.print(&types, &mut stdout());
println!();
}
/*
Fails occurs check:
{
let mut types = TypeArena::new();
let mut env: Map<String, TypeScheme> = Default::default();
let ty = infer_types(&mut types, 0, &mut env, &test_programs::heiber()).unwrap();
print!("(\\y -> y (\\z -> y z)) : ");
ty.print(&types, &mut stdout());
println!();
}
*/
{
let mut types = TypeArena::new();
let mut env: Map<String, TypeScheme> = Default::default();
let ty = infer_types(&mut types, 0, &mut env, &test_programs::kirang_simple()).unwrap();
print!("(\\x -> \\k -> k (k x) : ");
ty.print(&types, &mut stdout());
println!();
}
/*
FIXME: This incorrectly triggers occurs check
{
let mut types = TypeArena::new();
let mut env: Map<String, TypeScheme> = Default::default();
let ty = infer_types(&mut types, 0, &mut env, &test_programs::kirang()).unwrap();
print!("(\\x -> \\y -> \\k -> k (k x y) (k y x) : ");
ty.print(&types, &mut stdout());
println!();
}
*/
{
let mut types = TypeArena::new();
let mut env: Map<String, TypeScheme> = Default::default();
infer_types(
&mut types,
0,
&mut env,
&test_programs::sound_generalization_1(),
)
.unwrap();
let scm = env.get("id").unwrap();
print!("let id x = let y = x in y : ");
scm.print(&types, &mut stdout());
println!();
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment