Skip to content

Instantly share code, notes, and snippets.

@brendanzab
Last active June 21, 2023 07:18
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save brendanzab/36ed5eaf5eef58db577624ac83d72006 to your computer and use it in GitHub Desktop.
Save brendanzab/36ed5eaf5eef58db577624ac83d72006 to your computer and use it in GitHub Desktop.
Bidirectional type checker for a simple functional language, in Rust
//! Bidirectional type checker for a simple functional language
use std::rc::Rc;
#[derive(Clone, PartialEq, Eq)]
enum Type {
Bool,
Int,
Fun(Rc<Type>, Rc<Type>),
}
enum Term {
Var(String),
Ann(Box<Term>, Rc<Type>),
Let(String, Box<Term>, Box<Term>),
BoolLit(bool),
IntLit(i32),
FunLit(String, Box<Term>),
FunApp(Box<Term>, Box<Term>),
}
/// A stack of bindings currently in scope
type Context = Vec<(String, Rc<Type>)>;
/// Check a term against a type annotation
fn check(context: &mut Context, term: &Term, expected_type: &Rc<Type>) -> Result<(), &'static str> {
match (term, expected_type.as_ref()) {
(Term::Let(name, def, body), _) => {
let def_type = synth(context, def)?;
context.push((name.clone(), def_type));
let body_result = check(context, body, expected_type);
context.pop();
body_result
}
(Term::FunLit(name, body), Type::Fun(param_type, body_type)) => {
context.push((name.clone(), param_type.clone()));
let body_result = check(context, body, body_type);
context.pop();
body_result
}
// Switch to synthesis mode
(term, _) => match synth(context, term)? == *expected_type {
true => Ok(()),
false => Err("mismatched types"),
},
}
}
/// Synthesise the type of a term
fn synth(context: &mut Context, term: &Term) -> Result<Rc<Type>, &'static str> {
match term {
Term::Var(name) => match context.iter().rev().find(|(n, _)| n == name) {
Some((_, r#type)) => Ok(r#type.clone()),
None => Err("unbound variable"),
},
Term::Ann(term, r#type) => {
check(context, term, r#type)?;
Ok(r#type.clone())
}
Term::Let(name, def, body) => {
let def_type = synth(context, def)?;
context.push((name.clone(), def_type));
let body_type = synth(context, body);
context.pop();
body_type
}
Term::BoolLit(_) => Ok(Rc::new(Type::Bool)),
Term::IntLit(_) => Ok(Rc::new(Type::Int)),
Term::FunLit(_, _) => Err("ambiguous function literal"),
Term::FunApp(head, arg) => match synth(context, head)?.as_ref() {
Type::Fun(param_type, body_type) => {
check(context, arg, param_type)?;
Ok(body_type.clone())
}
_ => Err("not a function"),
},
}
}
@brendanzab
Copy link
Author

brendanzab commented Feb 16, 2023

@brendanzab
Copy link
Author

Useful resources:

@Hirrolot
Copy link

I find bidirectional type checking quite neat because it allows you to "infer" (or, better, "propagate") type information without actually doing anything fancy (HM, looking at you). And due to its simplicity, it does even scale to such complicated systems as CIC 1.

Footnotes

  1. "Complete Bidirectional Typing for the Calculus of Inductive Constructions" by Meven Lennon-Bertrand

@Hirrolot
Copy link

I think it'd also be interesting to experiment with the tagless final approach to building languages in Rust. The recent addition of GATs has allowed doing it in Rust. A couple of benefits atop my head would be modular AST representation (possibly split across multiple files) and type-safe encodings of object language constructions similar to GADTs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment