Last active
April 16, 2021 10:44
-
-
Save WheretIB/3a92d2a6bae7f81f337f5f54c5681eef to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import std.string; | |
import std.io; | |
import std.hashmap; | |
import std.typeinfo; | |
import std.vector; | |
// Standard library extensions (stuff that should have been in stdlib but isn't there for some reason) | |
int int(string str){ return int(str.data); } | |
int hash_value(string str){ return hash_value(str.data); } | |
StdOut operator<<(StdOut out, string str){ Print(str.data); return out; } | |
void hashmap:insert(Key k, Value v){ this[k] = v; } | |
auto hashmap:start() | |
{ | |
return coroutine auto(){ | |
for(int i = 0; i < entries.size; i++) for(auto node = entries[i]; node; node = node.next) yield node; | |
return nullptr; | |
}; | |
} | |
// Types | |
class Type extendable{} | |
class UnitType: Type{} | |
class BoolType: Type{} | |
class NatType: Type{} | |
class StringType: Type{} | |
class FloatType: Type{} | |
class FunctionType: Type{ Type ref arg, res; } | |
class TupleType: Type{ vector<Type ref> elements; } | |
class RecordPairType: Type{ string name; Type ref type; } | |
class RecordType: Type{ vector<Type ref> elements; } // almost like it can be interchanged with Tuple | |
// Type constructors | |
void FunctionType:FunctionType(Type ref arg, Type ref res){ this.arg = arg; this.res = res; } | |
void TupleType:TupleType(vector<Type ref> elements){ this.elements = elements; } | |
void RecordPairType:RecordPairType(string name, Type ref type){ this.name = name; this.type = type; } | |
void RecordType:RecordType(vector<Type ref> elements){ this.elements = elements; } | |
bool equal(Type ref a, Type ref b) | |
{ | |
if (!a || !b) | |
return !!a == !!b; | |
if (typeid(a) != typeid(b)) | |
return false; | |
switch (typeid(a)) | |
{ | |
case UnitType: | |
case BoolType: | |
case NatType: | |
case StringType: | |
case FloatType: | |
return true; | |
case FunctionType: | |
FunctionType ref funcA = a; | |
FunctionType ref funcB = b; | |
return equal(funcA.arg, funcB.arg) && equal(funcA.res, funcB.res); | |
case TupleType: | |
TupleType ref tupA = a; | |
TupleType ref tupB = b; | |
if (tupA.elements.size() != tupB.elements.size()) return false; | |
for (i in tupA.elements, j in tupB.elements) if (!equal(i, j)) return false; | |
return true; | |
case RecordPairType: | |
RecordPairType ref rpairA = a; | |
RecordPairType ref rpairB = b; | |
return rpairA.name == rpairB.name && equal(rpairA.type, rpairB.type); | |
case RecordType: | |
RecordType ref recA = a; | |
RecordType ref recB = b; | |
if (recA.elements.size() != recB.elements.size()) return false; | |
for (i in recA.elements, j in recB.elements) if (!equal(i, j)) return false; | |
return true; | |
} | |
} | |
// Output | |
string getName(Type ref a) | |
{ | |
if (!a) | |
return string("-"); | |
switch (typeid(a)) | |
{ | |
case UnitType: return string("Unit"); | |
case BoolType: return string("Bool"); | |
case NatType: return string("Nat"); | |
case StringType: return string("String"); | |
case FloatType: return string("Float"); | |
case FunctionType: | |
FunctionType ref func = a; | |
return getName(func.arg) + " -> " + getName(func.res); | |
case TupleType: | |
TupleType ref tup = a; | |
string tupStr = getName(tup.elements[0]); | |
for (int i = 1; i < tup.elements.size(); i++) | |
tupStr = tupStr + " * " + getName(tup.elements[i]); | |
return tupStr; | |
case RecordPairType: | |
RecordPairType ref rpair = a; | |
return rpair.name + ":" + getName(rpair.type); | |
case RecordType: | |
RecordType ref rec = a; | |
string recStr = "{" + getName(rec.elements[0]); | |
for (int i = 1; i < rec.elements.size(); i++) | |
recStr = recStr + "," + getName(rec.elements[i]); | |
return recStr + "}"; | |
} | |
} | |
StdOut operator<<(StdOut out, Type ref t){ out << getName(t); return out; } | |
// Terms | |
class Term extendable{ Type ref type; } | |
class Unit : Term{} | |
class Boolean : Term{ bool value; } | |
class Natural : Term{ int value; } | |
class String : Term{ string value; } | |
class Float : Term{ float value; } | |
class Variable: Term{ string x; } | |
class Abstraction: Term{ string x; Type ref xType; Term ref t; } | |
class Application: Term{ Term ref t1, t2; } | |
class Conditional: Term{ Term ref cond, tru, fls; } | |
class Ascription: Term{ Term ref t; Type ref expected; } | |
class Let: Term{ string name; Term ref t, expr; } | |
class Match: Term{ vector<string> names; Term ref t, expr; } | |
class Tuple: Term{ vector<Term ref> elements; } | |
class TupleAccess: Term{ Term ref t; int index; } | |
class RecordPair: Term{ string name; Term ref t; } | |
class Record: Term{ vector<Term ref> elements; } | |
class RecordAccess: Term{ Term ref t; string name; } | |
// Term constructors | |
void Boolean:Boolean(bool value){ this.value = value; } | |
void Natural:Natural(int value){ this.value = value; } | |
void String:String(string value){ this.value = value; } | |
void Float:Float(float value){ this.value = value; } | |
void Variable:Variable(string x){ this.x = x; } | |
void Variable:Variable(string x, Type ref type){ this.x = x; this.type = type; } | |
void Abstraction:Abstraction(string x, Type ref xType, Term ref t){ this.x = x; this.xType = xType; this.t = t; } | |
void Application:Application(Term ref t1, t2){ this.t1 = t1; this.t2 = t2; } | |
void Conditional:Conditional(Term ref cond, tru, fls){ this.cond = cond; this.tru = tru; this.fls = fls; } | |
void Let:Let(string name, Term ref t, expr){ this.name = name; this.t = t; this.expr = expr; } | |
void Match:Match(vector<string> names, Term ref t, expr){ this.names = names; this.t = t; this.expr = expr; } | |
void Ascription:Ascription(Term ref t, Type ref expected){ this.t = t; this.expected = expected; } | |
void Tuple:Tuple(vector<Term ref> elements){ this.elements = elements; } | |
void TupleAccess:TupleAccess(Term ref t, int index){ this.t = t; this.index = index; } | |
void RecordPair:RecordPair(string name, Term ref t){ this.name = name; this.t = t; } | |
void Record:Record(vector<Term ref> elements){ this.elements = elements; } | |
void RecordAccess:RecordAccess(Term ref t, string name){ this.t = t; this.name = name; } | |
// Optional equality (for testing) | |
bool equal(Term ref a, Term ref b) | |
{ | |
if (typeid(a) != typeid(b)) | |
return false; | |
if (!equal(a.type, b.type)) | |
return false; | |
switch (typeid(a)) | |
{ | |
case Unit: | |
return true; | |
case Boolean: | |
Boolean ref boolA = a; | |
Boolean ref boolB = b; | |
return boolA.value == boolB.value; | |
case Natural: | |
Natural ref natA = a; | |
Natural ref natB = b; | |
return natA.value == natB.value; | |
case String: | |
String ref strA = a; | |
String ref strB = b; | |
return strA.value == strB.value; | |
case Float: | |
Float ref floatA = a; | |
Float ref floatB = b; | |
return floatA.value == floatB.value; | |
case Variable: | |
Variable ref varA = a; | |
Variable ref varB = b; | |
return varA.x == varB.x; | |
case Abstraction: | |
Abstraction ref absA = a; | |
Abstraction ref absB = b; | |
return absA.x == absB.x && equal(absA.t, absB.t); | |
case Application: | |
Application ref appA = a; | |
Application ref appB = b; | |
return equal(appA.t1, appB.t1) && equal(appA.t2, appB.t2); | |
case Conditional: | |
Conditional ref condA = a; | |
Conditional ref condB = b; | |
return equal(condA.cond, condB.cond) && equal(condA.tru, condB.tru) && equal(condA.fls, condB.fls); | |
case Ascription: | |
Ascription ref ascA = a; | |
Ascription ref ascB = b; | |
return equal(ascA.t, ascB.t) && equal(ascA.expected, ascB.expected); | |
case Let: | |
Let ref letA = a; | |
Let ref letB = b; | |
return letA.name == letB.name && equal(letA.t, letB.t) && equal(letA.expr, letB.expr); | |
case Match: | |
Match ref matchA = a; | |
Match ref matchB = b; | |
if (matchA.names.size() != matchB.names.size()) return false; | |
for (i in matchA.names, j in matchB.names) if (i != j) return false; | |
return equal(matchA.t, matchB.t) && equal(matchA.expr, matchB.expr); | |
case Tuple: | |
Tuple ref tupA = a; | |
Tuple ref tupB = b; | |
if (tupA.elements.size() != tupB.elements.size()) return false; | |
for (i in tupA.elements, j in tupB.elements) if (!equal(i, j)) return false; | |
return true; | |
case TupleAccess: | |
TupleAccess ref tupaccA = a; | |
TupleAccess ref tupaccB = b; | |
return equal(tupaccA.t, tupaccB.t) && tupaccA.index == tupaccB.index; | |
case RecordPair: | |
RecordPair ref rpairA = a; | |
RecordPair ref rpairB = b; | |
return rpairA.name == rpairB.name && equal(rpairA.t, rpairB.t); | |
case Record: | |
Record ref recA = a; | |
Record ref recB = b; | |
if (recA.elements.size() != recB.elements.size()) return false; | |
for (i in recA.elements, j in recB.elements) if (!equal(i, j)) return false; | |
return true; | |
case RecordAccess: | |
RecordAccess ref recaccA = a; | |
RecordAccess ref recaccB = b; | |
return equal(recaccA.t, recaccB.t) && recaccA.name == recaccB.name; | |
} | |
} | |
// Term output | |
void output(Term ref a); | |
StdOut operator<<(StdOut out, Term ref t){ output(t); return out; } | |
void output(Term ref a) | |
{ | |
switch (typeid(a)) | |
{ | |
case Unit: | |
io.out << "unit"; | |
break; | |
case Boolean: | |
if (Boolean ref t = a) | |
io.out << (t.value ? "true" : "false"); | |
break; | |
case Natural: | |
if (Natural ref t = a) | |
io.out << t.value; | |
break; | |
case String: | |
if (String ref t = a) | |
io.out << "'" << t.value << "'"; | |
break; | |
case Float: | |
if (Float ref t = a) | |
io.out << t.value; | |
break; | |
case Variable: | |
if (Variable ref t = a) | |
io.out << t.x; | |
break; | |
case Abstraction: | |
if (Abstraction ref t = a) | |
{ | |
if (t.xType) | |
io.out << "/" << t.x << ":" << t.xType << ". " << t.t; | |
else | |
io.out << "/" << t.x << ". " << t.t; | |
} | |
break; | |
case Application: | |
if (Application ref t = a) | |
{ | |
if (typeid(t.t2) == Application) | |
io.out << t.t1 << " (" << t.t2 << ")"; | |
else | |
io.out << t.t1 << " " << t.t2; | |
} | |
break; | |
case Conditional: | |
if (Conditional ref t = a) | |
io.out << "if " << t.cond << " then " << t.tru << " else " << t.fls; | |
break; | |
case Ascription: | |
if (Ascription ref t = a) | |
io.out << "" << t.t << " as " << t.expected; | |
break; | |
case Let: | |
if (Let ref t = a) | |
io.out << "let " << t.name << "=" << t.t << " in " << t.expr; | |
break; | |
case Match: | |
if (Match ref t = a) | |
{ | |
io.out << "let {" << t.names[0]; | |
for (int i = 1; i < t.names.size(); i++) | |
io.out << "," << t.names[i]; | |
io.out << "}=" << t.t << " in " << t.expr; | |
} | |
break; | |
case Tuple: | |
if (Tuple ref t = a) | |
{ | |
io.out << "{" << t.elements[0]; | |
for (int i = 1; i < t.elements.size(); i++) | |
io.out << "," << t.elements[i]; | |
io.out << "}"; | |
} | |
break; | |
case TupleAccess: | |
if (TupleAccess ref t = a) | |
io.out << t.t << "." << t.index; | |
break; | |
case RecordPair: | |
if (RecordPair ref t = a) | |
io.out << t.name << "=" << t.t; | |
break; | |
case Record: | |
if (Record ref t = a) | |
{ | |
io.out << "{" << t.elements[0]; | |
for (int i = 1; i < t.elements.size(); i++) | |
io.out << "," << t.elements[i]; | |
io.out << "}"; | |
} | |
break; | |
case RecordAccess: | |
if (RecordAccess ref t = a) | |
io.out << t.t << "." << t.name; | |
break; | |
} | |
} | |
void outputln(Term ref t){ io.out << t << io.endl; } | |
// Lexer | |
enum LexemeType | |
{ | |
Lambda, | |
Number, | |
Rational, | |
Str, | |
QuotedStr, | |
Oparen, | |
Cparen, | |
Ofigure, | |
Cfigure, | |
Point, | |
Comma, | |
Colon, | |
Semicolon, | |
Arrow, | |
Mult, | |
Equal, | |
If, | |
Then, | |
Else, | |
Let_, | |
In, | |
Unknown, | |
Eof | |
} | |
class Lexeme | |
{ | |
void Lexeme(LexemeType type, string str){ this.type = type; this.str = str; } | |
void Lexeme(LexemeType type, char[] str){ this.type = type; this.str = str; } | |
LexemeType type = LexemeType.Eof; | |
string str; | |
} | |
class Lexer | |
{ | |
void Lexer(string str){ this.str = str; } | |
void skipSpaces() | |
{ | |
while (pos < str.length() && str[pos] <= ' ') | |
pos++; | |
} | |
auto peek() | |
{ | |
skipSpaces(); | |
int pos = this.pos; | |
if (pos == str.length()) | |
return Lexeme(); | |
auto start = pos; | |
// Lex lambda symbol | |
if (str[pos] == '/') return Lexeme(LexemeType.Lambda, str.substr(start, 1)); | |
if (str[pos] == '(') return Lexeme(LexemeType.Oparen, str.substr(start, 1)); | |
if (str[pos] == ')') return Lexeme(LexemeType.Cparen, str.substr(start, 1)); | |
if (str[pos] == '{') return Lexeme(LexemeType.Ofigure, str.substr(start, 1)); | |
if (str[pos] == '}') return Lexeme(LexemeType.Cfigure, str.substr(start, 1)); | |
if (str[pos] == '.') return Lexeme(LexemeType.Point, str.substr(start, 1)); | |
if (str[pos] == ',') return Lexeme(LexemeType.Comma, str.substr(start, 1)); | |
if (str[pos] == ':') return Lexeme(LexemeType.Colon, str.substr(start, 1)); | |
if (str[pos] == ';') return Lexeme(LexemeType.Semicolon, str.substr(start, 1)); | |
if (str[pos] == '*') return Lexeme(LexemeType.Mult, str.substr(start, 1)); | |
if (str[pos] == '=') return Lexeme(LexemeType.Equal, str.substr(start, 1)); | |
if (str[pos] == '-' && str[pos + 1] == '>') return Lexeme(LexemeType.Arrow, "->"); | |
if (str[pos] == '\'') | |
{ | |
pos++; | |
while (pos < str.length() && str[pos] != '\'') | |
pos++; | |
pos++; | |
return Lexeme(LexemeType.QuotedStr, str.substr(start, pos - start)); | |
} | |
// Lex identifier or a number | |
auto isAlnum = <char x> ((x | 32) >= 'a' && (x | 32) <= 'z') || x < 0 || x == '_'; | |
auto isDigit = <char x> x >= '0' && x <= '9'; | |
if (isAlnum(str[pos])) | |
{ | |
pos++; | |
while (pos < str.length() && (isAlnum(str[pos]) || isDigit(str[pos]))) | |
pos++; | |
} | |
else if (isDigit(str[pos])) | |
{ | |
pos++; | |
while (pos < str.length() && isDigit(str[pos])) | |
pos++; | |
if (str[pos] == '.') | |
{ | |
pos++; | |
while (pos < str.length() && isDigit(str[pos])) | |
pos++; | |
return Lexeme(LexemeType.Rational, str.substr(start, pos - start)); | |
} | |
return Lexeme(LexemeType.Number, str.substr(start, pos - start)); | |
} | |
string result = str.substr(start, pos - start); | |
if (result == "if") return Lexeme(LexemeType.If, result); | |
if (result == "then") return Lexeme(LexemeType.Then, result); | |
if (result == "else") return Lexeme(LexemeType.Else, result); | |
if (result == "let") return Lexeme(LexemeType.Let_, result); | |
if (result == "in") return Lexeme(LexemeType.In, result); | |
if (result.empty()) return Lexeme(LexemeType.Unknown, ""); | |
return Lexeme(LexemeType.Str, str.substr(start, pos - start)); | |
} | |
auto consume() | |
{ | |
skipSpaces(); | |
auto next = peek(); | |
pos += next.str.length(); | |
return next; | |
} | |
string str; | |
int pos; | |
} | |
class Parser | |
{ | |
Term ref parseExpr(); | |
Type ref parseType(); | |
Type ref parseSimpleType() | |
{ | |
assert(lexer.peek().type == LexemeType.Str, "type name expected"); | |
auto name = lexer.consume().str; | |
switch (name) | |
{ | |
case "Unit": return new UnitType(); | |
case "Bool": return new BoolType(); | |
case "Nat": return new NatType(); | |
case "String": return new StringType(); | |
case "Float": return new FloatType(); | |
} | |
assert(false, "unknown type name"); | |
} | |
Type ref parseRecordPairType() | |
{ | |
assert(lexer.peek().type == LexemeType.Str, "record name expected"); | |
auto name = lexer.consume().str; | |
assert(lexer.peek().type == LexemeType.Colon, "':' expected after name"); | |
lexer.consume(); | |
auto type = parseType(); | |
return new RecordPairType(name, type); | |
} | |
Type ref parseRecordType() | |
{ | |
if (lexer.peek().type == LexemeType.Ofigure) | |
{ | |
lexer.consume(); | |
vector<Type ref> elements; | |
elements.push_back(parseRecordPairType()); | |
while (lexer.peek().type == LexemeType.Comma) | |
{ | |
lexer.consume(); | |
elements.push_back(parseRecordPairType()); | |
} | |
assert(lexer.peek().type == LexemeType.Cfigure, "'}' expected after type"); | |
lexer.consume(); | |
return new RecordType(elements); | |
} | |
return parseSimpleType(); | |
} | |
Type ref parseTupleType() | |
{ | |
vector<Type ref> elements; | |
elements.push_back(parseRecordType()); | |
while (lexer.peek().type == LexemeType.Mult) | |
{ | |
lexer.consume(); | |
elements.push_back(parseRecordType()); | |
} | |
if (elements.size() == 1) | |
return elements[0]; | |
return new TupleType(elements); | |
} | |
Type ref parseFunctionType() | |
{ | |
Type ref t = parseTupleType(); | |
if (lexer.peek().type == LexemeType.Arrow) | |
{ | |
lexer.consume(); | |
t = new FunctionType(t, parseTupleType()); | |
} | |
return t; | |
} | |
Type ref parseType() | |
{ | |
return parseFunctionType(); | |
} | |
Term ref parseTerm() | |
{ | |
switch (lexer.peek().type) | |
{ | |
case LexemeType.Lambda: | |
lexer.consume(); | |
assert(lexer.peek().type == LexemeType.Str, "abstraction name expected after '/'"); | |
string x = lexer.consume().str; | |
assert(lexer.peek().type == LexemeType.Colon, "type expected after abstraction name"); | |
lexer.consume(); | |
Type ref type = parseType(); | |
assert(lexer.peek().type == LexemeType.Point, "'.' not found after abstracton name"); | |
lexer.consume(); | |
return new Abstraction(x, type, parseExpr()); | |
case LexemeType.Number: | |
return new Natural(int(lexer.consume().str.data)); | |
case LexemeType.Rational: | |
return new Float(float(lexer.consume().str.data)); | |
case LexemeType.Str: | |
auto str = lexer.consume().str; | |
if (str == "unit") | |
return new Unit(); | |
if (str == "true") | |
return new Boolean(true); | |
if (str == "false") | |
return new Boolean(false); | |
return new Variable(str); | |
case LexemeType.QuotedStr: | |
auto qstr = lexer.consume().str; | |
return new String(qstr.substr(1, qstr.length() - 2)); | |
case LexemeType.Oparen: | |
lexer.consume(); | |
auto t = parseExpr(); | |
assert(t != nullptr, "term expected after '('"); | |
assert(lexer.peek().type == LexemeType.Cparen, "')' not found after '('"); | |
lexer.consume(); | |
return t; | |
case LexemeType.Ofigure: | |
lexer.consume(); | |
auto first = parseExpr(); | |
assert(first != nullptr, "term expected after '{'"); | |
// Looks like a record instead of a tuple | |
if (typeid(first) == Variable && lexer.peek().type == LexemeType.Equal) | |
{ | |
lexer.consume(); | |
auto name = Variable ref(first).x; | |
auto value = parseExpr(); | |
assert(value != nullptr, "term expected after '='"); | |
vector<Term ref> elements; | |
elements.push_back(new RecordPair(name, value)); | |
while (lexer.peek().type == LexemeType.Comma) | |
{ | |
lexer.consume(); | |
assert(lexer.peek().type == LexemeType.Str, "name expected after ','"); | |
name = lexer.consume().str; | |
assert(lexer.peek().type == LexemeType.Equal, "'=' expected after name"); | |
lexer.consume(); | |
value = parseExpr(); | |
assert(value != nullptr, "term expected after '='"); | |
elements.push_back(new RecordPair(name, value)); | |
} | |
assert(lexer.peek().type == LexemeType.Cfigure, "'}' not found after term"); | |
lexer.consume(); | |
return new Record(elements); | |
} | |
vector<Term ref> elements; | |
elements.push_back(first); | |
while (lexer.peek().type == LexemeType.Comma) | |
{ | |
lexer.consume(); | |
elements.push_back(parseExpr()); | |
assert(*elements.back() != nullptr, "term expected after ','"); | |
} | |
assert(lexer.peek().type == LexemeType.Cfigure, "'}' not found after term"); | |
lexer.consume(); | |
return new Tuple(elements); | |
case LexemeType.If: | |
lexer.consume(); | |
auto cond = parseExpr(); | |
assert(cond != nullptr, "condition expected after 'if'"); | |
assert(lexer.peek().type == LexemeType.Then, "'then' not found after term"); | |
lexer.consume(); | |
auto tru = parseExpr(); | |
assert(tru != nullptr, "term expected after 'then'"); | |
assert(lexer.peek().type == LexemeType.Else, "'else' not found after term"); | |
lexer.consume(); | |
auto fls = parseExpr(); | |
assert(fls != nullptr, "term expected after 'else'"); | |
return new Conditional(cond, tru, fls); | |
case LexemeType.Let_: | |
lexer.consume(); | |
vector<string> names; | |
switch(lexer.peek().type) | |
{ | |
case LexemeType.Str: | |
names.push_back(lexer.consume().str); | |
break; | |
case LexemeType.Ofigure: | |
lexer.consume(); | |
assert(lexer.peek().type == LexemeType.Str, "name expected after '{'"); | |
names.push_back(lexer.consume().str); | |
while (lexer.peek().type == LexemeType.Comma) | |
{ | |
lexer.consume(); | |
assert(lexer.peek().type == LexemeType.Str, "name expected after ','"); | |
names.push_back(lexer.consume().str); | |
} | |
assert(lexer.peek().type == LexemeType.Cfigure, "'}' not found after name"); | |
lexer.consume(); | |
break; | |
default: | |
assert(false, "name or '{' expected after 'let'"); | |
} | |
assert(lexer.peek().type == LexemeType.Equal, "'=' not found after name"); | |
lexer.consume(); | |
auto term = parseExpr(); | |
assert(term != nullptr, "term expected after '='"); | |
assert(lexer.peek().type == LexemeType.In, "'in' not found after term"); | |
lexer.consume(); | |
auto expr = parseExpr(); | |
assert(expr != nullptr, "term expected after 'in'"); | |
if (names.size() == 1) | |
return new Let(names[0], term, expr); | |
return new Match(names, term, expr); | |
case LexemeType.Unknown: | |
assert(false, "unknown lexeme at " + lexer.pos.str()); | |
} | |
return nullptr; | |
} | |
Term ref parseAccess() | |
{ | |
auto t = parseTerm(); | |
while (lexer.peek().str == ".") | |
{ | |
lexer.consume(); | |
switch(lexer.peek().type) | |
{ | |
case LexemeType.Number: | |
int index = int(lexer.consume().str); | |
assert(index > 0, "index has to start at '1'"); | |
t = new TupleAccess(t, index); | |
break; | |
case LexemeType.Str: | |
string name = lexer.consume().str; | |
t = new RecordAccess(t, name); | |
break; | |
} | |
} | |
return t; | |
} | |
Term ref parseAs() | |
{ | |
auto t = parseAccess(); | |
while (lexer.peek().str == "as") | |
{ | |
lexer.consume(); | |
auto type = parseType(); | |
t = new Ascription(t, type); | |
} | |
return t; | |
} | |
Term ref parseTerms() | |
{ | |
auto t = parseAs(); | |
auto t2 = parseAs(); | |
while (t2) | |
{ | |
t = new Application(t, t2); | |
t2 = parseAs(); | |
} | |
return t; | |
} | |
Term ref parseSeq() | |
{ | |
auto t = parseTerms(); | |
while (lexer.peek().type == LexemeType.Semicolon) | |
{ | |
lexer.consume(); | |
auto t2 = parseTerms(); | |
assert(t2 != nullptr, "term expected after ';'"); | |
t = new Application(new Abstraction(string("_"), new UnitType(), t2), t); | |
} | |
return t; | |
} | |
Term ref parseExpr() | |
{ | |
return parseSeq(); | |
} | |
Term ref parse(char[] code) | |
{ | |
lexer = Lexer(string(code)); | |
auto t = parseExpr(); | |
assert(lexer.pos == lexer.str.length(), "unknown symbol " + {lexer.str[lexer.pos]}); | |
return t; | |
} | |
Lexer lexer; | |
} | |
class Interpreter | |
{ | |
Parser p; | |
hashmap<string, Type ref> bindings; | |
hashmap<string, Term ref> globals; // optional | |
} | |
Type ref Interpreter:typecheck(Term ref a) | |
{ | |
switch (typeid(a)) | |
{ | |
case Unit: | |
if (Unit ref t = a) | |
t.type = new UnitType(); | |
break; | |
case Boolean: | |
if (Boolean ref t = a) | |
t.type = new BoolType(); | |
break; | |
case Natural: | |
if (Natural ref t = a) | |
t.type = new NatType(); | |
break; | |
case String: | |
if (String ref t = a) | |
t.type = new StringType(); | |
break; | |
case Float: | |
if (Float ref t = a) | |
t.type = new FloatType(); | |
break; | |
case Variable: | |
if (Variable ref t = a) | |
{ | |
if (auto type = bindings.find(t.x)) | |
t.type = *type; | |
else if (auto global = globals.find(t.x)) | |
t.type = global.type; | |
else | |
io.out << "ERROR: Unknown variable " << t.x << io.endl; | |
} | |
break; | |
case Abstraction: | |
if (Abstraction ref t = a) | |
{ | |
auto lastType = bindings[t.x]; | |
if (t.x != "_") | |
bindings[t.x] = t.xType; | |
if (auto tType = typecheck(t.t)) | |
t.type = new FunctionType(t.xType, tType); | |
bindings[t.x] = lastType; | |
} | |
break; | |
case Application: | |
if (Application ref t = a) | |
{ | |
auto t2Type = typecheck(t.t2); | |
if (t2Type && typeid(t.t1) == Variable) | |
{ | |
Variable ref v = t.t1; | |
switch(v.x) | |
{ | |
case "succ": | |
if (typeid(t2Type) == NatType) | |
t.type = new NatType(); | |
else | |
io.out << v.x << " expects Nat as an argument, got " << t2Type << io.endl; | |
break; | |
case "pred": | |
if (typeid(t2Type) == NatType) | |
t.type = new NatType(); | |
else | |
io.out << v.x << " expects expects Nat as an argument, got " << t2Type << io.endl; | |
break; | |
case "iszero": | |
if (typeid(t2Type) == NatType) | |
t.type = new BoolType(); | |
else | |
io.out << v.x << " expects expects Nat as an argument, got " << t2Type << io.endl; | |
break; | |
case "print": | |
if (typeid(t2Type) == UnitType || typeid(t2Type) == BoolType || typeid(t2Type) == NatType || typeid(t2Type) == StringType || typeid(t2Type) == FloatType) | |
t.type = new UnitType(); | |
else | |
io.out << v.x << " expects expects base type as an argument, got " << t2Type << io.endl; | |
break; | |
} | |
if (t.type) | |
break; | |
} | |
auto t1Type = typecheck(t.t1); | |
if (t1Type && t2Type) | |
{ | |
if (typeid(t1Type) == FunctionType) | |
{ | |
FunctionType ref t1FuncType = t1Type; | |
if (equal(t1FuncType.arg, t2Type)) | |
t.type = t1FuncType.res; | |
else | |
io.out << "ERROR: Application rhs type " << t2Type << " has to match function argument type " << t1FuncType.arg << io.endl; | |
} | |
else | |
{ | |
io.out << "ERROR: Application lhs type has to be function got " << t1Type << io.endl; | |
} | |
} | |
} | |
break; | |
case Conditional: | |
if (Conditional ref t = a) | |
{ | |
auto condType = typecheck(t.cond); | |
auto truType = typecheck(t.tru); | |
auto flsType = typecheck(t.fls); | |
if (condType && truType && flsType) | |
{ | |
if (typeid(condType) == BoolType) | |
{ | |
if (equal(truType, flsType)) | |
t.type = truType; | |
else | |
io.out << "ERROR: Condition branch types have to equal " << truType << " != " << flsType << io.endl; | |
} | |
else | |
{ | |
io.out << "ERROR: Condition type has to be Bool got " << condType << io.endl; | |
} | |
} | |
} | |
break; | |
case Ascription: | |
if (Ascription ref t = a) | |
{ | |
auto tType = typecheck(t.t); | |
if (tType) | |
{ | |
if (equal(t.expected, tType)) | |
t.type = tType; | |
else | |
io.out << "ERROR: Ascription failed to match " << tType << " to " << t.expected << io.endl; | |
} | |
} | |
break; | |
case Let: | |
if (Let ref t = a) | |
{ | |
auto tType = typecheck(t.t); | |
if (tType) | |
{ | |
auto lastType = bindings[t.name]; | |
if (t.name != "_") | |
bindings[t.name] = tType; | |
if (auto exprType = typecheck(t.expr)) | |
t.type = exprType; | |
bindings[t.name] = lastType; | |
} | |
} | |
break; | |
case Match: | |
if (Match ref t = a) | |
{ | |
auto tType = typecheck(t.t); | |
if (tType) | |
{ | |
if (typeid(tType) == TupleType) | |
{ | |
TupleType ref matchTupleType = tType; | |
if (t.names.size() == matchTupleType.elements.size()) | |
{ | |
hashmap<string, Type ref> saved; | |
for (name in t.names, type in matchTupleType.elements) | |
{ | |
saved[name] = bindings[name]; | |
if (name != "_") | |
bindings[name] = type; | |
} | |
if (auto exprType = typecheck(t.expr)) | |
t.type = exprType; | |
for (el in saved) | |
bindings[el.key] = el.value; | |
} | |
else | |
{ | |
io.out << "ERROR: Mismatch between let pattern " << t << " and tuple type " << tType << io.endl; | |
} | |
} | |
else | |
{ | |
io.out << "ERROR: Cannot match let pattern to " << tType << io.endl; | |
} | |
} | |
} | |
break; | |
case Tuple: | |
if (Tuple ref t = a) | |
{ | |
vector<Type ref> elements; | |
for (i in t.elements) | |
{ | |
if (auto type = typecheck(i)) | |
elements.push_back(type); | |
else | |
break 2; | |
} | |
t.type = new TupleType(elements); | |
} | |
break; | |
case TupleAccess: | |
if (TupleAccess ref t = a) | |
{ | |
auto tType = typecheck(t.t); | |
if (tType) | |
{ | |
if (typeid(tType) == TupleType) | |
{ | |
TupleType ref tupType = tType; | |
if (t.index - 1 < tupType.elements.size()) | |
t.type = tupType.elements[t.index - 1]; | |
else | |
io.out << "ERROR: Index " << t.index << " is out of bounds of type " << tupType << io.endl; | |
} | |
else | |
{ | |
io.out << "ERROR: Cannot index type that is not a Tuple " << tType << io.endl; | |
} | |
} | |
} | |
break; | |
case RecordPair: | |
if (RecordPair ref t = a) | |
{ | |
if (auto tType = typecheck(t.t)) | |
t.type = new RecordPairType(t.name, tType); | |
} | |
break; | |
case Record: | |
if (Record ref t = a) | |
{ | |
vector<Type ref> elements; | |
for (i in t.elements) | |
{ | |
if (auto type = typecheck(i)) | |
{ | |
if (typeid(type) == RecordPairType) | |
{ | |
RecordPairType ref recPairType = type; | |
if (elements.count_if(<x> RecordPairType ref(x).name == recPairType.name) != 0) | |
io.out << "ERROR: record already has field named " << recPairType.name << io.endl; | |
else | |
elements.push_back(type); | |
} | |
else | |
{ | |
io.out << "ERROR: Record element type is not a RecordPair: " << type << io.endl; | |
} | |
} | |
else | |
{ | |
break 2; | |
} | |
} | |
t.type = new RecordType(elements); | |
} | |
break; | |
case RecordAccess: | |
if (RecordAccess ref t = a) | |
{ | |
auto tType = typecheck(t.t); | |
if (tType) | |
{ | |
if (typeid(tType) == RecordType) | |
{ | |
RecordType ref recType = tType; | |
if (recType.elements.count_if(<x> RecordPairType ref(x).name == t.name) == 0) | |
{ | |
io.out << "ERROR: record has no field named " << t.name << " in " << recType; | |
} | |
else | |
{ | |
for (i in recType.elements) | |
{ | |
if (RecordPairType ref(i).name == t.name) | |
t.type = RecordPairType ref(i).type; | |
} | |
} | |
} | |
else | |
{ | |
io.out << "ERROR: Cannot index type that is not a Record " << tType << io.endl; | |
} | |
} | |
} | |
break; | |
} | |
if (!a.type) | |
io.out << "ERROR: Failed to typecheck " << typeid(a).name << " '" << a << "'" << io.endl; | |
return a.type; | |
} | |
void Interpreter:free(char[] name, char[] code) | |
{ | |
p.lexer = Lexer(string(code)); | |
auto t = p.parseType(); | |
assert(t != nullptr, "failed to parse type"); | |
auto var = new Variable(string(name)); | |
var.type = t; | |
globals[string(name)] = var; | |
} | |
void Interpreter:global(char[] name, char[] code) | |
{ | |
auto t = p.parse(code); | |
assert(t != nullptr, "failed to parse"); | |
typecheck(t); | |
assert(t.type != nullptr, "failed to typecheck"); | |
globals[string(name)] = t; | |
} | |
Term ref Interpreter:substitute(Term ref a, string x, Term ref rhs) | |
{ | |
switch (typeid(a)) | |
{ | |
case Variable: | |
if (Variable ref t = a) | |
{ | |
if (t.x == x) | |
return rhs; | |
} | |
break; | |
case Abstraction: | |
if (Abstraction ref t = a) | |
{ | |
// Do not step into the abstractions that have a binding over same name | |
if (t.x != "_" && t.x == x) | |
return a; | |
return new Abstraction(t.x, t.type, substitute(t.t, x, rhs)); | |
} | |
break; | |
case Application: | |
if (Application ref t = a) | |
return new Application(substitute(t.t1, x, rhs), substitute(t.t2, x, rhs)); | |
break; | |
case Conditional: | |
if (Conditional ref t = a) | |
return new Conditional(substitute(t.cond, x, rhs), substitute(t.tru, x, rhs), substitute(t.fls, x, rhs)); | |
break; | |
case Ascription: | |
if (Ascription ref t = a) | |
return new Ascription(substitute(t.t, x, rhs), t.expected); | |
break; | |
case Let: | |
if (Let ref t = a) | |
{ | |
auto value = substitute(t.t, x, rhs); | |
auto expr = t.name == x ? t.expr : substitute(t.expr, x, rhs); | |
return new Let(t.name, value, expr); | |
} | |
break; | |
case Match: | |
if (Match ref t = a) | |
{ | |
auto value = substitute(t.t, x, rhs); | |
auto expr = t.names.count_if(<name> name == x) != 0 ? t.expr : substitute(t.expr, x, rhs); | |
return new Match(t.names, value, expr); | |
} | |
break; | |
case Tuple: | |
if (Tuple ref t = a) | |
{ | |
vector<Term ref> elements; | |
for (i in t.elements) | |
elements.push_back(substitute(i, x, rhs)); | |
return new Tuple(elements); | |
} | |
break; | |
case TupleAccess: | |
if (TupleAccess ref t = a) | |
return new TupleAccess(substitute(t.t, x, rhs), t.index); | |
break; | |
} | |
return a; | |
} | |
Term ref Interpreter:eval(Term ref t); | |
Term ref Interpreter:evalBuiltin(Variable ref v, Term ref t) | |
{ | |
t = eval(t); | |
switch(v.x) | |
{ | |
case "succ": | |
return new Natural(Natural ref(t).value + 1); | |
case "pred": | |
if (Natural ref(t).value != 0) | |
return new Natural(Natural ref(t).value - 1); | |
case "iszero": | |
return new Boolean(Natural ref(t).value == 0); | |
case "print": | |
switch (typeid(t)) | |
{ | |
case Unit: io.out << "STDOUT: " << io.endl; break; | |
case Boolean: io.out << "STDOUT: " << Boolean ref(t).value << io.endl; break; | |
case Natural: io.out << "STDOUT: " << Natural ref(t).value << io.endl; break; | |
case String: io.out << "STDOUT: " << String ref(t).value << io.endl; break; | |
case Float: io.out << "STDOUT: " << Float ref(t).value << io.endl; break; | |
} | |
return new Unit(); | |
} | |
return nullptr; | |
} | |
Term ref Interpreter:eval(Term ref t) | |
{ | |
switch (typeid(t)) | |
{ | |
case Application: | |
if (Application ref application = t) | |
{ | |
auto t1 = eval(application.t1); | |
if (typeid(t1) == Variable) | |
{ | |
if (auto r = evalBuiltin(t1, application.t2)) | |
return r; | |
} | |
if (typeid(t1) != Abstraction) | |
io.out << "expected abstraction for t1, got: (" << typeid(t1).name << ") " << t1 << io.endl; | |
Abstraction ref lhs = t1; | |
auto value = eval(application.t2); | |
return lhs.x == "_" ? eval(lhs.t) : eval(substitute(lhs.t, lhs.x, value)); | |
} | |
break; | |
case Variable: | |
if (Variable ref variable = t) | |
{ | |
if (auto t = globals.find(variable.x)) | |
return eval(*t); | |
} | |
break; | |
case Conditional: | |
if (Conditional ref conditional = t) | |
{ | |
auto cond = eval(conditional.cond); | |
if (Boolean ref(cond).value) | |
return eval(conditional.tru); | |
else | |
return eval(conditional.fls); | |
} | |
break; | |
case Let: | |
if (Let ref a = t) | |
{ | |
auto value = eval(a.t); | |
return a.name == "_" ? eval(a.expr) : eval(substitute(a.expr, a.name, value)); | |
} | |
break; | |
case Match: | |
if (Match ref a = t) | |
{ | |
Tuple ref value = eval(a.t); | |
auto expr = a.expr; | |
for (i in a.names, j in value.elements) | |
expr = i == "_" ? expr : substitute(expr, i, j); | |
return eval(expr); | |
} | |
break; | |
case TupleAccess: | |
if (TupleAccess ref access = t) | |
{ | |
Tuple ref value = eval(access.t); | |
return eval(value.elements[access.index - 1]); // should be typechecked already | |
} | |
break; | |
} | |
return t; | |
} | |
auto Interpreter:eval(char[] code) | |
{ | |
auto t = p.parse(code); | |
assert(t != nullptr, "failed to parse"); | |
typecheck(t); | |
assert(t.type != nullptr, "failed to typecheck"); | |
return eval(t); | |
} | |
void Interpreter:printeval(char[] code) | |
{ | |
io.out << code << io.endl; | |
auto t = eval(code); | |
typecheck(t); | |
io.out << "> " << t << " : " << getName(t.type) << io.endl; | |
} | |
void Interpreter:printtest(char[] code, char[] result) | |
{ | |
io.out << code << io.endl; | |
auto t = eval(code); | |
typecheck(t); | |
io.out << "> " << t << " : " << getName(t.type) << io.endl; | |
auto r = eval(result); | |
typecheck(r); | |
if (!equal(t, r)) | |
assert(false, "result mismatch"); | |
} | |
void Interpreter:printtypecheck(char[] code) | |
{ | |
auto t = p.parse(code); | |
assert(t != nullptr, "failed to parse"); | |
typecheck(t); | |
assert(t.type != nullptr, "failed to typecheck"); | |
io.out << "> " << t << " : " << getName(t.type) << io.endl; | |
} | |
Interpreter i; | |
i.global("not", "/x:Bool. if x then false else true"); | |
i.printeval("if true then (/x:Bool. x) else (/x:Bool. not x)"); | |
i.free("f", "Bool -> Bool"); | |
i.printtypecheck("f (if false then true else false)"); | |
i.printtypecheck("/x:Bool. f (if x then false else x)"); | |
i.printeval("unit"); | |
i.printeval("true"); | |
i.printeval("1"); | |
i.printeval("'hello'"); | |
i.printeval("1.3"); | |
i.printeval("unit;2"); | |
i.printeval("2 as Nat"); | |
i.printtypecheck("f"); | |
i.free("tupTest", "Bool * Bool -> Bool * Nat"); | |
i.printtypecheck("tupTest"); | |
i.printeval("{2, 'hello', false}"); | |
i.printtest("{2, 'hello', false}.2", "'hello'"); | |
i.printtest("pred 3", "2"); | |
i.printtest("{pred 4, if true then false else false}.1", "3"); | |
i.printtest("(/x:Nat*Nat. x.2) {pred 4, pred 5}", "4"); | |
i.printeval("let a = {2, 3} in a.2"); | |
// some substitution tests | |
i.printeval("(/x:Nat. x as Nat) 2"); | |
i.printeval("(/x:Nat. let a=x in a) 2"); | |
i.printeval("(/x:Nat. let x={x, x} in x.1) 2"); | |
i.printeval("(/x:Nat. {x, x}) 2"); | |
i.printeval("(/x:Nat*Nat. x.2) {2, 3}"); | |
i.printeval("print 'hello world'; 2"); | |
i.printeval("{x=5}"); | |
i.printeval("{partno=5524,cost=30.27}"); | |
i.printeval("{x=5} as {x:Nat}"); | |
i.printeval("{partno=5524,cost=30.27} as {partno:Nat,cost:Float}"); | |
i.printeval("let {a,b} = {2, 3} in print b;a"); | |
i.printtest("(/_:Nat. 2) (print 'test';3)", "2"); |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment