Navigation Menu

Skip to content

Instantly share code, notes, and snippets.

@WilsonGramer
Created August 8, 2022 03:05
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save WilsonGramer/8d6438163213d47ced0eedc8fb87f588 to your computer and use it in GitHub Desktop.
Save WilsonGramer/8d6438163213d47ced0eedc8fb87f588 to your computer and use it in GitHub Desktop.
use std::collections::HashMap;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
struct Var(&'static str);
#[derive(Debug, Clone)]
enum TyError {
UnresolvedVar(unresolved::TyVar),
Recursive(unresolved::TyVar),
MismatchedTypes(unresolved::Ty, unresolved::Ty),
}
#[derive(Debug, Clone, Default)]
struct TyCtx {
vars: HashMap<Var, unresolved::Ty>,
next_ty_var: unresolved::TyVar,
substitutions: HashMap<unresolved::TyVar, unresolved::Ty>,
errors: Vec<TyError>,
}
mod untypechecked {
use super::*;
#[derive(Debug, Clone)]
pub struct Expr {
pub kind: ExprKind,
}
#[derive(Debug, Clone)]
pub enum ExprKind {
String(&'static str),
Int(i64),
Let(Var, Box<Expr>),
Var(Var),
Function(Var, Box<Expr>),
Annotate(Box<Expr>, unresolved::Ty),
Call(Box<Expr>, Box<Expr>),
}
}
impl TyCtx {
fn new_ty_var(&mut self) -> unresolved::TyVar {
let ty_var = self.next_ty_var;
self.next_ty_var.0 += 1;
ty_var
}
}
mod unresolved {
use super::*;
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash)]
pub struct TyVar(pub usize);
#[derive(Debug, Clone)]
pub struct Expr {
pub kind: ExprKind,
pub ty: Ty,
}
#[derive(Debug, Clone)]
pub enum ExprKind {
String(&'static str),
Int(i64),
Let(Var, Box<Expr>),
Var(Var),
Function(Var, Box<Expr>),
Call(Box<Expr>, Box<Expr>),
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Ty {
Var(TyVar),
String,
Int,
Void,
Function(Box<Ty>, Box<Ty>),
Error,
}
}
mod resolved {
use super::*;
#[derive(Debug, Clone)]
pub struct Expr {
pub kind: ExprKind,
pub ty: Ty,
}
#[derive(Debug, Clone)]
pub enum ExprKind {
String(&'static str),
Int(i64),
Let(Var, Box<Expr>),
Var(Var),
Function(Var, Box<Expr>),
Call(Box<Expr>, Box<Expr>),
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Ty {
String,
Int,
Void,
Function(Box<Ty>, Box<Ty>),
Error,
}
}
fn typecheck(expr: untypechecked::Expr, ctx: &mut TyCtx) -> unresolved::Expr {
match expr.kind {
untypechecked::ExprKind::String(value) => unresolved::Expr {
kind: unresolved::ExprKind::String(value),
ty: unresolved::Ty::String,
},
untypechecked::ExprKind::Int(value) => unresolved::Expr {
kind: unresolved::ExprKind::Int(value),
ty: unresolved::Ty::Int,
},
untypechecked::ExprKind::Let(var, value) => {
let value = typecheck(*value, ctx);
ctx.vars.insert(var, value.ty.clone());
unresolved::Expr {
kind: unresolved::ExprKind::Let(var, Box::new(value)),
ty: unresolved::Ty::Void,
}
}
untypechecked::ExprKind::Var(var) => unresolved::Expr {
kind: unresolved::ExprKind::Var(var),
ty: ctx
.vars
.get(&var)
.expect("variable used before being defined")
.clone(),
},
untypechecked::ExprKind::Function(var, body) => {
let input_ty = unresolved::Ty::Var(ctx.new_ty_var());
ctx.vars.insert(var, input_ty.clone());
let body = typecheck(*body, ctx);
let body_ty = body.ty.clone();
unresolved::Expr {
kind: unresolved::ExprKind::Function(var, Box::new(body)),
ty: unresolved::Ty::Function(Box::new(input_ty), Box::new(body_ty)),
}
}
untypechecked::ExprKind::Annotate(expr, ty) => {
let expr = typecheck(*expr, ctx);
if let Err(error) = unify(expr.ty.clone(), ty, ctx) {
ctx.errors.push(error);
return unresolved::Expr {
kind: expr.kind,
ty: unresolved::Ty::Error,
};
}
expr
}
untypechecked::ExprKind::Call(function, input) => {
let function = typecheck(*function, ctx);
let input = typecheck(*input, ctx);
let output_ty = unresolved::Ty::Var(ctx.new_ty_var());
if let Err(error) = unify(
function.ty.clone(),
unresolved::Ty::Function(Box::new(input.ty.clone()), Box::new(output_ty.clone())),
ctx,
) {
ctx.errors.push(error);
return unresolved::Expr {
kind: unresolved::ExprKind::Call(Box::new(function), Box::new(input)),
ty: unresolved::Ty::Error,
};
}
unresolved::Expr {
kind: unresolved::ExprKind::Call(Box::new(function), Box::new(input)),
ty: output_ty,
}
}
}
}
fn unify(
mut actual: unresolved::Ty,
mut expected: unresolved::Ty,
ctx: &mut TyCtx,
) -> Result<(), TyError> {
actual.apply(ctx);
expected.apply(ctx);
match (actual, expected) {
(unresolved::Ty::Var(var), ty) | (ty, unresolved::Ty::Var(var)) => {
// Don't want to cause an infinite loop by substituting a variable with
// itself!
if let unresolved::Ty::Var(other) = ty {
if var == other {
return Ok(());
}
}
if ty.contains(&var) {
Err(TyError::Recursive(var))
} else {
ctx.substitutions.insert(var, ty);
Ok(())
}
}
(unresolved::Ty::String, unresolved::Ty::String)
| (unresolved::Ty::Int, unresolved::Ty::Int)
| (unresolved::Ty::Void, unresolved::Ty::Void) => Ok(()),
(
unresolved::Ty::Function(actual_input, actual_output),
unresolved::Ty::Function(expected_input, expected_output),
) => {
unify(*actual_input, *expected_input, ctx)?;
unify(*actual_output, *expected_output, ctx)?;
Ok(())
}
(unresolved::Ty::Error, _) | (_, unresolved::Ty::Error) => Ok(()),
(actual, expected) => Err(TyError::MismatchedTypes(actual, expected)),
}
}
impl unresolved::Ty {
fn apply(&mut self, ctx: &TyCtx) {
match self {
unresolved::Ty::Var(var) => {
if let Some(ty) = ctx.substitutions.get(var) {
*self = ty.clone();
// This is necessary when there's a multi-step substitution
// (say, A -> B, B -> C, C -> D) -- we want to resolve as
// deeply as possible (ie. A -> D)
self.apply(ctx);
}
}
unresolved::Ty::Function(input, output) => {
input.apply(ctx);
output.apply(ctx);
}
_ => {}
}
}
fn contains(&self, var: &unresolved::TyVar) -> bool {
match self {
unresolved::Ty::Var(other) => var == other,
unresolved::Ty::Function(input, output) => input.contains(var) || output.contains(var),
_ => false,
}
}
}
fn resolve(expr: unresolved::Expr, ctx: &mut TyCtx) -> resolved::Expr {
resolved::Expr {
kind: match expr.kind {
unresolved::ExprKind::String(value) => resolved::ExprKind::String(value),
unresolved::ExprKind::Int(value) => resolved::ExprKind::Int(value),
unresolved::ExprKind::Let(var, expr) => {
resolved::ExprKind::Let(var, Box::new(resolve(*expr, ctx)))
}
unresolved::ExprKind::Var(var) => resolved::ExprKind::Var(var),
unresolved::ExprKind::Function(input, body) => {
resolved::ExprKind::Function(input, Box::new(resolve(*body, ctx)))
}
unresolved::ExprKind::Call(function, input) => resolved::ExprKind::Call(
Box::new(resolve(*function, ctx)),
Box::new(resolve(*input, ctx)),
),
},
ty: resolve_ty(expr.ty, ctx),
}
}
fn resolve_ty(mut ty: unresolved::Ty, ctx: &mut TyCtx) -> resolved::Ty {
ty.apply(ctx);
match ty {
unresolved::Ty::Var(var) => {
ctx.errors.push(TyError::UnresolvedVar(var));
resolved::Ty::Error
}
unresolved::Ty::String => resolved::Ty::String,
unresolved::Ty::Int => resolved::Ty::Int,
unresolved::Ty::Void => resolved::Ty::Void,
unresolved::Ty::Function(input, output) => resolved::Ty::Function(
Box::new(resolve_ty(*input, ctx)),
Box::new(resolve_ty(*output, ctx)),
),
unresolved::Ty::Error => resolved::Ty::Error,
}
}
fn main() {
let mut ctx = TyCtx::default();
let untypechecked_program = vec![
// let identity = (x) => x
untypechecked::Expr {
kind: untypechecked::ExprKind::Let(
Var("identity"),
Box::new(untypechecked::Expr {
kind: untypechecked::ExprKind::Function(
Var("x"),
Box::new(untypechecked::Expr {
kind: untypechecked::ExprKind::Var(Var("x")),
}),
),
}),
),
},
// identity(42)
untypechecked::Expr {
kind: untypechecked::ExprKind::Call(
Box::new(untypechecked::Expr {
kind: untypechecked::ExprKind::Var(Var("identity")),
}),
Box::new(untypechecked::Expr {
kind: untypechecked::ExprKind::Int(42),
}),
),
},
];
let unresolved_program = untypechecked_program
.into_iter()
.map(|expr| typecheck(expr, &mut ctx))
.collect::<Vec<_>>();
let resolved_program = unresolved_program
.into_iter()
.map(|expr| resolve(expr, &mut ctx))
.collect::<Vec<_>>();
println!("{:#?}", resolved_program);
println!("{:#?}", ctx.errors);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment