Skip to content

Instantly share code, notes, and snippets.

@zesterer
Created May 14, 2020 21:28
Show Gist options
  • Save zesterer/0fea5caf5dcf028feae275eabc6a0f4a to your computer and use it in GitHub Desktop.
Save zesterer/0fea5caf5dcf028feae275eabc6a0f4a to your computer and use it in GitHub Desktop.
Type Inference in 66 lines of Rust
use std::collections::HashMap;
#[derive(Debug)]
enum Type {
Num,
Bool,
List(Box<Type>),
Func(Box<Type>, Box<Type>),
}
pub type TypeId = usize;
#[derive(Copy, Clone, Debug)]
enum TypeInfo {
Unknown,
Ref(TypeId),
Num,
Bool,
List(TypeId),
Func(TypeId, TypeId),
}
#[derive(Default)]
struct Engine {
id_counter: usize,
vars: HashMap<TypeId, TypeInfo>,
}
impl Engine {
pub fn insert(&mut self, info: TypeInfo) -> TypeId {
self.id_counter += 1;
let id = self.id_counter;
self.vars.insert(id, info);
id
}
pub fn unify(&mut self, a: TypeId, b: TypeId) -> Result<(), String> {
use TypeInfo::*;
match (self.vars[&a], self.vars[&b]) {
(Unknown, _) => { self.vars.insert(a, TypeInfo::Ref(b)); Ok(()) },
(_, Unknown) => self.unify(b, a),
(Ref(a), _) => self.unify(a, b),
(_, Ref(b)) => self.unify(a, b),
(Num, Num) => Ok(()),
(Bool, Bool) => Ok(()),
(List(a_item), List(b_item)) => self.unify(a_item, b_item),
(Func(a_i, a_o), Func(b_i, b_o)) => self.unify(a_i, b_i).and_then(|_| self.unify(a_o, b_o)),
(a, b) => Err(format!("Conflict between {:?} and {:?}", a, b)),
}
}
pub fn reconstruct(&self, id: TypeId) -> Result<Type, String> {
use TypeInfo::*;
match self.vars[&id] {
Unknown => Err(format!("Cannot infer")),
Ref(id) => self.reconstruct(id),
Num => Ok(Type::Num),
Bool => Ok(Type::Bool),
List(item) => Ok(Type::List(Box::new(self.reconstruct(item)?))),
Func(i, o) => Ok(Type::Func(Box::new(self.reconstruct(i)?), Box::new(self.reconstruct(o)?))),
}
}
}
fn main() {
let mut engine = Engine::default();
// A function with an unknown input
let i = engine.insert(TypeInfo::Unknown);
let o = engine.insert(TypeInfo::Num);
let f0 = engine.insert(TypeInfo::Func(i, o));
// A function with an unknown output
let i = engine.insert(TypeInfo::Bool);
let o = engine.insert(TypeInfo::Unknown);
let f1 = engine.insert(TypeInfo::Func(i, o));
// Unify them together...
engine.unify(f0, f1).unwrap();
// ...and compute the resulting type
println!("Final type = {:?}", engine.reconstruct(f0));
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment