-
-
Save WilsonGramer/8d6438163213d47ced0eedc8fb87f588 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::collections::HashMap; | |
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] | |
struct Var(&'static str); | |
#[derive(Debug, Clone)] | |
enum TyError { | |
UnresolvedVar(unresolved::TyVar), | |
Recursive(unresolved::TyVar), | |
MismatchedTypes(unresolved::Ty, unresolved::Ty), | |
} | |
#[derive(Debug, Clone, Default)] | |
struct TyCtx { | |
vars: HashMap<Var, unresolved::Ty>, | |
next_ty_var: unresolved::TyVar, | |
substitutions: HashMap<unresolved::TyVar, unresolved::Ty>, | |
errors: Vec<TyError>, | |
} | |
mod untypechecked { | |
use super::*; | |
#[derive(Debug, Clone)] | |
pub struct Expr { | |
pub kind: ExprKind, | |
} | |
#[derive(Debug, Clone)] | |
pub enum ExprKind { | |
String(&'static str), | |
Int(i64), | |
Let(Var, Box<Expr>), | |
Var(Var), | |
Function(Var, Box<Expr>), | |
Annotate(Box<Expr>, unresolved::Ty), | |
Call(Box<Expr>, Box<Expr>), | |
} | |
} | |
impl TyCtx { | |
fn new_ty_var(&mut self) -> unresolved::TyVar { | |
let ty_var = self.next_ty_var; | |
self.next_ty_var.0 += 1; | |
ty_var | |
} | |
} | |
mod unresolved { | |
use super::*; | |
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash)] | |
pub struct TyVar(pub usize); | |
#[derive(Debug, Clone)] | |
pub struct Expr { | |
pub kind: ExprKind, | |
pub ty: Ty, | |
} | |
#[derive(Debug, Clone)] | |
pub enum ExprKind { | |
String(&'static str), | |
Int(i64), | |
Let(Var, Box<Expr>), | |
Var(Var), | |
Function(Var, Box<Expr>), | |
Call(Box<Expr>, Box<Expr>), | |
} | |
#[derive(Debug, Clone, PartialEq, Eq)] | |
pub enum Ty { | |
Var(TyVar), | |
String, | |
Int, | |
Void, | |
Function(Box<Ty>, Box<Ty>), | |
Error, | |
} | |
} | |
mod resolved { | |
use super::*; | |
#[derive(Debug, Clone)] | |
pub struct Expr { | |
pub kind: ExprKind, | |
pub ty: Ty, | |
} | |
#[derive(Debug, Clone)] | |
pub enum ExprKind { | |
String(&'static str), | |
Int(i64), | |
Let(Var, Box<Expr>), | |
Var(Var), | |
Function(Var, Box<Expr>), | |
Call(Box<Expr>, Box<Expr>), | |
} | |
#[derive(Debug, Clone, PartialEq, Eq)] | |
pub enum Ty { | |
String, | |
Int, | |
Void, | |
Function(Box<Ty>, Box<Ty>), | |
Error, | |
} | |
} | |
fn typecheck(expr: untypechecked::Expr, ctx: &mut TyCtx) -> unresolved::Expr { | |
match expr.kind { | |
untypechecked::ExprKind::String(value) => unresolved::Expr { | |
kind: unresolved::ExprKind::String(value), | |
ty: unresolved::Ty::String, | |
}, | |
untypechecked::ExprKind::Int(value) => unresolved::Expr { | |
kind: unresolved::ExprKind::Int(value), | |
ty: unresolved::Ty::Int, | |
}, | |
untypechecked::ExprKind::Let(var, value) => { | |
let value = typecheck(*value, ctx); | |
ctx.vars.insert(var, value.ty.clone()); | |
unresolved::Expr { | |
kind: unresolved::ExprKind::Let(var, Box::new(value)), | |
ty: unresolved::Ty::Void, | |
} | |
} | |
untypechecked::ExprKind::Var(var) => unresolved::Expr { | |
kind: unresolved::ExprKind::Var(var), | |
ty: ctx | |
.vars | |
.get(&var) | |
.expect("variable used before being defined") | |
.clone(), | |
}, | |
untypechecked::ExprKind::Function(var, body) => { | |
let input_ty = unresolved::Ty::Var(ctx.new_ty_var()); | |
ctx.vars.insert(var, input_ty.clone()); | |
let body = typecheck(*body, ctx); | |
let body_ty = body.ty.clone(); | |
unresolved::Expr { | |
kind: unresolved::ExprKind::Function(var, Box::new(body)), | |
ty: unresolved::Ty::Function(Box::new(input_ty), Box::new(body_ty)), | |
} | |
} | |
untypechecked::ExprKind::Annotate(expr, ty) => { | |
let expr = typecheck(*expr, ctx); | |
if let Err(error) = unify(expr.ty.clone(), ty, ctx) { | |
ctx.errors.push(error); | |
return unresolved::Expr { | |
kind: expr.kind, | |
ty: unresolved::Ty::Error, | |
}; | |
} | |
expr | |
} | |
untypechecked::ExprKind::Call(function, input) => { | |
let function = typecheck(*function, ctx); | |
let input = typecheck(*input, ctx); | |
let output_ty = unresolved::Ty::Var(ctx.new_ty_var()); | |
if let Err(error) = unify( | |
function.ty.clone(), | |
unresolved::Ty::Function(Box::new(input.ty.clone()), Box::new(output_ty.clone())), | |
ctx, | |
) { | |
ctx.errors.push(error); | |
return unresolved::Expr { | |
kind: unresolved::ExprKind::Call(Box::new(function), Box::new(input)), | |
ty: unresolved::Ty::Error, | |
}; | |
} | |
unresolved::Expr { | |
kind: unresolved::ExprKind::Call(Box::new(function), Box::new(input)), | |
ty: output_ty, | |
} | |
} | |
} | |
} | |
fn unify( | |
mut actual: unresolved::Ty, | |
mut expected: unresolved::Ty, | |
ctx: &mut TyCtx, | |
) -> Result<(), TyError> { | |
actual.apply(ctx); | |
expected.apply(ctx); | |
match (actual, expected) { | |
(unresolved::Ty::Var(var), ty) | (ty, unresolved::Ty::Var(var)) => { | |
// Don't want to cause an infinite loop by substituting a variable with | |
// itself! | |
if let unresolved::Ty::Var(other) = ty { | |
if var == other { | |
return Ok(()); | |
} | |
} | |
if ty.contains(&var) { | |
Err(TyError::Recursive(var)) | |
} else { | |
ctx.substitutions.insert(var, ty); | |
Ok(()) | |
} | |
} | |
(unresolved::Ty::String, unresolved::Ty::String) | |
| (unresolved::Ty::Int, unresolved::Ty::Int) | |
| (unresolved::Ty::Void, unresolved::Ty::Void) => Ok(()), | |
( | |
unresolved::Ty::Function(actual_input, actual_output), | |
unresolved::Ty::Function(expected_input, expected_output), | |
) => { | |
unify(*actual_input, *expected_input, ctx)?; | |
unify(*actual_output, *expected_output, ctx)?; | |
Ok(()) | |
} | |
(unresolved::Ty::Error, _) | (_, unresolved::Ty::Error) => Ok(()), | |
(actual, expected) => Err(TyError::MismatchedTypes(actual, expected)), | |
} | |
} | |
impl unresolved::Ty { | |
fn apply(&mut self, ctx: &TyCtx) { | |
match self { | |
unresolved::Ty::Var(var) => { | |
if let Some(ty) = ctx.substitutions.get(var) { | |
*self = ty.clone(); | |
// This is necessary when there's a multi-step substitution | |
// (say, A -> B, B -> C, C -> D) -- we want to resolve as | |
// deeply as possible (ie. A -> D) | |
self.apply(ctx); | |
} | |
} | |
unresolved::Ty::Function(input, output) => { | |
input.apply(ctx); | |
output.apply(ctx); | |
} | |
_ => {} | |
} | |
} | |
fn contains(&self, var: &unresolved::TyVar) -> bool { | |
match self { | |
unresolved::Ty::Var(other) => var == other, | |
unresolved::Ty::Function(input, output) => input.contains(var) || output.contains(var), | |
_ => false, | |
} | |
} | |
} | |
fn resolve(expr: unresolved::Expr, ctx: &mut TyCtx) -> resolved::Expr { | |
resolved::Expr { | |
kind: match expr.kind { | |
unresolved::ExprKind::String(value) => resolved::ExprKind::String(value), | |
unresolved::ExprKind::Int(value) => resolved::ExprKind::Int(value), | |
unresolved::ExprKind::Let(var, expr) => { | |
resolved::ExprKind::Let(var, Box::new(resolve(*expr, ctx))) | |
} | |
unresolved::ExprKind::Var(var) => resolved::ExprKind::Var(var), | |
unresolved::ExprKind::Function(input, body) => { | |
resolved::ExprKind::Function(input, Box::new(resolve(*body, ctx))) | |
} | |
unresolved::ExprKind::Call(function, input) => resolved::ExprKind::Call( | |
Box::new(resolve(*function, ctx)), | |
Box::new(resolve(*input, ctx)), | |
), | |
}, | |
ty: resolve_ty(expr.ty, ctx), | |
} | |
} | |
fn resolve_ty(mut ty: unresolved::Ty, ctx: &mut TyCtx) -> resolved::Ty { | |
ty.apply(ctx); | |
match ty { | |
unresolved::Ty::Var(var) => { | |
ctx.errors.push(TyError::UnresolvedVar(var)); | |
resolved::Ty::Error | |
} | |
unresolved::Ty::String => resolved::Ty::String, | |
unresolved::Ty::Int => resolved::Ty::Int, | |
unresolved::Ty::Void => resolved::Ty::Void, | |
unresolved::Ty::Function(input, output) => resolved::Ty::Function( | |
Box::new(resolve_ty(*input, ctx)), | |
Box::new(resolve_ty(*output, ctx)), | |
), | |
unresolved::Ty::Error => resolved::Ty::Error, | |
} | |
} | |
fn main() { | |
let mut ctx = TyCtx::default(); | |
let untypechecked_program = vec![ | |
// let identity = (x) => x | |
untypechecked::Expr { | |
kind: untypechecked::ExprKind::Let( | |
Var("identity"), | |
Box::new(untypechecked::Expr { | |
kind: untypechecked::ExprKind::Function( | |
Var("x"), | |
Box::new(untypechecked::Expr { | |
kind: untypechecked::ExprKind::Var(Var("x")), | |
}), | |
), | |
}), | |
), | |
}, | |
// identity(42) | |
untypechecked::Expr { | |
kind: untypechecked::ExprKind::Call( | |
Box::new(untypechecked::Expr { | |
kind: untypechecked::ExprKind::Var(Var("identity")), | |
}), | |
Box::new(untypechecked::Expr { | |
kind: untypechecked::ExprKind::Int(42), | |
}), | |
), | |
}, | |
]; | |
let unresolved_program = untypechecked_program | |
.into_iter() | |
.map(|expr| typecheck(expr, &mut ctx)) | |
.collect::<Vec<_>>(); | |
let resolved_program = unresolved_program | |
.into_iter() | |
.map(|expr| resolve(expr, &mut ctx)) | |
.collect::<Vec<_>>(); | |
println!("{:#?}", resolved_program); | |
println!("{:#?}", ctx.errors); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment