Skip to content

Instantly share code, notes, and snippets.

@mbillingr
Last active March 5, 2024 15:23
Show Gist options
  • Save mbillingr/18a673b7588aaa3b0befe3d76f128de1 to your computer and use it in GitHub Desktop.
Save mbillingr/18a673b7588aaa3b0befe3d76f128de1 to your computer and use it in GitHub Desktop.
use std::cell::RefCell;
use std::collections::{HashMap, HashSet};
use std::rc::Rc;
use std::sync::atomic::{AtomicU64, Ordering};
fn main() {
use Term::*;
let id = Lam {
name: "x".into(),
body: Var { name: "x".into() }.into(),
};
let expr = App {
lhs: id.clone().into(),
rhs: Lit { value: 42 }.into(),
};
let ctx = Ctx::default();
let st = ctx.type_term(&expr);
let ty = coalesce_type(&st);
println!("{:?}", ty);
let twice = Lam {
name: "f".into(),
body: Lam {
name: "x".into(),
body: App {
lhs: Var { name: "f".into() }.into(),
rhs: App {
lhs: Var { name: "f".into() }.into(),
rhs: Var { name: "x".into() }.into(),
}
.into(),
}
.into(),
}
.into(),
};
let ctx = Ctx::default();
println!("{:?}", coalesce_type(&ctx.type_term(&twice)))
}
type Int = i64;
type Str = Ref<str>;
type Ref<T> = Rc<T>;
#[derive(Clone)]
enum Term {
Lit {
value: Int,
},
Var {
name: Str,
},
Lam {
name: Str,
body: Ref<Term>,
},
App {
lhs: Ref<Term>,
rhs: Ref<Term>,
},
Rcd {
fields: Vec<(Str, Term)>,
},
Sel {
receiver: Ref<Term>,
field_name: Str,
},
Let {
is_rec: bool,
name: Str,
rhs: Ref<Term>,
body: Ref<Term>,
},
}
#[derive(Clone, Debug)]
enum SimpleType {
Variable(Ref<VariableState>),
Primitive(Str),
Function(Ref<SimpleType>, Ref<SimpleType>),
Record(Ref<HashMap<Str, SimpleType>>),
}
#[derive(Debug)]
struct VariableState {
lower_bounds: RefCell<List<SimpleType>>,
upper_bounds: RefCell<List<SimpleType>>,
unique_name: Str,
}
enum Type {
Top,
Bot,
Union(Ref<Type>, Ref<Type>),
Inter(Ref<Type>, Ref<Type>),
Function(Ref<Type>, Ref<Type>),
Record(Ref<HashMap<Str, Type>>),
Recursive { name: Str, body: Ref<Type> },
Variable(Str),
Primitive(Str),
}
#[derive(Clone)]
struct PolarVariable(Ref<VariableState>, P);
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
enum P {
Val,
Use,
}
impl std::ops::Not for P {
type Output = Self;
fn not(self) -> Self {
match self {
P::Val => P::Use,
P::Use => P::Val,
}
}
}
impl Eq for SimpleType {}
impl PartialEq for SimpleType {
fn eq(&self, other: &Self) -> bool {
use SimpleType::*;
match (self, other) {
(Variable(a), Variable(b)) => Ref::ptr_eq(a, b),
(Primitive(a), Primitive(b)) => a == b,
(Function(a1, r1), Function(a2, r2)) => a1 == a2 && r1 == r2,
(Record(a), Record(b)) => Ref::ptr_eq(a, b),
_ => false,
}
}
}
impl std::hash::Hash for SimpleType {
fn hash<H: std::hash::Hasher>(&self, h: &mut H) {
match self {
SimpleType::Variable(rc) => std::ptr::hash(Rc::as_ptr(rc), h),
SimpleType::Primitive(name) => name.hash(h),
SimpleType::Function(a, r) => {
a.hash(h);
r.hash(h);
}
SimpleType::Record(fs) => {
for f in fs.iter() {
f.hash(h);
}
}
}
}
}
impl Eq for PolarVariable {}
impl PartialEq for PolarVariable {
fn eq(&self, other: &Self) -> bool {
Rc::ptr_eq(&self.0, &other.0) && self.1 == other.1
}
}
impl std::hash::Hash for PolarVariable {
fn hash<H: std::hash::Hasher>(&self, h: &mut H) {
std::ptr::hash(Rc::as_ptr(&self.0), h);
self.1.hash(h)
}
}
impl VariableState {
pub fn new() -> Self {
VariableState {
lower_bounds: Default::default(),
upper_bounds: Default::default(),
unique_name: unique_name(),
}
}
}
impl std::fmt::Debug for Type {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
Type::Top => write!(f, "⊤"),
Type::Bot => write!(f, "⊥"),
Type::Union(lhs, rhs) => write!(f, "{lhs:?} ∨ {rhs:?}"),
Type::Inter(lhs, rhs) => write!(f, "{lhs:?} ∧ {rhs:?}"),
Type::Variable(v) => write!(f, "{v}"),
Type::Primitive(n) => write!(f, "{n}"),
Type::Function(lhs, rhs) => write!(f, "({:?} -> {:?})", lhs, rhs),
Type::Record(fs) => {
write!(f, "{{")?;
let mut fs = fs.iter();
if let Some((n, t)) = fs.next() {
write!(f, "{n}: {t:?}")?;
}
for (n, t) in fs {
write!(f, ", {n}: {t:?}")?;
}
write!(f, "}}")
}
Type::Recursive{name, body} => write!(f, "({body:?} as {name})"),
}
}
}
impl SimpleType {
fn function(lhs: SimpleType, rhs: SimpleType) -> Self {
Self::Function(Ref::new(lhs), Ref::new(rhs))
}
}
#[derive(Default)]
struct Ctx {
vars: HashMap<Str, SimpleType>,
constraint_cache: Ref<RefCell<HashSet<(SimpleType, SimpleType)>>>,
}
impl Ctx {
fn type_term(&self, term: &Term) -> SimpleType {
use SimpleType::*;
use Term::*;
match term {
Lit { .. } => Primitive("int".into()),
Var { name } => self
.vars
.get(&**name)
.cloned()
.unwrap_or_else(|| Self::err(format!("{} not found", name))),
Rcd { fields } => Record(Ref::new(
fields
.iter()
.map(|(n, t)| (n.clone(), self.type_term(t)))
.collect(),
)),
Lam { name, body } => {
let param = self.fresh_var();
let ctx_ = self.bind_var(name.clone(), param.clone());
SimpleType::function(param, ctx_.type_term(body))
}
App { lhs, rhs } => {
let res = self.fresh_var();
self.constrain(
self.type_term(lhs),
SimpleType::function(self.type_term(rhs), res.clone()),
);
res
}
Sel {
receiver,
field_name,
} => {
let res = self.fresh_var();
let mut rec = HashMap::default();
rec.insert(field_name.clone(), res.clone());
self.constrain(self.type_term(receiver), Record(Ref::new(rec)));
res
}
_ => todo!(),
}
}
fn constrain(&self, lhs: SimpleType, rhs: SimpleType) {
let types = (lhs, rhs);
{
let mut cc = self.constraint_cache.borrow_mut();
if cc.contains(&types) {
return;
}
cc.insert(types.clone());
}
use SimpleType::*;
match types {
(Primitive(a), Primitive(b)) if a == b => {}
(Function(a1, r1), Function(a2, r2)) => {
self.constrain((*a2).clone(), (*a1).clone());
self.constrain((*r1).clone(), (*r2).clone());
}
(Record(ref fs1), Record(fs2)) => {
for (n2, t2) in fs2.iter() {
match fs1.get(n2) {
Some(t1) => self.constrain(t1.clone(), t2.clone()),
None => Self::err(format!("missing field: {n2} in {:?}", types.0)),
}
}
}
(Variable(lhs), rhs) => {
cons_(rhs.clone(), &lhs.upper_bounds);
for lb in &(*lhs.lower_bounds.borrow()).clone() {
self.constrain(lb.clone(), rhs.clone());
}
}
(lhs, Variable(rhs)) => {
cons_(lhs.clone(), &rhs.lower_bounds);
for ub in &(*rhs.upper_bounds.borrow()).clone() {
self.constrain(lhs.clone(), ub.clone());
}
}
(lhs, rhs) => Self::err(format!("cannot constrain {lhs:?} <: {rhs:?}")),
}
}
fn fresh_var(&self) -> SimpleType {
SimpleType::Variable(Ref::new(VariableState::new()))
}
fn err(msg: impl ToString) -> ! {
panic!("type error: {}", msg.to_string())
}
}
impl Ctx {
fn bind_var(&self, name: Str, ty: SimpleType) -> Self {
let mut vars = self.vars.clone();
vars.insert(name, ty);
Ctx {
vars,
constraint_cache: self.constraint_cache.clone(),
}
}
}
fn coalesce_type(ty: &SimpleType) -> Type {
let mut recursive: HashMap<PolarVariable, Str> = Default::default();
go(ty, P::Val, &Default::default(), &mut recursive)
}
fn go(
ty: &SimpleType,
polar: P,
in_process: &HashSet<PolarVariable>,
recursive: &mut HashMap<PolarVariable, Str>,
) -> Type {
match ty {
SimpleType::Primitive(name) => Type::Primitive(name.clone()),
SimpleType::Function(lhs, rhs) => Type::Function(
Ref::new(go(lhs, !polar, in_process, recursive)),
Ref::new(go(rhs, polar, in_process, recursive)),
),
SimpleType::Record(fs) => Type::Record(Ref::new(
fs.iter()
.map(|(n, t)| (n.clone(), go(t, polar, in_process, recursive)))
.collect(),
)),
SimpleType::Variable(vs) => {
let vs_pol = PolarVariable(vs.clone(), polar);
if in_process.contains(&vs_pol) {
let name = recursive.entry(vs_pol).or_insert_with(|| unique_name());
Type::Variable(name.clone())
} else {
let bounds = match polar {
P::Val => (*vs.lower_bounds.borrow()).clone(),
P::Use => (*vs.upper_bounds.borrow()).clone(),
};
let mut ip_ = in_process.clone();
ip_.insert(vs_pol.clone());
let mut bound_types = vec![];
for b in &bounds {
bound_types.push(go(b, polar, &ip_, recursive))
}
let mut res = Type::Variable(vs.unique_name.clone());
for t in bound_types {
match polar {
P::Val => res = Type::Union(Ref::new(t), Ref::new(res)),
P::Use => res = Type::Inter(Ref::new(t), Ref::new(res)),
}
}
match recursive.get(&vs_pol) {
None => res,
Some(name) => Type::Recursive {
name: name.clone(),
body: Ref::new(res),
},
}
}
}
}
}
fn unique_name() -> Str {
format!("'{}", GLOBAL_COUNTER.fetch_add(1, Ordering::SeqCst)).into()
}
static GLOBAL_COUNTER: AtomicU64 = AtomicU64::new(0);
#[derive(Debug)]
enum List<T> {
Nil,
Item(Rc<(T, Self)>),
}
impl<T> Default for List<T> {
fn default() -> Self {
Self::new()
}
}
impl<T> Clone for List<T> {
fn clone(&self) -> Self {
match self {
List::Nil => List::Nil,
List::Item(rc) => List::Item(rc.clone()),
}
}
}
impl<T> List<T> {
pub fn new() -> Self {
List::Nil
}
}
impl<'a, T> Iterator for &'a List<T> {
type Item = &'a T;
fn next(&mut self) -> Option<&'a T> {
match self {
List::Nil => None,
List::Item(rc) => {
*self = &rc.1;
Some(&rc.0)
}
}
}
}
pub fn cons<T>(x: T, xs: List<T>) -> List<T> {
List::Item(Rc::new((x, xs)))
}
pub fn cons_<T>(x: T, xs: &RefCell<List<T>>) {
let xs_ = xs.borrow().clone();
*xs.borrow_mut() = cons(x, xs_);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment