Skip to content

Instantly share code, notes, and snippets.

@felko
Last active September 2, 2019 23:56
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save felko/83f66298218a442dd95b3bf9454aab85 to your computer and use it in GitHub Desktop.
Save felko/83f66298218a442dd95b3bf9454aab85 to your computer and use it in GitHub Desktop.
# Types
abstract type Type
end
struct VarType <: Type
name :: String
rigid :: Bool
end
struct FunType <: Type
domain :: Array{Type}
codomain :: Type
end
struct TupleType <: Type
types :: Array{Type}
end
struct RefType <: Type
type_ :: Type
end
abstract type AtomType <: Type
end
struct IntType <: AtomType
end
struct BoolType <: AtomType
end
struct StringType <: AtomType
end
struct VoidType <: AtomType
end
# Polymorphic types (rank 1)
struct Scheme
vars :: Array{String}
type_ :: Type
end
function monotype(type_ :: Type)
return Scheme([], type_)
end
# AST
abstract type Expr
end
struct VarExpr <: Expr
name :: String
end
struct Param
name :: String
type_ :: Any
end
struct FunExpr <: Expr
params :: Array{Param}
body :: Expr
end
struct CallExpr <: Expr
called :: Expr
arguments :: Array{Expr}
end
abstract type Instruction
end
struct VarInstr <: Instruction
name :: String
value :: Expr
end
struct AssignInstr <: Instruction
name :: String
value :: Expr
end
struct LetInstr <: Instruction
name :: String
value :: Expr
end
struct ExprInstr <: Instruction
expr :: Expr
end
struct BlockExpr <: Expr
instructions :: Array{Instruction}
last :: Expr
end
struct IfExpr <: Expr
condition :: Expr
true_case :: Expr
false_case :: Expr
end
struct TupleExpr <: Expr
elements :: Array{Expr}
end
abstract type BinaryOperation <: Expr
end
struct AddExpr <: BinaryOperation
lhs :: Expr
rhs :: Expr
end
abstract type ComparisonExpr <: BinaryOperation
end
struct EQExpr <: ComparisonExpr
lhs :: Expr
rhs :: Expr
end
struct GTExpr <: ComparisonExpr
lhs :: Expr
rhs :: Expr
end
struct LTExpr <: ComparisonExpr
lhs :: Expr
rhs :: Expr
end
abstract type LitExpr <: Expr
end
struct IntLit <: LitExpr
value :: Int
end
struct BoolLit <: LitExpr
value :: Bool
end
struct StringLit <: LitExpr
value :: String
end
# Free variables
freeVariables(expr :: VarExpr) = Set{String}([expr.name])
freeVariables(expr :: FunExpr) = setdiff(freeVariables(expr.body), Set{String}([param.name for param in expr.params]))
freeVariables(expr :: CallExpr) = union(freeVariables(expr.called), map(freeVariables, expr.arguments)...)
freeVariables(expr :: VarInstr) = freeVariables(expr.value)
freeVariables(expr :: LetInstr) = freeVariables(expr.value)
freeVariables(expr :: AssignInstr) = freeVariables(expr.value)
freeVariables(expr :: ExprInstr) = freeVariables(expr.expr)
freeVariables(expr :: BlockExpr) = union(freeVariables(expr.last), map(freeVariables, expr.instructions)...)
freeVariables(expr :: IfExpr) = union(map(freeVariables, [expr.condition, expr.true_case, expr.false_case])...)
freeVariables(expr :: TupleExpr) = union(map(freeVariables, expr.elements)...)
freeVariables(expr :: E) where E <: BinaryOperation = union(freeVariables(expr.lhs), freeVariables(expr.rhs))
freeVariables(expr :: E) where E <: Expr = Set{String}()
# Free type variables
freeTypeVars(type_ :: VarType) = Set{String}([type_.name])
function freeTypeVars(type_ :: FunType)
domainFTVs = Set{String}()
for paramType in type_.domain
domainFTVs = union(domainFTVs, freeTypeVars(paramType))
end
return union(domainFTVs, freeTypeVars(type_.codomain))
end
function freeTypeVars(type_ :: TupleType)
ftvs = Set{String}()
for elemType in type_.types
ftvs = union(ftvs, freeTypeVars(elemType))
end
return ftvs
end
freeTypeVars(type_ :: RefType) = freeTypeVars(type_.type_)
freeTypeVars(type_ :: T) where T <: AtomType = Set{String}()
freeTypeVars(scheme :: Scheme) = setdiff(freeTypeVars(scheme.type_), Set{String}(scheme.vars))
# Substitution
function substitute(varName :: String, varType :: Type, ctxType :: VarType)
if ctxType.name == varName && !ctxType.rigid
return varType
else
return ctxType
end
end
function substitute(varName :: String, varType :: Type, ctxType :: FunType)
return FunType([substitute(varName, varType, paramType) for paramType in ctxType.domain],
substitute(varName, varType, ctxType.codomain))
end
function substitute(varName :: String, varType :: Type, ctxType :: TupleType)
newTypes = []
for type_ in ctxType.types
push!(newTypes, substitute(varName, varType, type_))
end
return TupleType(newTypes)
end
substitute(varName :: String, varType :: Type, ctxType :: RefType) = RefType(substitute(varName, varType, ctxType.type_))
substitute(varName :: String, varType :: Type, ctxType :: T) where T <: AtomType = ctxType
function substitute(varName :: String, varType :: Type, ctxType :: Scheme)
if varName in ctxType.vars
return ctxType
else
return Scheme(ctxType.vars, substitute(varName, varType, ctxType.type_))
end
end
# Typechecking utilities
abstract type TypecheckErrorType
end
struct ScopeError <: TypecheckErrorType
name :: String
end
struct TypeMismatch <: TypecheckErrorType
a :: Type
b :: Type
end
struct RigidityError <: TypecheckErrorType
var :: String
type_ :: Type
end
struct OccursCheckFailure <: TypecheckErrorType
var :: String
type_ :: Type
end
struct TypecheckError
type_ :: TypecheckErrorType
decl :: String
nesting :: Array{Expr}
bindings :: Dict{String, Type}
end
mutable struct Context
environment :: Dict{String, Scheme}
bindings :: Dict{String, Type}
freshVarSupply :: Int
decl :: String
nesting :: Array{Expr}
end
macro focus(func)
instrs = func.args[2]
if length(func.args[1].args) == 2
node = Symbol(func.args[1].args[1].args[3].args[1])
else
node = Symbol(func.args[1].args[3].args[1])
end
func.args[2] = quote
if isa($(node), Expr)
push!(ctx.nesting, $(node))
end
try
ret = eval($instrs)
return ret
catch err
if isa(err, TypecheckError)
rethrow(TypecheckError(err.type_, err.decl, err.nesting[1:end], err.bindings))
else
rethrow(err)
end
finally if isa($(node), Expr)
pop!(ctx.nesting)
end end
end
return func
end
function throwTypecheckError(ctx :: Context, err :: TypecheckErrorType)
freeVars = collect(setdiff(freeVariables(ctx.nesting[1]), keys(ctx.environment)))[1:min(10,end)]
relevantBindings = Dict{String,Type}([(var, ctx.bindings[var]) for var in freeVars])
throw(TypecheckError(err, ctx.decl, ctx.nesting, relevantBindings))
end
function freshVar(ctx :: Context)
name = "#$(ctx.freshVarSupply)"
ctx.freshVarSupply += 1
return name
end
freshType(ctx :: Context) = VarType(freshVar(ctx), false)
function substitute(varName :: String, varType :: Type, ctx :: Context)
for (var, typ) in ctx.bindings
ctx.bindings[var] = substitute(varName, varType, typ)
end
end
rigidify(type_ :: T) where T <: Type = rigidify(type_, freeTypeVars(type_))
function rigidify(type_ :: VarType, ftvs :: Set{String})
if type_.name in ftvs
return VarType(type_.name, true)
else
return type_
end
end
rigidify(type_ :: FunType, ftvs :: Set{String}) = FunType(map(t -> rigidify(t, ftvs), type_.domain), relax(type_.codomain))
rigidify(type_ :: TupleType, ftvs :: Set{String}) = TupleType(map(t -> rigidity(t, ftvs), type_.types))
rigidify(type_ :: T, ftvs :: Set{String}) where T <: AtomType = type_
function generalize(type_ :: Type)
ftvs = collect(freeTypeVars(type_))
tmpVarIndex = 0
for (i, ftv) in enumerate(ftvs)
if startswith(ftv, "#")
newName = String([Char(Int('A') + tmpVarIndex)])
ftvs[i] = newName
type_ = substitute(ftv, VarType(newName, false), type_)
tmpVarIndex += 1
end
end
return Scheme(ftvs, rigidify(type_))
end
relax(type_ :: VarType) = VarType(type_.name, false)
relax(type_ :: FunType) = FunType(map(relax, type_.domain), relax(type_.codomain))
relax(type_ :: TupleType) = TupleType(map(relax, type_.types))
relax(type_ :: T) where T <: AtomType = type_
function instantiate(ctx :: Context, scheme :: Scheme)
instantiated = relax(scheme.type_)
for var in scheme.vars
newVar = freshType(ctx)
instantiated = substitute(var, newVar, instantiated)
end
return instantiated
end
# Unification
function unify(ctx :: Context, a :: VarType, b :: VarType)
if a.rigid && b.rigid && a.name != b.name
throwTypecheckError(ctx, TypeMismatch(a, b))
elseif a.rigid
substitute(b.name, a, ctx)
return a
else
substitute(a.name, b, ctx)
return b
end
end
function unify(ctx :: Context, a :: VarType, b :: RefType)
if a.rigid
throwTypecheckError(ctx, RigidityError(a.name, b))
elseif a.name in freeTypeVars(b)
throwTypecheckError(ctx, OccursCheckFailure(a.name, b))
else
substitute(a.name, b, ctx)
end
return b
end
function unify(ctx :: Context, a :: RefType, b :: VarType)
if b.rigid
throwTypecheckError(ctx, RigidityError(b.name, a))
elseif b.name in freeTypeVars(a)
throwTypecheckError(ctx, OccursCheckFailure(b.name, a))
else
substitute(b.name, a, ctx)
end
return a
end
function unify(ctx :: Context, a :: VarType, b :: T) where T <: Type
if a.rigid
throwTypecheckError(ctx, RigidityError(a.name, b))
elseif a.name in freeTypeVars(b)
throwTypecheckError(ctx, OccursCheckFailure(a.name, b))
else
substitute(a.name, b, ctx)
end
return b
end
function unify(ctx :: Context, a :: T, b :: VarType) where T <: Type
if b.rigid
throwTypecheckError(ctx, RigidityError(b.name, a))
elseif b.name in freeTypeVars(a)
throwTypecheckError(ctx, OccursCheckFailure(b.name, a))
else
substitute(b.name, a, ctx)
end
return a
end
function unify(ctx :: Context, a :: FunType, b :: FunType)
if length(a.domain) != length(b.domain)
throwTypecheckError(ctx, TypeMismatch(a, b))
end
domain = []
for (paramA, paramB) in zip(a.domain, b.domain)
push!(domain, unify(ctx, paramA, paramB))
end
return FunType(domain, unify(ctx, a.codomain, b.codomain))
end
function unify(ctx :: Context, a :: TupleType, b :: TupleType)
if length(a.types) != length(b.types)
throwTypecheckError(ctx, TypeMismatch(a, b))
end
newTypes = []
for (x, y) in zip(a.types, b.types)
push!(newTypes, unify(ctx, x, y))
end
return TupleType(newTypes)
end
unify(ctx :: Context, a :: RefType, b :: RefType) = RefType(unify(ctx, a.type_, b.type_))
unify(ctx :: Context, a :: RefType, b :: T) where T <: Type = RefType(unify(ctx, a.type_, b))
unify(ctx :: Context, a :: T, b :: RefType) where T <: Type = RefType(unify(ctx, a, b.type_))
unify(ctx :: Context, a :: T, b :: T) where T <: AtomType = a
unify(ctx :: Context, a :: T, b :: U) where T <: Type where U <: Type = throwTypecheckError(ctx, TypeMismatch(a, b))
function unify(ctx :: Context, a :: Scheme, b :: Scheme)
t = a.type_
for (ta, tb) in zip(a.vars, b.vars)
t = substitute(ta, VarType(tb, true), t)
end
return unify(ctx, t, b.type_)
end
# Inference algorithm
@focus function typecheck(ctx :: Context, expr :: VarExpr)
try
return instantiate(ctx, ctx.environment[expr.name])
catch err
if isa(err, KeyError)
try
return ctx.bindings[expr.name]
catch err
if isa(err, KeyError)
throwTypecheckError(ctx, ScopeError(expr.name))
else
rethrow(err)
end
end
else
rethrow(err)
end
end
end
@focus function typecheck(ctx :: Context, expr :: FunExpr)
for param in expr.params
if isa(param.type_, Type)
ctx.bindings[param.name] = param.type_
else
ctx.bindings[param.name] = freshType(ctx)
end
end
codomain = typecheck(ctx, expr.body)
domain = []
for param in expr.params
push!(domain, ctx.bindings[param.name])
delete!(ctx.bindings, param.name)
end
return FunType(domain, codomain)
end
@focus typecheck(ctx :: Context, expr :: CallExpr) =
unify(ctx, typecheck(ctx, expr.called), FunType([typecheck(ctx, arg) for arg in expr.arguments], freshType(ctx))).codomain
typecheck(ctx :: Context, instr :: LetInstr) = ctx.environment[instr.name] = generalize(typecheck(ctx, instr.value))
typecheck(ctx :: Context, instr :: VarInstr) = ctx.bindings[instr.name] = RefType(typecheck(ctx, instr.value))
typecheck(ctx :: Context, instr :: ExprInstr) = typecheck(ctx, instr.expr)
@focus function typecheck(ctx :: Context, instr :: AssignInstr)
try
varType = ctx.bindings[instr.name]
valType = typecheck(ctx, instr.value)
unify(ctx, varType, RefType(valType))
catch err
if isa(err, KeyError)
throwTypecheckError(ctx, ScopeError(instr.name))
else
rethrow(err)
end
end
end
@focus function typecheck(ctx :: Context, expr :: BlockExpr)
old_environment = copy(ctx.environment)
old_bindings = copy(ctx.bindings)
for instr in expr.instructions
typecheck(ctx, instr)
end
last = typecheck(ctx, expr.last)
ctx.environment = old_environment
ctx.bindings = old_bindings
return last
end
@focus function typecheck(ctx :: Context, expr :: IfExpr)
unify(ctx, typecheck(ctx, expr.condition), BoolType())
tr = typecheck(ctx, expr.true_case)
fl = typecheck(ctx, expr.false_case)
return unify(ctx, tr, fl)
end
@focus typecheck(ctx :: Context, expr :: IntLit) = IntType()
@focus typecheck(ctx :: Context, expr :: BoolLit) = BoolType()
@focus typecheck(ctx :: Context, expr :: StringLit) = StringType()
@focus typecheck(ctx :: Context, expr :: TupleExpr) = TupleType(map(elt -> typecheck(ctx, elt), expr.elements))
@focus function typecheck(ctx :: Context, expr :: AddExpr)
unify(ctx, typecheck(ctx, expr.lhs), IntType())
unify(ctx, typecheck(ctx, expr.rhs), IntType())
return IntType()
end
@focus function typecheck(ctx :: Context, expr :: E) where E <: ComparisonExpr
a = typecheck(ctx, expr.lhs)
b = typecheck(ctx, expr.rhs)
unify(ctx, a, b)
return BoolType()
end
# Module
abstract type Declaration
end
struct FunDecl <: Declaration
name :: String
typeParams :: Any
params :: Array{Param}
returnType :: Any
body :: Expr
end
struct Module
declarations :: Array{Declaration}
end
mutable struct ModuleContext
environment :: Dict{String, Scheme}
end
function typecheck(modctx :: ModuleContext, decl :: FunDecl)
ctx = Context(modctx.environment, Dict{String, Type}(), 0, decl.name, [])
ctx.bindings[decl.name] = freshType(ctx)
funType = typecheck(ctx, FunExpr(decl.params, decl.body))
if isa(decl.returnType, Type)
unify(ctx, funType.codomain, decl.returnType)
end
declType = ctx.bindings[decl.name]
modctx.environment[decl.name] = generalize(unify(ctx, declType, funType))
end
function typecheck(ctx :: ModuleContext, mod :: Module)
for decl in mod.declarations
typecheck(ctx, decl)
end
end
# Pretty printing
pretty(expr :: VarExpr) = expr.name
pretty(expr :: FunExpr) = "fun($(join(map(pretty, expr.params), ", "))) -> $(pretty(expr.body))"
pretty(expr :: CallExpr) = "$(pretty(expr.called))($(join(map(pretty, expr.arguments), ", ")))"
pretty(expr :: VarInstr) = "var $(expr.name) = $(pretty(expr.value))"
pretty(expr :: LetInstr) = "let $(expr.name) = $(pretty(expr.value))"
pretty(expr :: AssignInstr) = "$(expr.name) = $(pretty(expr.value))"
pretty(expr :: ExprInstr) = pretty(expr.expr)
pretty(expr :: BlockExpr) = "{ $(join(map(pretty, push!(copy(expr.instructions), ExprInstr(expr.last))), "; ")) }"
pretty(expr :: IfExpr) = "if $(pretty(expr.condition)) $(pretty(expr.true_case)) else $(pretty(expr.false_case))"
pretty(expr :: TupleExpr) = "($(join(map(pretty, expr.elements), ", ")))"
pretty(expr :: AddExpr) = "$(pretty(expr.lhs)) + $(pretty(expr.rhs))"
pretty(expr :: EQExpr) = "$(pretty(expr.lhs)) == $(pretty(expr.rhs))"
pretty(expr :: GTExpr) = "$(pretty(expr.lhs)) > $(pretty(expr.rhs))"
pretty(expr :: LTExpr) = "$(pretty(expr.lhs)) < $(pretty(expr.rhs))"
pretty(expr :: E) where E <: LitExpr = repr(expr.value)
pretty(param :: Param) = if isa(param.type_, Type) "$(param.name): $(pretty(param.type_))" else param.name end
pretty(typ :: VarType) = typ.name
pretty(typ :: FunType) = "fun($(join(map(pretty, typ.domain), ", "))) -> $(pretty(typ.codomain))"
pretty(typ :: TupleType) = "($(join(map(pretty, typ.types), ", ")))"
pretty(typ :: RefType) = "ref $(pretty(typ.type_))"
pretty(typ :: IntType) = "Int"
pretty(typ :: BoolType) = "Bool"
pretty(typ :: StringType) = "String"
pretty(typ :: VoidType) = "Void"
function pretty(scheme :: Scheme)
if isempty(scheme.vars)
return pretty(scheme.type_)
end
if isa(scheme.type_, FunType)
return "fun<$(join(scheme.vars, ", "))>($(join(map(pretty, scheme.type_.domain), ", "))) -> $(pretty(scheme.type_.codomain))"
else
return "<$(join(scheme.vars, ", "))>$(pretty(scheme.type_))"
end
end
pretty(err :: ScopeError) = "Name not in scope: `$(err.name)`"
pretty(err :: TypeMismatch) = "Type mismatch: Expected value of type `$(pretty(err.a))`, got `$(pretty(err.b))`"
pretty(err :: RigidityError) = "Cannot unify rigid type variable `$(err.var)` to `$(pretty(err.type_))`"
pretty(err :: OccursCheckFailure) = "Occurs check fails: Cannot construct infinite type `$(err.var) ~ $(pretty(err.type_))`"
exprCategory(expr :: VarExpr) = "expression `$(pretty(expr))`"
exprCategory(expr :: FunExpr) = "anonymous function `$(pretty(expr))`"
exprCategory(expr :: CallExpr) = "function call `$(pretty(expr))`"
exprCategory(expr :: BlockExpr) = "block `$(pretty(expr))`"
exprCategory(expr :: IfExpr) = "if-expression `$(pretty(expr))`"
exprCategory(expr :: TupleExpr) = "tuple `$(pretty(expr))`"
exprCategory(expr :: AddExpr) = "arithmetic formula `$(pretty(expr))`"
exprCategory(expr :: E) where E <: ComparisonExpr = "comparison `$(pretty(expr))`"
exprCategory(expr :: E) where E <: LitExpr = "literal `$(pretty(expr))`"
function pretty(err :: TypecheckError)
s = "A type error was raised while trying to typecheck `$(err.decl)`:\n$(pretty(err.type_))\n"
for expr in err.nesting[max(1,end-5):end]
s *= "\t• In $(exprCategory(expr))\n"
end
s *= "Relevant bindings include:\n"
for (name, type_) in err.bindings
s *= "\t• $(name) : $(pretty(type_))\n"
end
return s
end
# Example
function main()
prelude = Dict(
"input" => Scheme([], FunType([], StringType())),
"print" => Scheme(["T"], FunType([VarType("T", true)], VoidType())),
"to_int" => Scheme([], FunType([StringType()], IntType()))
)
mod = Module(
[
FunDecl("app", Nothing, [Param("f", Nothing), Param("x", Nothing)], Nothing, CallExpr(VarExpr("f"), [VarExpr("x")])),
# fun app(f, x) = f(x)
FunDecl("positive", Nothing, [Param("x", Nothing)], Nothing, GTExpr(VarExpr("x"), IntLit(0))),
# fun positive(x) = x > 0
FunDecl("increment", Nothing, [Param("x", Nothing)], Nothing, AddExpr(VarExpr("x"), IntLit(1))),
# fun increment(x) = x + 1
FunDecl("fix", Nothing, [Param("f", Nothing)], Nothing, CallExpr(VarExpr("f"), [CallExpr(VarExpr("fix"), [VarExpr("f")])])),
# fun fix(f) = f(fix(f))
FunDecl("id", ["A"], [Param("x", VarType("A", true))], VarType("A", true), VarExpr("x")),
# fun id<A>(x: A) -> A = x
FunDecl("letGeneralization", [], [], Nothing, BlockExpr([LetInstr("f", FunExpr([Param("x", Nothing)], VarExpr("x")))], TupleExpr([CallExpr(VarExpr("f"), [IntLit(0)]), CallExpr(VarExpr("f"), [StringLit("test")])]))),
# fun letGeneralization() = {
# let f(x) = x
# (f 1, f "test")
# }
FunDecl(
"play",
Nothing,
[Param("num", Nothing)],
Nothing,
BlockExpr(
[
ExprInstr(CallExpr(VarExpr("print"), [StringLit("> ")])),
LetInstr("guess", CallExpr(VarExpr("to_int"), [CallExpr(VarExpr("input"), [])]))
],
IfExpr(
AddExpr(VarExpr("guess"), VarExpr("num")),
CallExpr(VarExpr("print"), [StringLit("gagné")]),
IfExpr(
LTExpr(VarExpr("guess"), VarExpr("num")),
BlockExpr(
[ExprInstr(CallExpr(VarExpr("print"), [StringLit("+")]))],
CallExpr(VarExpr("play"), [VarExpr("num")])
),
BlockExpr(
[ExprInstr(CallExpr(VarExpr("print"), [StringLit("-")]))],
CallExpr(VarExpr("play"), [VarExpr("num")])
)
)
)
)
),
# fun play(num) = {
# print("Affiche un nombre: ")
# let guess = to_int(input())
# if guess == num then
# print("gagné")
# else if guess < num {
# print("+")
# play(num)
# } else {
# print("-")
# play(num)
# }
# }
FunDecl(
"mutation",
Nothing,
[Param("x", Nothing)],
Nothing,
BlockExpr(
[
AssignInstr("x", AddExpr(VarExpr("x"), IntLit(1)))
],
CallExpr(VarExpr("print"), [VarExpr("x")])
)
),
# fun mutation(x) = {
# x = x + 1
# print(x)
# }
]
)
ctx = ModuleContext(prelude)
try
typecheck(ctx, mod)
catch err
if isa(err, TypecheckError)
println(pretty(err))
else
rethrow(err)
end
end
for (name, typ) in ctx.environment
println(name, " : ", pretty(typ))
end
end
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment