Skip to content

Instantly share code, notes, and snippets.

@zesterer
Last active July 1, 2023 00:52
Show Gist options
  • Save zesterer/96e1449d721f006c4a9191c359892445 to your computer and use it in GitHub Desktop.
Save zesterer/96e1449d721f006c4a9191c359892445 to your computer and use it in GitHub Desktop.
Type inference via unification with support for higher-ranked types
type Term = &'static str;
#[derive(Clone, Debug)]
enum TyInfo {
Unknown,
Ref(TyVar),
Term(Term),
Bool,
Int,
Func(TyVar, TyVar),
For(Term, TyVar),
}
type TyVar = usize;
#[derive(Default)]
struct Engine {
vars: Vec<TyInfo>,
}
impl Engine {
fn insert(&mut self, info: TyInfo) -> TyVar {
let var = self.vars.len();
self.vars.push(info);
var
}
fn follow(&self, var: TyVar) -> TyInfo {
match &self.vars[var] {
TyInfo::Ref(var) => self.follow(*var),
info => info.clone(),
}
}
/// Unify type variables originating from the same type
fn unify_same(
&mut self,
x: TyVar,
y: TyVar,
) -> Result<(), String> {
println!("Unify same {:?} with {:?}", self.follow(x), self.follow(y));
match (self.follow(x), self.follow(y)) {
(TyInfo::Unknown, _) => Ok(self.vars[x] = TyInfo::Ref(y)),
(_, TyInfo::Unknown) => Ok(self.vars[y] = TyInfo::Ref(x)),
(TyInfo::Term(x_t), TyInfo::Term(y_t)) if x_t == y_t => Ok(()),
(TyInfo::Bool, TyInfo::Bool) | (TyInfo::Int, TyInfo::Int) => Ok(()),
(TyInfo::Func(x_i, x_o), TyInfo::Func(y_i, y_o)) => {
self.unify_same(x_i, y_i)?;
self.unify_same(x_o, y_o)
},
(TyInfo::For(x_t, x), TyInfo::For(y_t, y)) if x_t == y_t => self.unify_same(x, y),
(x, y) => Err(format!("cannot unify {:?} and {:?}", x, y)),
}
}
/// Unify type variables originating from different types, allowing instantiations where required
fn unify(
&mut self,
x: TyVar,
y: TyVar,
x_terms: &mut Vec<(Term, Option<TyVar>)>,
y_terms: &mut Vec<(Term, Option<TyVar>)>,
) -> Result<(), String> {
println!("Unify {:?} with {:?}", self.follow(x), self.follow(y));
let (x_info, y_info) = (self.follow(x), self.follow(y));
match (x_info, y_info) {
(TyInfo::Unknown, _) => Ok(self.vars[x] = TyInfo::Ref(y)),
(_, TyInfo::Unknown) => Ok(self.vars[y] = TyInfo::Ref(x)),
// Handle higher-ranked types
(TyInfo::Term(x_t), TyInfo::Term(y_t)) => {
match x_terms
.iter_mut()
.find(|(t, _)| *t == x_t)
{
Some((_, Some(t))) => self.unify_same(*t, y)?,
Some((_, t @ None)) => *t = Some(y),
None => Err(format!("Unknown term '{}'", x_t))?,
}
match y_terms
.iter_mut()
.find(|(t, _)| *t == y_t)
{
Some((_, Some(t))) => self.unify_same(*t, x)?,
Some((_, t @ None)) => *t = Some(x),
None => Err(format!("Unknown term '{}'", y_t))?,
}
Ok(())
},
(TyInfo::Term(x_t), _) => match x_terms
.iter_mut()
.find(|(t, _)| *t == x_t)
{
Some((_, Some(t))) => self.unify_same(y, *t),
Some((_, t @ None)) => Ok(*t = Some(y)),
None => Err(format!("Unknown term '{}'", x_t)),
},
(_, TyInfo::Term(y_t)) => match y_terms
.iter_mut()
.find(|(t, _)| *t == y_t)
{
Some((_, Some(t))) => self.unify_same(x, *t),
Some((_, t @ None)) => Ok(*t = Some(x)),
None => Err(format!("Unknown term '{}'", y_t)),
},
(TyInfo::For(x_t, x), _) => {
x_terms.push((x_t, None));
let r = self.unify(x, y, x_terms, y_terms);
x_terms.pop();
r
},
(_, TyInfo::For(y_t, y)) => {
y_terms.push((y_t, None));
let r = self.unify(x, y, x_terms, y_terms);
y_terms.pop();
r
},
// Now, unify rank-1 types
(TyInfo::Bool, TyInfo::Bool) | (TyInfo::Int, TyInfo::Int) => Ok(()),
(TyInfo::Func(x_i, x_o), TyInfo::Func(y_i, y_o)) => {
self.unify(x_i, y_i, x_terms, y_terms)?;
self.unify(x_o, y_o, x_terms, y_terms)
},
(x, y) => Err(format!("cannot unify {:?} and {:?}", x, y)),
}
}
}
fn main() {
let mut engine = Engine::default();
let x = engine.insert(TyInfo::Term("B"));
let y = engine.insert(TyInfo::Unknown);
let f = engine.insert(TyInfo::Func(x, y));
let f = engine.insert(TyInfo::For("B", f));
let a = engine.insert(TyInfo::Term("A"));
let b = engine.insert(TyInfo::Term("X"));
let b = engine.insert(TyInfo::For("X", b));
let g = engine.insert(TyInfo::Func(a, b));
let g = engine.insert(TyInfo::For("A", g));
let g = engine.insert(TyInfo::For("B", g));
engine.unify(f, g, &mut Vec::new(), &mut Vec::new()).unwrap();
println!("y = {:?}", engine.follow(y));
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment