Skip to content

Instantly share code, notes, and snippets.

@the-sofi-uwu
Last active February 12, 2024 22:34
Show Gist options
  • Save the-sofi-uwu/bf79be8eee70ea224b23ec5820e3624b to your computer and use it in GitHub Desktop.
Save the-sofi-uwu/bf79be8eee70ea224b23ec5820e3624b to your computer and use it in GitHub Desktop.
Dependent type checker with substitution for lambda calculus.
use std::{collections::HashSet, fmt::Display, rc::Rc};
/// The AST. This thing describes
/// the syntatic tree of the program.
#[derive(Debug)]
pub enum Syntax {
Lambda {
param: String,
body: Rc<Syntax>,
},
App {
fun: Rc<Syntax>,
arg: Rc<Syntax>,
},
Var {
name: String,
},
Pi {
param: String,
typ: Rc<Syntax>,
body: Rc<Syntax>,
},
Ann {
expr: Rc<Syntax>,
typ: Rc<Syntax>,
},
Typ,
}
/// Pretty printing of the code.
impl Display for Syntax {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Syntax::Lambda { param, body } => write!(f, "(λ{param}. {body})"),
Syntax::App { fun, arg } => write!(f, "({fun} {arg})"),
Syntax::Var { name } => write!(f, "{name}"),
Syntax::Pi { param, typ, body } => write!(f, "(({param} : {typ}) -> {body})"),
Syntax::Ann { expr: body, typ } => write!(f, "({body} : {typ})"),
Syntax::Typ => write!(f, "Type"),
}
}
}
impl Syntax {
/// Collects "Free variables". Free variables (FV) are variables that
/// would give us a error saying that "this variable is not defined" is any language.
/// BUT here we need them because "free variables" may be used by the context
/// or outside of the expression we are analysing right now.
///
/// E.g: If we are analysing the BODY of "λx. x" ("x") them "x" is a free variable locally
/// but if we analyse the entire expression then "x" is bound. The something similar happens
/// with the Pi type "(x : T) -> G x". "x" is free on "G x" but bound on the rest.
pub fn free_vars(&self) -> HashSet<String> {
let mut fv = HashSet::new();
// Here we are using the "im" package that provides "immutable" HashSets for efficiency
// So clones are almost O(1).
fn collect(expr: &Syntax, ctx: im::HashSet<String>, fv: &mut HashSet<String>) {
match expr {
Syntax::Lambda { param, body } => {
let mut new_ctx = ctx.clone();
new_ctx.insert(param.clone());
collect(&body.clone(), new_ctx, fv)
}
Syntax::App { fun, arg } => {
collect(&fun.clone(), ctx.clone(), fv);
collect(&arg.clone(), ctx, fv);
}
Syntax::Var { name } => {
if !ctx.contains(name) {
fv.insert(name.clone());
}
}
Syntax::Pi { param, typ, body } => {
collect(&typ.clone(), ctx.clone(), fv);
let mut new_ctx = ctx;
new_ctx.insert(param.clone());
collect(&body.clone(), new_ctx, fv);
}
Syntax::Ann { expr: body, typ } => {
collect(&body.clone(), ctx.clone(), fv);
collect(&typ.clone(), ctx, fv);
}
Syntax::Typ => (),
}
}
collect(&self, Default::default(), &mut fv);
fv
}
}
/// Now the type checker runs inside a "type checker context"
/// in order to create new names.
pub struct TyCtx {
pub name_counter: u64,
}
impl TyCtx {
/// Generates a new name based on the counter.
pub fn new_name(&mut self) -> String {
let mut str = String::new();
let mut count = self.name_counter;
loop {
let chr = count % 26;
count = count / 26;
str.push((chr + 96) as u8 as char);
if count <= 0 {
break;
}
}
self.name_counter += 1;
str.push('\'');
str.chars().rev().collect()
}
/// Substitutes (expr[from = to]) the variable "from" to the variable "to" in the expression
/// "expr".
pub fn subst(&mut self, expr: Rc<Syntax>, from: &String, to: Rc<Syntax>) -> Rc<Syntax> {
match &*expr {
// If the variable has the same name of "from" then we just returns "to".
Syntax::Var { name } if name == from => to.clone(),
// If the "param" is equal to "from" then we are facing a "Shadowing".
Syntax::Lambda { param, body } if param != from => {
// Here we have a special case that is uur really bad to treat btw, we have to treat it.
// It's the case when we have to substitute (\x. E)[y = x] and `x` happens on the `to`
// parameter. in this case we have to change all of the `x` in `E` to something new.
let (param, body) = if to.free_vars().contains(param) {
let new_var = self.new_name();
(
new_var.clone(),
self.subst(body.clone(), param, Rc::new(Syntax::Var { name: new_var })),
)
} else {
(param.clone(), body.clone())
};
Rc::new(Syntax::Lambda {
param,
body: self.subst(body, from, to),
})
}
// We substitute both sides.
Syntax::App { fun, arg } => Rc::new(Syntax::App {
fun: self.subst(fun.clone(), from, to.clone()),
arg: self.subst(arg.clone(), from, to.clone()),
}),
Syntax::Pi { param, typ, body } => {
let (param, body) = if to.free_vars().contains(param) {
let new_var = self.new_name();
(
new_var.clone(),
self.subst(body.clone(), param, Rc::new(Syntax::Var { name: new_var })),
)
} else {
(param.clone(), body.clone())
};
Rc::new(Syntax::Pi {
param: param.clone(),
typ: self.subst(typ.clone(), from, to.clone()),
body: if param == *from {
body.clone()
} else {
self.subst(body.clone(), from, to.clone())
},
})
}
Syntax::Ann { expr: body, typ } => Rc::new(Syntax::Ann {
expr: self.subst(body.clone(), from, to.clone()),
typ: self.subst(typ.clone(), from, to.clone()),
}),
_ => expr.clone(),
}
}
/// Gets a lambda expression and evalutes it to it's "Weak head normal form"
pub fn eval(&mut self, expr: Rc<Syntax>) -> Rc<Syntax> {
match &*expr {
Syntax::App { fun, arg } => match &*self.eval(fun.clone()) {
Syntax::Lambda { param, body } => {
let arg = self.eval(arg.clone());
let res = self.subst(body.clone(), param, arg);
self.eval(res)
}
_ => expr,
},
Syntax::Ann { expr, typ: _ } => {
expr.clone()
}
_ => expr,
}
}
/// Strong normalize the ENTIRE expression it's fucked up I think?
pub fn reduce(&mut self, expr: Rc<Syntax>) -> Rc<Syntax> {
match &*expr {
Syntax::App { fun, arg } => match &*self.reduce(fun.clone()) {
Syntax::Lambda { param, body } => {
let arg = self.reduce(arg.clone());
let res = self.subst(body.clone(), param, arg);
self.reduce(res)
}
_ => {
app(self.reduce(fun.clone()), self.reduce(arg.clone()))
}
},
Syntax::Lambda { param, body } => {
lam(param, self.reduce(body.clone()))
},
Syntax::Pi { param, typ, body } => {
pi(param, self.reduce(typ.clone()), self.reduce(body.clone()))
}
Syntax::Ann { expr, typ: _ } => {
self.reduce(expr.clone())
},
_ => expr.clone()
}
}
}
// Some helper functions
pub fn var(name: &str) -> Rc<Syntax> {
Rc::new(Syntax::Var {
name: name.to_string(),
})
}
pub fn typ() -> Rc<Syntax> {
Rc::new(Syntax::Typ)
}
pub fn app(fun: Rc<Syntax>, arg: Rc<Syntax>) -> Rc<Syntax> {
Rc::new(Syntax::App { fun, arg })
}
pub fn ann(expr: Rc<Syntax>, typ: Rc<Syntax>) -> Rc<Syntax> {
Rc::new(Syntax::Ann { expr, typ })
}
pub fn lam(param: &str, body: Rc<Syntax>) -> Rc<Syntax> {
Rc::new(Syntax::Lambda {
param: param.to_string(),
body,
})
}
pub fn pi(param: &str, typ: Rc<Syntax>, body: Rc<Syntax>) -> Rc<Syntax> {
Rc::new(Syntax::Pi {
param: param.to_string(),
typ,
body,
})
}
// An immutable environment for the type checking phase
type Env = im::HashMap<String, Rc<Syntax>>;
// Type checking functions
impl TyCtx {
// The equal function is more commonly defined as "conv" (convergence)
// it checks if two expressions are equal
pub fn conv(&mut self, left: Rc<Syntax>, right: Rc<Syntax>) -> bool {
match (&*self.eval(left), &*self.eval(right)) {
(Syntax::Var { name: name_a }, Syntax::Var { name: name_b }) => name_a == name_b,
(
Syntax::Lambda {
param: pa,
body: ba,
},
Syntax::Lambda {
param: pb,
body: bb,
},
) => {
let n = self.new_name();
// Here we rename two expression so in the end they become "alpha equivalebnt".
// e.g: (\x.x) = (\y.y) but they have different names so we change the names of the
// inside to 'a and we end up with (\'a. 'a) = (\'a. 'a) but we can discard the \'a and
// compare the inside part. 'a = 'a
let ba_subst = self.subst(ba.clone(), pa, var(&n));
let bb_subst = self.subst(bb.clone(), pb, var(&n));
self.conv(ba_subst, bb_subst)
}
(
Syntax::Pi {
param: pa,
typ: ta,
body: ba,
},
Syntax::Pi {
param: pb,
typ: tb,
body: bb,
},
) => {
let n = self.new_name();
let ba_subst = self.subst(ba.clone(), pa, var(&n));
let bb_subst = self.subst(bb.clone(), pb, var(&n));
self.conv(ta.clone(), tb.clone()) && self.conv(ba_subst, bb_subst)
}
(Syntax::Ann { expr, typ }, Syntax::Ann { expr: eb, typ: tb }) => {
self.conv(expr.clone(), eb.clone()) && self.conv(typ.clone(), tb.clone())
}
(Syntax::App { fun, arg }, Syntax::App { fun: fb, arg: ab }) => {
self.conv(fun.clone(), fb.clone()) && self.conv(arg.clone(), ab.clone())
}
(Syntax::Typ, Syntax::Typ) => true,
(_, _) => {
false
},
}
}
pub fn check(&mut self, ctx: Env, expr: Rc<Syntax>, typ: Rc<Syntax>) {
let expected = self.eval(typ);
match (&*expr, &*expected) {
// Γ ⊢ λx. e ⇐ (y: A) -> B
(Syntax::Lambda { param, body }, Syntax::Pi { param: pb, typ, body: tb }) => {
// Γ, x : A
let mut new_ctx = ctx.clone();
new_ctx.insert(param.clone(), typ.clone());
// B[y = x]
let ret_type = self.subst(tb.clone(), pb, var(&param));
// e ⇐ B[y = x]
self.check(new_ctx, body.clone(), ret_type);
},
// Γ ⊢ x ⇐ A
(_, _) => {
// Γ ⊢ x => B
let infered = self.infer(ctx, expr.clone());
// A = B
if !self.conv(expected.clone(), infered.clone()) {
panic!("Type '{}' does not match with '{}'", expected, infered)
}
}
}
}
pub fn infer(&mut self, ctx: Env, expr: Rc<Syntax>) -> Rc<Syntax> {
match &*expr {
Syntax::Lambda { .. } => panic!("Cannot infer lambda"),
// Γ ⊢ a b => B[x = b]
Syntax::App { fun, arg } => {
// Γ ⊢ a => (x: A) -> B
let fun_ty = self.infer(ctx.clone(), fun.clone());
if let Syntax::Pi { param, typ, body } = &*fun_ty {
// Γ ⊢ b ⇐ A
self.check(ctx, arg.clone(), typ.clone());
// B[x = b]
self.subst(body.clone(), param, arg.clone())
} else {
panic!("Not a function to apply")
}
},
// Γ ⊢ x => A
Syntax::Var { name } => {
// x : A ∈ Γ
if let Some(ty) = ctx.get(name) {
ty.clone()
} else {
panic!("Cannot find variable '{name}'")
}
},
// Γ ⊢ (x: A) -> B => Type
Syntax::Pi { param, typ: tipo, body } => {
// Γ ⊢ A ⇐ Type
self.check(ctx.clone(), tipo.clone(), typ());
// Γ, x : A ⊢ B ⇐ Type
let mut new_ctx = ctx.clone();
new_ctx.insert(param.clone(), tipo.clone());
self.check(new_ctx, body.clone(), typ());
typ()
},
// Γ ⊢ e : A => A
Syntax::Ann { expr, typ: tipo } => {
// A <= Type
self.check(ctx.clone(), tipo.clone(), typ());
// e <= A
self.check(ctx, expr.clone(), tipo.clone());
tipo.clone()
},
// Γ ⊢ Type => Type
Syntax::Typ => {
typ()
},
}
}
}
fn main() {
let mut tyctx = TyCtx { name_counter: 1 };
// Encoding Nat as pi type with church encoding
// type nat : Type {
// zero : nat
// succ : nat -> nat
// }
let nat =
pi("nat", typ(),
pi("zero", var("nat"),
pi("succ", pi("_", var("nat"), var("nat")),
var("nat"))));
// Nat is a type
tyctx.check(Default::default(), nat.clone(), typ());
// \_ -> \z -> \s -> \z
let zero =
lam("_",
lam("z",
lam("s",
var("z"))));
// Zero is a natural
let zero = ann(zero.clone(), nat.clone());
tyctx.infer(Default::default(), zero.clone());
// \m -> \ty -> \z -> \s -> (s (m ty z s))
let succ =
lam("m",
lam("ty",
lam("z",
lam("s",
app(var("s"), app(app(app(var("m"), var("ty")), var("z")), var("s")))))));
// Succ is a (natural -> natural)
let succ = ann(succ.clone(), pi("_", nat.clone(), nat.clone()));
tyctx.infer(Default::default(), succ.clone());
// \m -> \n -> \ty -> \z -> \s -> ((m ty) (n ty z s)) s
let add =
lam("m",
lam("n",
lam("ty",
lam("z",
lam("s",
app(app(app(var("m"), var("ty")), app(app(app(var("n"), var("ty")), var("z")), var("s"))), var("s")))))));
// We always have to anotate these things \:P
// Add is a (natural -> -> nat natural)
let add = ann(add.clone(), pi("_", nat.clone(), pi("_", nat.clone(), nat.clone())));
tyctx.infer(Default::default(), add.clone());
// Remember it's in Weak head normal form so it does not reduce until the end
let one = app(succ.clone(), zero.clone());
let two = app(succ.clone(), one.clone());
let three = app(succ.clone(), two.clone());
let four = app(succ.clone(), three.clone());
let five = app(succ.clone(), four.clone());
let added_five = app(app(add.clone(), three.clone()), two.clone());
tyctx.check(Default::default(), added_five.clone(), nat.clone());
let redc_added = tyctx.reduce(added_five.clone());
let redc_five = tyctx.reduce(five.clone());
let added_five = app(app(add.clone(), three.clone()), two.clone());
let added_inv_five = app(app(add.clone(), two.clone()), three.clone());
// Testing if the strong is equal to the non evaluated
println!("{}", tyctx.conv(added_five.clone(), added_inv_five));
println!("{}", tyctx.conv(redc_added, redc_five));
println!("{}", tyctx.conv(added_five, five));
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment