Skip to content

Instantly share code, notes, and snippets.

@WheretIB
Last active April 16, 2021 10:14
Show Gist options
  • Save WheretIB/3218cf832956de3d45b201a8adc8af32 to your computer and use it in GitHub Desktop.
Save WheretIB/3218cf832956de3d45b201a8adc8af32 to your computer and use it in GitHub Desktop.
import std.string;
import std.io;
import std.hashmap;
import std.typeinfo;
import std.vector;
// Standard library extensions (stuff that should have been in stdlib but isn't there for some reason)
int int(string str){ return int(str.data); }
int hash_value(string str){ return hash_value(str.data); }
StdOut operator<<(StdOut out, string str){ Print(str.data); return out; }
void hashmap:insert(Key k, Value v){ this[k] = v; }
auto hashmap:start()
{
return coroutine auto(){
for(int i = 0; i < entries.size; i++) for(auto node = entries[i]; node; node = node.next) yield node;
return nullptr;
};
}
// Types
class Type extendable{}
class BoolType: Type{}
class FunctionType: Type{ Type ref arg, res; }
// Type constructors
void FunctionType:FunctionType(Type ref arg, Type ref res){ this.arg = arg; this.res = res; }
bool equal(Type ref a, Type ref b)
{
if (!a || !b)
return !!a == !!b;
if (typeid(a) != typeid(b))
return false;
switch (typeid(a))
{
case BoolType:
return true;
case FunctionType:
FunctionType ref funcA = a;
FunctionType ref funcB = b;
return equal(funcA.arg, funcB.arg) && equal(funcA.res, funcB.res);
}
}
// Output
string getName(Type ref a)
{
if (!a)
return string("");
switch (typeid(a))
{
case BoolType:
return string("Bool");
case FunctionType:
FunctionType ref func = a;
return getName(func.arg) + " -> " + getName(func.res);
}
}
StdOut operator<<(StdOut out, Type ref t){ out << getName(t); return out; }
// Terms
class Term extendable{ Type ref type; }
class Boolean : Term{ bool value; }
class Variable: Term{ string x; }
class Abstraction: Term{ string x; Type ref xType; Term ref t; }
class Application: Term{ Term ref t1, t2; }
class Conditional: Term{ Term ref cond, tru, fls; }
// Term constructors
void Boolean:Boolean(bool value){ this.value = value; }
void Variable:Variable(string x){ this.x = x; }
void Variable:Variable(string x, Type ref type){ this.x = x; this.type = type; }
void Abstraction:Abstraction(string x, Type ref xType, Term ref t){ this.x = x; this.xType = xType; this.t = t; }
void Application:Application(Term ref t1, t2){ this.t1 = t1; this.t2 = t2; }
void Conditional:Conditional(Term ref cond, tru, fls){ this.cond = cond; this.tru = tru; this.fls = fls; }
// Optional equality (for testing)
bool equal(Term ref a, Term ref b)
{
if (typeid(a) != typeid(b))
return false;
if (!equal(a.type, b.type))
return false;
switch (typeid(a))
{
case Boolean:
return true;
case Variable:
Variable ref varA = a;
Variable ref varB = b;
return varA.x == varB.x;
case Abstraction:
Abstraction ref absA = a;
Abstraction ref absB = b;
return absA.x == absB.x && equal(absA.t, absB.t);
case Application:
Application ref appA = a;
Application ref appB = b;
return equal(appA.t1, appB.t1) && equal(appA.t2, appB.t2);
case Conditional:
Conditional ref condA = a;
Conditional ref condB = b;
return equal(condA.cond, condB.cond) && equal(condA.tru, condB.tru) && equal(condA.fls, condB.fls);
}
}
// Term output
void output(Term ref a);
StdOut operator<<(StdOut out, Term ref t){ output(t); return out; }
void output(Term ref a)
{
switch (typeid(a))
{
case Boolean:
if (Boolean ref t = a)
io.out << (t.value ? "true" : "false");
break;
case Variable:
if (Variable ref t = a)
io.out << t.x;
break;
case Abstraction:
if (Abstraction ref t = a)
{
if (t.xType)
io.out << "/" << t.x << ":" << t.xType << ". " << t.t;
else
io.out << "/" << t.x << ". " << t.t;
}
break;
case Application:
if (Application ref t = a)
{
if (typeid(t.t2) == Application)
io.out << t.t1 << " (" << t.t2 << ")";
else
io.out << t.t1 << " " << t.t2;
}
break;
case Conditional:
if (Conditional ref t = a)
io.out << "if " << t.cond << " then " << t.tru << " else " << t.fls;
break;
}
}
void outputln(Term ref t){ io.out << t << io.endl; }
// Lexer
enum LexemeType
{
Lambda,
Number,
String,
Oparen,
Cparen,
Point,
Colon,
Arrow,
If,
Then,
Else,
Unknown,
Eof
}
class Lexeme
{
void Lexeme(LexemeType type, string str){ this.type = type; this.str = str; }
void Lexeme(LexemeType type, char[] str){ this.type = type; this.str = str; }
LexemeType type = LexemeType.Eof;
string str;
}
class Lexer
{
void Lexer(string str){ this.str = str; }
void skipSpaces()
{
while (pos < str.length() && str[pos] <= ' ')
pos++;
}
auto peek()
{
skipSpaces();
int pos = this.pos;
if (pos == str.length())
return Lexeme();
// Lex lambda symbol
if (str[pos] == '/') return Lexeme(LexemeType.Lambda, "/");
if (str[pos] == '(') return Lexeme(LexemeType.Oparen, "(");
if (str[pos] == ')') return Lexeme(LexemeType.Cparen, ")");
if (str[pos] == '.') return Lexeme(LexemeType.Point, ".");
if (str[pos] == ':') return Lexeme(LexemeType.Colon, ":");
if (str[pos] == '-' && str[pos + 1] == '>') return Lexeme(LexemeType.Arrow, "->");
// Lex identifier or a number
auto isAlnum = <char x> ((x | 32) >= 'a' && (x | 32) <= 'z') || x < 0 || x == '_';
auto isDigit = <char x> x >= '0' && x <= '9';
int start = pos;
if (isAlnum(str[pos]))
{
pos++;
while (pos < str.length() && (isAlnum(str[pos]) || isDigit(str[pos])))
pos++;
}
else if (isDigit(str[pos]))
{
pos++;
while (pos < str.length() && isDigit(str[pos]))
pos++;
return Lexeme(LexemeType.Number, str.substr(start, pos - start));
}
string result = str.substr(start, pos - start);
if (result == "if") return Lexeme(LexemeType.If, result);
if (result == "then") return Lexeme(LexemeType.Then, result);
if (result == "else") return Lexeme(LexemeType.Else, result);
if (result.empty()) return Lexeme(LexemeType.Unknown, "");
return Lexeme(LexemeType.String, str.substr(start, pos - start));
}
auto consume()
{
skipSpaces();
auto next = peek();
pos += next.str.length();
return next;
}
string str;
int pos;
}
class Parser
{
Term ref parseTerms();
Type ref parseSimpleType()
{
assert(lexer.peek().type == LexemeType.String, "type name expected");
auto name = lexer.consume().str;
if (name == "Bool")
return new BoolType();
assert(false, "unknown type name");
}
Type ref parseType()
{
Type ref t = parseSimpleType();
if (lexer.peek().type == LexemeType.Arrow)
{
lexer.consume();
t = new FunctionType(t, parseType());
}
return t;
}
Term ref parseTerm()
{
switch (lexer.peek().type)
{
case LexemeType.Lambda:
lexer.consume();
assert(lexer.peek().type == LexemeType.String, "abstraction name expected after '/'");
string x = lexer.consume().str;
assert(lexer.peek().type == LexemeType.Colon, "type expected after abstraction name");
lexer.consume();
Type ref type = parseType();
assert(lexer.peek().type == LexemeType.Point, "'.' not found after abstracton name");
lexer.consume();
return new Abstraction(x, type, parseTerms());
case LexemeType.String:
auto str = lexer.consume().str;
if (str == "true")
return new Boolean(true);
if (str == "false")
return new Boolean(false);
return new Variable(str);
case LexemeType.Oparen:
lexer.consume();
auto t = parseTerms();
assert(lexer.peek().type == LexemeType.Cparen, "')' not found after '('");
lexer.consume();
return t;
case LexemeType.If:
lexer.consume();
auto cond = parseTerms();
assert(cond != nullptr, "condition expected after 'if'");
assert(lexer.peek().type == LexemeType.Then, "'then' not found after term");
lexer.consume();
auto tru = parseTerms();
assert(tru != nullptr, "term expected after 'then'");
assert(lexer.peek().type == LexemeType.Else, "'else' not found after term");
lexer.consume();
auto fls = parseTerms();
assert(fls != nullptr, "term expected after 'else'");
return new Conditional(cond, tru, fls);
case LexemeType.Unknown:
assert(false, "unknown lexeme at " + lexer.pos.str());
}
return nullptr;
}
Term ref parseTerms()
{
auto t = parseTerm();
auto t2 = parseTerm();
while (t2)
{
t = new Application(t, t2);
t2 = parseTerm();
}
return t;
}
Term ref parse(char[] code)
{
lexer = Lexer(string(code));
return parseTerms();
}
Lexer lexer;
}
class Interpreter
{
Parser p;
hashmap<string, Type ref> bindings;
hashmap<string, Term ref> globals; // optional
}
Type ref Interpreter:typecheck(Term ref a)
{
switch (typeid(a))
{
case Boolean:
if (Boolean ref t = a)
t.type = new BoolType();
break;
case Variable:
if (Variable ref t = a)
{
if (auto type = bindings.find(t.x))
t.type = *type;
else if (auto global = globals.find(t.x))
t.type = global.type;
else
io.out << "ERROR: Unknown variable " << t.x << io.endl;
}
break;
case Abstraction:
if (Abstraction ref t = a)
{
auto lastType = bindings[t.x];
bindings[t.x] = t.xType;
if (auto tType = typecheck(t.t))
t.type = new FunctionType(t.xType, tType);
bindings[t.x] = lastType;
}
break;
case Application:
if (Application ref t = a)
{
auto t1Type = typecheck(t.t1);
auto t2Type = typecheck(t.t2);
if (t1Type && t2Type)
{
if (typeid(t1Type) == FunctionType)
{
FunctionType ref t1FuncType = t1Type;
if (equal(t1FuncType.arg, t2Type))
t.type = t1FuncType.res;
else
io.out << "ERROR: Application rhs type " << t2Type << " has to match function argument type " << t1FuncType.arg << io.endl;
}
else
{
io.out << "ERROR: Application lhs type has to be function got " << t1Type << io.endl;
}
}
}
break;
case Conditional:
if (Conditional ref t = a)
{
auto condType = typecheck(t.cond);
auto truType = typecheck(t.tru);
auto flsType = typecheck(t.fls);
if (condType && truType && flsType)
{
if (typeid(condType) == BoolType)
{
if (equal(truType, flsType))
t.type = truType;
else
io.out << "ERROR: Condition branch types have to equal " << truType << " != " << flsType << io.endl;
}
else
{
io.out << "ERROR: Condition type has to be Bool got " << condType << io.endl;
}
}
}
break;
}
if (!a.type)
io.out << "ERROR: Failed to typecheck " << typeid(a).name << " '" << a << "'" << io.endl;
return a.type;
}
void Interpreter:free(char[] name, char[] code)
{
p.lexer = Lexer(string(code));
auto t = p.parseType();
assert(t != nullptr, "failed to parse type");
bindings[string(name)] = t;
}
void Interpreter:global(char[] name, char[] code)
{
auto t = p.parse(code);
assert(t != nullptr, "failed to parse");
typecheck(t);
assert(t.type != nullptr, "failed to typecheck");
globals[string(name)] = t;
}
Term ref Interpreter:substitute(Term ref t, string x, Term ref rhs)
{
switch (typeid(t))
{
case Variable:
if (Variable ref variable = t)
{
if (variable.x == x)
return rhs;
}
break;
case Abstraction:
if (Abstraction ref abstraction = t)
{
// Do not step into the abstractions that have a binding over same name
if (abstraction.x == x)
return t;
return new Abstraction(abstraction.x, abstraction.type, substitute(abstraction.t, x, rhs));
}
break;
case Application:
if (Application ref application = t)
return new Application(substitute(application.t1, x, rhs), substitute(application.t2, x, rhs));
break;
case Conditional:
if (Conditional ref conditional = t)
return new Conditional(substitute(conditional.cond, x, rhs), substitute(conditional.tru, x, rhs), substitute(conditional.fls, x, rhs));
break;
}
return t;
}
Term ref Interpreter:eval(Term ref t)
{
switch (typeid(t))
{
case Application:
if (Application ref application = t)
{
auto t1 = eval(application.t1);
if (typeid(t1) != Abstraction)
io.out << "expected abstraction for t1, got: (" << typeid(t1).name << ") " << t1 << io.endl;
Abstraction ref lhs = t1;
return eval(substitute(lhs.t, lhs.x, eval(application.t2))); // eval result for big-step
}
break;
case Variable:
if (Variable ref variable = t)
{
if (auto t = globals.find(variable.x))
return eval(*t);
}
break;
case Conditional:
if (Conditional ref conditional = t)
{
auto cond = eval(conditional.cond);
if (Boolean ref(cond).value)
return eval(conditional.tru);
else
return eval(conditional.fls);
}
}
return t;
}
Term ref Interpreter:wrap(Term ref t)
{
for (i in globals)
{
if (equal(i.value, t))
return new Variable(i.key, i.value.type);
}
switch (typeid(t))
{
case Abstraction:
if (Abstraction ref abstraction = t)
return new Abstraction(abstraction.x, abstraction.xType, wrap(abstraction.t));
break;
case Application:
if (Application ref application = t)
return new Application(wrap(application.t1), wrap(application.t2));
break;
}
return t;
}
auto Interpreter:eval(char[] code)
{
auto t = p.parse(code);
assert(t != nullptr, "failed to parse");
typecheck(t);
assert(t.type != nullptr, "failed to typecheck");
return eval(t);
}
void Interpreter:printeval(char[] code)
{
io.out << code << io.endl;
auto t = eval(code);
io.out << "> " << wrap(t) << " : " << getName(t.type) << io.endl;
}
void Interpreter:printtypecheck(char[] code)
{
auto t = p.parse(code);
assert(t != nullptr, "failed to parse");
typecheck(t);
assert(t.type != nullptr, "failed to typecheck");
io.out << "> " << t << " : " << getName(t.type) << io.endl;
}
Interpreter i;
i.global("not", "/x:Bool. if x then false else true");
i.printeval("if true then (/x:Bool. x) else (/x:Bool. not x)");
i.free("f", "Bool -> Bool");
i.printtypecheck("f (if false then true else false)");
i.printtypecheck("/x:Bool. f (if x then false else x)");
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment