Last active
July 1, 2023 00:52
-
-
Save zesterer/96e1449d721f006c4a9191c359892445 to your computer and use it in GitHub Desktop.
Type inference via unification with support for higher-ranked types
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
type Term = &'static str; | |
#[derive(Clone, Debug)] | |
enum TyInfo { | |
Unknown, | |
Ref(TyVar), | |
Term(Term), | |
Bool, | |
Int, | |
Func(TyVar, TyVar), | |
For(Term, TyVar), | |
} | |
type TyVar = usize; | |
#[derive(Default)] | |
struct Engine { | |
vars: Vec<TyInfo>, | |
} | |
impl Engine { | |
fn insert(&mut self, info: TyInfo) -> TyVar { | |
let var = self.vars.len(); | |
self.vars.push(info); | |
var | |
} | |
fn follow(&self, var: TyVar) -> TyInfo { | |
match &self.vars[var] { | |
TyInfo::Ref(var) => self.follow(*var), | |
info => info.clone(), | |
} | |
} | |
/// Unify type variables originating from the same type | |
fn unify_same( | |
&mut self, | |
x: TyVar, | |
y: TyVar, | |
) -> Result<(), String> { | |
println!("Unify same {:?} with {:?}", self.follow(x), self.follow(y)); | |
match (self.follow(x), self.follow(y)) { | |
(TyInfo::Unknown, _) => Ok(self.vars[x] = TyInfo::Ref(y)), | |
(_, TyInfo::Unknown) => Ok(self.vars[y] = TyInfo::Ref(x)), | |
(TyInfo::Term(x_t), TyInfo::Term(y_t)) if x_t == y_t => Ok(()), | |
(TyInfo::Bool, TyInfo::Bool) | (TyInfo::Int, TyInfo::Int) => Ok(()), | |
(TyInfo::Func(x_i, x_o), TyInfo::Func(y_i, y_o)) => { | |
self.unify_same(x_i, y_i)?; | |
self.unify_same(x_o, y_o) | |
}, | |
(TyInfo::For(x_t, x), TyInfo::For(y_t, y)) if x_t == y_t => self.unify_same(x, y), | |
(x, y) => Err(format!("cannot unify {:?} and {:?}", x, y)), | |
} | |
} | |
/// Unify type variables originating from different types, allowing instantiations where required | |
fn unify( | |
&mut self, | |
x: TyVar, | |
y: TyVar, | |
x_terms: &mut Vec<(Term, Option<TyVar>)>, | |
y_terms: &mut Vec<(Term, Option<TyVar>)>, | |
) -> Result<(), String> { | |
println!("Unify {:?} with {:?}", self.follow(x), self.follow(y)); | |
let (x_info, y_info) = (self.follow(x), self.follow(y)); | |
match (x_info, y_info) { | |
(TyInfo::Unknown, _) => Ok(self.vars[x] = TyInfo::Ref(y)), | |
(_, TyInfo::Unknown) => Ok(self.vars[y] = TyInfo::Ref(x)), | |
// Handle higher-ranked types | |
(TyInfo::Term(x_t), TyInfo::Term(y_t)) => { | |
match x_terms | |
.iter_mut() | |
.find(|(t, _)| *t == x_t) | |
{ | |
Some((_, Some(t))) => self.unify_same(*t, y)?, | |
Some((_, t @ None)) => *t = Some(y), | |
None => Err(format!("Unknown term '{}'", x_t))?, | |
} | |
match y_terms | |
.iter_mut() | |
.find(|(t, _)| *t == y_t) | |
{ | |
Some((_, Some(t))) => self.unify_same(*t, x)?, | |
Some((_, t @ None)) => *t = Some(x), | |
None => Err(format!("Unknown term '{}'", y_t))?, | |
} | |
Ok(()) | |
}, | |
(TyInfo::Term(x_t), _) => match x_terms | |
.iter_mut() | |
.find(|(t, _)| *t == x_t) | |
{ | |
Some((_, Some(t))) => self.unify_same(y, *t), | |
Some((_, t @ None)) => Ok(*t = Some(y)), | |
None => Err(format!("Unknown term '{}'", x_t)), | |
}, | |
(_, TyInfo::Term(y_t)) => match y_terms | |
.iter_mut() | |
.find(|(t, _)| *t == y_t) | |
{ | |
Some((_, Some(t))) => self.unify_same(x, *t), | |
Some((_, t @ None)) => Ok(*t = Some(x)), | |
None => Err(format!("Unknown term '{}'", y_t)), | |
}, | |
(TyInfo::For(x_t, x), _) => { | |
x_terms.push((x_t, None)); | |
let r = self.unify(x, y, x_terms, y_terms); | |
x_terms.pop(); | |
r | |
}, | |
(_, TyInfo::For(y_t, y)) => { | |
y_terms.push((y_t, None)); | |
let r = self.unify(x, y, x_terms, y_terms); | |
y_terms.pop(); | |
r | |
}, | |
// Now, unify rank-1 types | |
(TyInfo::Bool, TyInfo::Bool) | (TyInfo::Int, TyInfo::Int) => Ok(()), | |
(TyInfo::Func(x_i, x_o), TyInfo::Func(y_i, y_o)) => { | |
self.unify(x_i, y_i, x_terms, y_terms)?; | |
self.unify(x_o, y_o, x_terms, y_terms) | |
}, | |
(x, y) => Err(format!("cannot unify {:?} and {:?}", x, y)), | |
} | |
} | |
} | |
fn main() { | |
let mut engine = Engine::default(); | |
let x = engine.insert(TyInfo::Term("B")); | |
let y = engine.insert(TyInfo::Unknown); | |
let f = engine.insert(TyInfo::Func(x, y)); | |
let f = engine.insert(TyInfo::For("B", f)); | |
let a = engine.insert(TyInfo::Term("A")); | |
let b = engine.insert(TyInfo::Term("X")); | |
let b = engine.insert(TyInfo::For("X", b)); | |
let g = engine.insert(TyInfo::Func(a, b)); | |
let g = engine.insert(TyInfo::For("A", g)); | |
let g = engine.insert(TyInfo::For("B", g)); | |
engine.unify(f, g, &mut Vec::new(), &mut Vec::new()).unwrap(); | |
println!("y = {:?}", engine.follow(y)); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment