Skip to content

Instantly share code, notes, and snippets.

@FredTheDino
Last active May 16, 2023 20:45
Show Gist options
  • Save FredTheDino/7c348f75b2f761532f72549be2ac07bd to your computer and use it in GitHub Desktop.
Save FredTheDino/7c348f75b2f761532f72549be2ac07bd to your computer and use it in GitHub Desktop.
A simple Hindley Milner typechecker implemented for Lambda Calculus
use std::collections::HashMap;
#[derive(Clone, Debug)]
enum Ast {
Unit,
Var(&'static str), // All strings are uniq
Fun(&'static str, Box<Ast>),
Call(Box<Ast>, Box<Ast>),
}
fn unit() -> Ast {
Ast::Unit
}
fn var(v: &'static str) -> Ast {
Ast::Var(v)
}
fn fun(x: &'static str, body: Ast) -> Ast {
Ast::Fun(x, Box::new(body))
}
fn call(x: Ast, fun: Ast) -> Ast {
Ast::Call(Box::new(x), Box::new(fun))
}
fn interpret(ast: Ast, vars: &mut HashMap<&'static str, Ast>) -> Option<Ast> {
match ast {
Ast::Unit => Some(Ast::Unit),
Ast::Var(var) => vars.get(var).cloned(),
Ast::Fun(var, body) => Some(Ast::Fun(var, body)),
Ast::Call(var, fun) => {
if let Ast::Fun(arg, body) = interpret(*fun, vars)? {
let var = interpret(*var, vars)?;
vars.insert(arg, var);
interpret(*body, vars)
} else {
panic!("CANNOT CALL NON FUNCTION!");
}
}
}
}
#[derive(Clone, Debug)]
enum Type {
Unknown,
Node(usize),
Unit,
Fun(Box<Type>, Box<Type>),
}
#[derive(Clone, Debug)]
struct Ctx {
tys: Vec<Type>,
names: HashMap<&'static str, usize>,
}
impl Ctx {
fn new() -> Self {
Ctx {
tys: Vec::new(),
names: HashMap::new(),
}
}
fn generic_for_var(&mut self, var: &'static str) -> Type {
match self.names.entry(var) {
std::collections::hash_map::Entry::Occupied(x) => Type::Node(*x.get()),
std::collections::hash_map::Entry::Vacant(n) => {
let id = self.tys.len();
self.tys.push(Type::Unknown);
Type::Node(id)
}
}
}
fn new_generic(&mut self) -> Type {
let id = self.tys.len();
self.tys.push(Type::Unknown);
Type::Node(id)
}
fn replace(&mut self, a: usize, other: Type) -> Result<usize, &'static str> {
if let Type::Node(aa) = self.tys[a] {
let inner_a = self.replace(aa, other)?;
self.tys[a] = Type::Node(inner_a);
Ok(inner_a)
} else {
let ty_a = self.tys[a].clone();
let ty_imp = unify(self, ty_a, other)?;
self.tys[a] = ty_imp;
Ok(a)
}
}
}
fn expr(
ctx: &mut Ctx,
ast: Ast,
vars: &mut HashMap<&'static str, Type>,
) -> Result<Type, &'static str> {
match ast {
Ast::Unit => Ok(Type::Unit),
Ast::Var(var) => {
if let Some(ty) = vars.get(var) {
Ok(ty.clone())
} else {
Err("Unknown variable!")
}
}
Ast::Fun(var, body) => {
let ty = ctx.generic_for_var(var);
vars.insert(var, ty.clone());
Ok(Type::Fun(Box::new(ty), Box::new(expr(ctx, *body, vars)?)))
}
Ast::Call(arg, fun) => {
let fun_ty = expr(ctx, *fun, vars)?;
let ret_ty = ctx.new_generic();
let infered_ty = Type::Fun(Box::new(expr(ctx, *arg, vars)?), Box::new(ret_ty.clone()));
unify(ctx, fun_ty, infered_ty)?;
Ok(ret_ty)
}
}
}
fn unify(ctx: &mut Ctx, a: Type, b: Type) -> Result<Type, &'static str> {
match (a, b) {
(Type::Unknown, guess) | (guess, Type::Unknown) => Ok(guess),
(Type::Node(a), other) | (other, Type::Node(a)) => {
Ok(Type::Node(ctx.replace(a, other)?))
}
(Type::Unit, Type::Unit) => Ok(Type::Unit),
(Type::Fun(aa, ab), Type::Fun(ba, bb)) => {
let a = unify(ctx, *aa, *ba)?;
let b = unify(ctx, *ab, *bb)?;
Ok(Type::Fun(Box::new(a), Box::new(b)))
}
(a, b) => {
println!(":( {:?} - {:?}", a, b);
Err("Failed to unify!")
}
}
}
fn id(v: &'static str) -> Ast {
fun(v, var(v))
}
fn main() {
let program = call(id("y"), id("x"));
println!("PROGRAM: {:?}", program.clone());
let mut ctx =Ctx::new();
let typ = expr(&mut ctx, program.clone(), &mut HashMap::new());
println!("TYPE: {:?}", typ);
println!("CTX: {:?}", ctx);
println!("EVAL: {:?}", interpret(program, &mut HashMap::new()));
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment