Last active
April 16, 2021 10:14
-
-
Save WheretIB/3218cf832956de3d45b201a8adc8af32 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import std.string; | |
import std.io; | |
import std.hashmap; | |
import std.typeinfo; | |
import std.vector; | |
// Standard library extensions (stuff that should have been in stdlib but isn't there for some reason) | |
int int(string str){ return int(str.data); } | |
int hash_value(string str){ return hash_value(str.data); } | |
StdOut operator<<(StdOut out, string str){ Print(str.data); return out; } | |
void hashmap:insert(Key k, Value v){ this[k] = v; } | |
auto hashmap:start() | |
{ | |
return coroutine auto(){ | |
for(int i = 0; i < entries.size; i++) for(auto node = entries[i]; node; node = node.next) yield node; | |
return nullptr; | |
}; | |
} | |
// Types | |
class Type extendable{} | |
class BoolType: Type{} | |
class FunctionType: Type{ Type ref arg, res; } | |
// Type constructors | |
void FunctionType:FunctionType(Type ref arg, Type ref res){ this.arg = arg; this.res = res; } | |
bool equal(Type ref a, Type ref b) | |
{ | |
if (!a || !b) | |
return !!a == !!b; | |
if (typeid(a) != typeid(b)) | |
return false; | |
switch (typeid(a)) | |
{ | |
case BoolType: | |
return true; | |
case FunctionType: | |
FunctionType ref funcA = a; | |
FunctionType ref funcB = b; | |
return equal(funcA.arg, funcB.arg) && equal(funcA.res, funcB.res); | |
} | |
} | |
// Output | |
string getName(Type ref a) | |
{ | |
if (!a) | |
return string(""); | |
switch (typeid(a)) | |
{ | |
case BoolType: | |
return string("Bool"); | |
case FunctionType: | |
FunctionType ref func = a; | |
return getName(func.arg) + " -> " + getName(func.res); | |
} | |
} | |
StdOut operator<<(StdOut out, Type ref t){ out << getName(t); return out; } | |
// Terms | |
class Term extendable{ Type ref type; } | |
class Boolean : Term{ bool value; } | |
class Variable: Term{ string x; } | |
class Abstraction: Term{ string x; Type ref xType; Term ref t; } | |
class Application: Term{ Term ref t1, t2; } | |
class Conditional: Term{ Term ref cond, tru, fls; } | |
// Term constructors | |
void Boolean:Boolean(bool value){ this.value = value; } | |
void Variable:Variable(string x){ this.x = x; } | |
void Variable:Variable(string x, Type ref type){ this.x = x; this.type = type; } | |
void Abstraction:Abstraction(string x, Type ref xType, Term ref t){ this.x = x; this.xType = xType; this.t = t; } | |
void Application:Application(Term ref t1, t2){ this.t1 = t1; this.t2 = t2; } | |
void Conditional:Conditional(Term ref cond, tru, fls){ this.cond = cond; this.tru = tru; this.fls = fls; } | |
// Optional equality (for testing) | |
bool equal(Term ref a, Term ref b) | |
{ | |
if (typeid(a) != typeid(b)) | |
return false; | |
if (!equal(a.type, b.type)) | |
return false; | |
switch (typeid(a)) | |
{ | |
case Boolean: | |
return true; | |
case Variable: | |
Variable ref varA = a; | |
Variable ref varB = b; | |
return varA.x == varB.x; | |
case Abstraction: | |
Abstraction ref absA = a; | |
Abstraction ref absB = b; | |
return absA.x == absB.x && equal(absA.t, absB.t); | |
case Application: | |
Application ref appA = a; | |
Application ref appB = b; | |
return equal(appA.t1, appB.t1) && equal(appA.t2, appB.t2); | |
case Conditional: | |
Conditional ref condA = a; | |
Conditional ref condB = b; | |
return equal(condA.cond, condB.cond) && equal(condA.tru, condB.tru) && equal(condA.fls, condB.fls); | |
} | |
} | |
// Term output | |
void output(Term ref a); | |
StdOut operator<<(StdOut out, Term ref t){ output(t); return out; } | |
void output(Term ref a) | |
{ | |
switch (typeid(a)) | |
{ | |
case Boolean: | |
if (Boolean ref t = a) | |
io.out << (t.value ? "true" : "false"); | |
break; | |
case Variable: | |
if (Variable ref t = a) | |
io.out << t.x; | |
break; | |
case Abstraction: | |
if (Abstraction ref t = a) | |
{ | |
if (t.xType) | |
io.out << "/" << t.x << ":" << t.xType << ". " << t.t; | |
else | |
io.out << "/" << t.x << ". " << t.t; | |
} | |
break; | |
case Application: | |
if (Application ref t = a) | |
{ | |
if (typeid(t.t2) == Application) | |
io.out << t.t1 << " (" << t.t2 << ")"; | |
else | |
io.out << t.t1 << " " << t.t2; | |
} | |
break; | |
case Conditional: | |
if (Conditional ref t = a) | |
io.out << "if " << t.cond << " then " << t.tru << " else " << t.fls; | |
break; | |
} | |
} | |
void outputln(Term ref t){ io.out << t << io.endl; } | |
// Lexer | |
enum LexemeType | |
{ | |
Lambda, | |
Number, | |
String, | |
Oparen, | |
Cparen, | |
Point, | |
Colon, | |
Arrow, | |
If, | |
Then, | |
Else, | |
Unknown, | |
Eof | |
} | |
class Lexeme | |
{ | |
void Lexeme(LexemeType type, string str){ this.type = type; this.str = str; } | |
void Lexeme(LexemeType type, char[] str){ this.type = type; this.str = str; } | |
LexemeType type = LexemeType.Eof; | |
string str; | |
} | |
class Lexer | |
{ | |
void Lexer(string str){ this.str = str; } | |
void skipSpaces() | |
{ | |
while (pos < str.length() && str[pos] <= ' ') | |
pos++; | |
} | |
auto peek() | |
{ | |
skipSpaces(); | |
int pos = this.pos; | |
if (pos == str.length()) | |
return Lexeme(); | |
// Lex lambda symbol | |
if (str[pos] == '/') return Lexeme(LexemeType.Lambda, "/"); | |
if (str[pos] == '(') return Lexeme(LexemeType.Oparen, "("); | |
if (str[pos] == ')') return Lexeme(LexemeType.Cparen, ")"); | |
if (str[pos] == '.') return Lexeme(LexemeType.Point, "."); | |
if (str[pos] == ':') return Lexeme(LexemeType.Colon, ":"); | |
if (str[pos] == '-' && str[pos + 1] == '>') return Lexeme(LexemeType.Arrow, "->"); | |
// Lex identifier or a number | |
auto isAlnum = <char x> ((x | 32) >= 'a' && (x | 32) <= 'z') || x < 0 || x == '_'; | |
auto isDigit = <char x> x >= '0' && x <= '9'; | |
int start = pos; | |
if (isAlnum(str[pos])) | |
{ | |
pos++; | |
while (pos < str.length() && (isAlnum(str[pos]) || isDigit(str[pos]))) | |
pos++; | |
} | |
else if (isDigit(str[pos])) | |
{ | |
pos++; | |
while (pos < str.length() && isDigit(str[pos])) | |
pos++; | |
return Lexeme(LexemeType.Number, str.substr(start, pos - start)); | |
} | |
string result = str.substr(start, pos - start); | |
if (result == "if") return Lexeme(LexemeType.If, result); | |
if (result == "then") return Lexeme(LexemeType.Then, result); | |
if (result == "else") return Lexeme(LexemeType.Else, result); | |
if (result.empty()) return Lexeme(LexemeType.Unknown, ""); | |
return Lexeme(LexemeType.String, str.substr(start, pos - start)); | |
} | |
auto consume() | |
{ | |
skipSpaces(); | |
auto next = peek(); | |
pos += next.str.length(); | |
return next; | |
} | |
string str; | |
int pos; | |
} | |
class Parser | |
{ | |
Term ref parseTerms(); | |
Type ref parseSimpleType() | |
{ | |
assert(lexer.peek().type == LexemeType.String, "type name expected"); | |
auto name = lexer.consume().str; | |
if (name == "Bool") | |
return new BoolType(); | |
assert(false, "unknown type name"); | |
} | |
Type ref parseType() | |
{ | |
Type ref t = parseSimpleType(); | |
if (lexer.peek().type == LexemeType.Arrow) | |
{ | |
lexer.consume(); | |
t = new FunctionType(t, parseType()); | |
} | |
return t; | |
} | |
Term ref parseTerm() | |
{ | |
switch (lexer.peek().type) | |
{ | |
case LexemeType.Lambda: | |
lexer.consume(); | |
assert(lexer.peek().type == LexemeType.String, "abstraction name expected after '/'"); | |
string x = lexer.consume().str; | |
assert(lexer.peek().type == LexemeType.Colon, "type expected after abstraction name"); | |
lexer.consume(); | |
Type ref type = parseType(); | |
assert(lexer.peek().type == LexemeType.Point, "'.' not found after abstracton name"); | |
lexer.consume(); | |
return new Abstraction(x, type, parseTerms()); | |
case LexemeType.String: | |
auto str = lexer.consume().str; | |
if (str == "true") | |
return new Boolean(true); | |
if (str == "false") | |
return new Boolean(false); | |
return new Variable(str); | |
case LexemeType.Oparen: | |
lexer.consume(); | |
auto t = parseTerms(); | |
assert(lexer.peek().type == LexemeType.Cparen, "')' not found after '('"); | |
lexer.consume(); | |
return t; | |
case LexemeType.If: | |
lexer.consume(); | |
auto cond = parseTerms(); | |
assert(cond != nullptr, "condition expected after 'if'"); | |
assert(lexer.peek().type == LexemeType.Then, "'then' not found after term"); | |
lexer.consume(); | |
auto tru = parseTerms(); | |
assert(tru != nullptr, "term expected after 'then'"); | |
assert(lexer.peek().type == LexemeType.Else, "'else' not found after term"); | |
lexer.consume(); | |
auto fls = parseTerms(); | |
assert(fls != nullptr, "term expected after 'else'"); | |
return new Conditional(cond, tru, fls); | |
case LexemeType.Unknown: | |
assert(false, "unknown lexeme at " + lexer.pos.str()); | |
} | |
return nullptr; | |
} | |
Term ref parseTerms() | |
{ | |
auto t = parseTerm(); | |
auto t2 = parseTerm(); | |
while (t2) | |
{ | |
t = new Application(t, t2); | |
t2 = parseTerm(); | |
} | |
return t; | |
} | |
Term ref parse(char[] code) | |
{ | |
lexer = Lexer(string(code)); | |
return parseTerms(); | |
} | |
Lexer lexer; | |
} | |
class Interpreter | |
{ | |
Parser p; | |
hashmap<string, Type ref> bindings; | |
hashmap<string, Term ref> globals; // optional | |
} | |
Type ref Interpreter:typecheck(Term ref a) | |
{ | |
switch (typeid(a)) | |
{ | |
case Boolean: | |
if (Boolean ref t = a) | |
t.type = new BoolType(); | |
break; | |
case Variable: | |
if (Variable ref t = a) | |
{ | |
if (auto type = bindings.find(t.x)) | |
t.type = *type; | |
else if (auto global = globals.find(t.x)) | |
t.type = global.type; | |
else | |
io.out << "ERROR: Unknown variable " << t.x << io.endl; | |
} | |
break; | |
case Abstraction: | |
if (Abstraction ref t = a) | |
{ | |
auto lastType = bindings[t.x]; | |
bindings[t.x] = t.xType; | |
if (auto tType = typecheck(t.t)) | |
t.type = new FunctionType(t.xType, tType); | |
bindings[t.x] = lastType; | |
} | |
break; | |
case Application: | |
if (Application ref t = a) | |
{ | |
auto t1Type = typecheck(t.t1); | |
auto t2Type = typecheck(t.t2); | |
if (t1Type && t2Type) | |
{ | |
if (typeid(t1Type) == FunctionType) | |
{ | |
FunctionType ref t1FuncType = t1Type; | |
if (equal(t1FuncType.arg, t2Type)) | |
t.type = t1FuncType.res; | |
else | |
io.out << "ERROR: Application rhs type " << t2Type << " has to match function argument type " << t1FuncType.arg << io.endl; | |
} | |
else | |
{ | |
io.out << "ERROR: Application lhs type has to be function got " << t1Type << io.endl; | |
} | |
} | |
} | |
break; | |
case Conditional: | |
if (Conditional ref t = a) | |
{ | |
auto condType = typecheck(t.cond); | |
auto truType = typecheck(t.tru); | |
auto flsType = typecheck(t.fls); | |
if (condType && truType && flsType) | |
{ | |
if (typeid(condType) == BoolType) | |
{ | |
if (equal(truType, flsType)) | |
t.type = truType; | |
else | |
io.out << "ERROR: Condition branch types have to equal " << truType << " != " << flsType << io.endl; | |
} | |
else | |
{ | |
io.out << "ERROR: Condition type has to be Bool got " << condType << io.endl; | |
} | |
} | |
} | |
break; | |
} | |
if (!a.type) | |
io.out << "ERROR: Failed to typecheck " << typeid(a).name << " '" << a << "'" << io.endl; | |
return a.type; | |
} | |
void Interpreter:free(char[] name, char[] code) | |
{ | |
p.lexer = Lexer(string(code)); | |
auto t = p.parseType(); | |
assert(t != nullptr, "failed to parse type"); | |
bindings[string(name)] = t; | |
} | |
void Interpreter:global(char[] name, char[] code) | |
{ | |
auto t = p.parse(code); | |
assert(t != nullptr, "failed to parse"); | |
typecheck(t); | |
assert(t.type != nullptr, "failed to typecheck"); | |
globals[string(name)] = t; | |
} | |
Term ref Interpreter:substitute(Term ref t, string x, Term ref rhs) | |
{ | |
switch (typeid(t)) | |
{ | |
case Variable: | |
if (Variable ref variable = t) | |
{ | |
if (variable.x == x) | |
return rhs; | |
} | |
break; | |
case Abstraction: | |
if (Abstraction ref abstraction = t) | |
{ | |
// Do not step into the abstractions that have a binding over same name | |
if (abstraction.x == x) | |
return t; | |
return new Abstraction(abstraction.x, abstraction.type, substitute(abstraction.t, x, rhs)); | |
} | |
break; | |
case Application: | |
if (Application ref application = t) | |
return new Application(substitute(application.t1, x, rhs), substitute(application.t2, x, rhs)); | |
break; | |
case Conditional: | |
if (Conditional ref conditional = t) | |
return new Conditional(substitute(conditional.cond, x, rhs), substitute(conditional.tru, x, rhs), substitute(conditional.fls, x, rhs)); | |
break; | |
} | |
return t; | |
} | |
Term ref Interpreter:eval(Term ref t) | |
{ | |
switch (typeid(t)) | |
{ | |
case Application: | |
if (Application ref application = t) | |
{ | |
auto t1 = eval(application.t1); | |
if (typeid(t1) != Abstraction) | |
io.out << "expected abstraction for t1, got: (" << typeid(t1).name << ") " << t1 << io.endl; | |
Abstraction ref lhs = t1; | |
return eval(substitute(lhs.t, lhs.x, eval(application.t2))); // eval result for big-step | |
} | |
break; | |
case Variable: | |
if (Variable ref variable = t) | |
{ | |
if (auto t = globals.find(variable.x)) | |
return eval(*t); | |
} | |
break; | |
case Conditional: | |
if (Conditional ref conditional = t) | |
{ | |
auto cond = eval(conditional.cond); | |
if (Boolean ref(cond).value) | |
return eval(conditional.tru); | |
else | |
return eval(conditional.fls); | |
} | |
} | |
return t; | |
} | |
Term ref Interpreter:wrap(Term ref t) | |
{ | |
for (i in globals) | |
{ | |
if (equal(i.value, t)) | |
return new Variable(i.key, i.value.type); | |
} | |
switch (typeid(t)) | |
{ | |
case Abstraction: | |
if (Abstraction ref abstraction = t) | |
return new Abstraction(abstraction.x, abstraction.xType, wrap(abstraction.t)); | |
break; | |
case Application: | |
if (Application ref application = t) | |
return new Application(wrap(application.t1), wrap(application.t2)); | |
break; | |
} | |
return t; | |
} | |
auto Interpreter:eval(char[] code) | |
{ | |
auto t = p.parse(code); | |
assert(t != nullptr, "failed to parse"); | |
typecheck(t); | |
assert(t.type != nullptr, "failed to typecheck"); | |
return eval(t); | |
} | |
void Interpreter:printeval(char[] code) | |
{ | |
io.out << code << io.endl; | |
auto t = eval(code); | |
io.out << "> " << wrap(t) << " : " << getName(t.type) << io.endl; | |
} | |
void Interpreter:printtypecheck(char[] code) | |
{ | |
auto t = p.parse(code); | |
assert(t != nullptr, "failed to parse"); | |
typecheck(t); | |
assert(t.type != nullptr, "failed to typecheck"); | |
io.out << "> " << t << " : " << getName(t.type) << io.endl; | |
} | |
Interpreter i; | |
i.global("not", "/x:Bool. if x then false else true"); | |
i.printeval("if true then (/x:Bool. x) else (/x:Bool. not x)"); | |
i.free("f", "Bool -> Bool"); | |
i.printtypecheck("f (if false then true else false)"); | |
i.printtypecheck("/x:Bool. f (if x then false else x)"); |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment