Skip to content

Instantly share code, notes, and snippets.

@WheretIB
Created April 30, 2021 23:05
Show Gist options
  • Save WheretIB/5d63666f64ad3c30a32a082bfc55cb8f to your computer and use it in GitHub Desktop.
Save WheretIB/5d63666f64ad3c30a32a082bfc55cb8f to your computer and use it in GitHub Desktop.
// C++17
#include <stdio.h>
#include <algorithm>
#include <string>
#include <vector>
#include <unordered_map>
#include <typeinfo>
#include <iostream>
using namespace std; // haha
void assert(bool cond)
{
if(!cond)
abort();
}
void assert(bool cond, string message)
{
if(!cond)
{
cout << message << endl;
abort();
}
}
// Types
struct Type{
virtual ~Type(){}
};
struct UnitType : Type{};
struct BoolType : Type{};
struct NatType : Type{};
struct StringType : Type{};
struct FloatType : Type{};
struct FunctionType : Type
{
FunctionType(Type* arg, Type* res)
{
this->arg = arg; this->res = res;
}
Type* arg; Type* res;
};
struct TupleType : Type
{
TupleType(vector<Type*> elements)
{
this->elements = elements;
}
vector<Type*> elements;
};
struct TaggedType : Type
{
TaggedType(string name, Type* type)
{
this->name = name; this->type = type;
}
string name; Type* type;
};
struct RecordType : Type
{
RecordType(vector<TaggedType*> elements)
{
this->elements = elements;
}
vector<TaggedType*> elements;
}; // almost like it can be interchanged with Tuple
struct SumType : Type
{
SumType(Type* left, Type* right)
{
this->left = left; this->right = right;
}
Type* left; Type* right;
};
struct VariantType : Type
{
VariantType(vector<TaggedType*> elements)
{
this->elements = elements;
}
vector<TaggedType*> elements;
};
struct ListType : Type
{
ListType(Type* element)
{
this->element = element;
}
Type* element;
};
bool isEqual(Type* a, Type* b)
{
if(!a || !b)
return !!a == !!b;
if(typeid(*a) != typeid(*b))
return false;
if(dynamic_cast<UnitType*>(a))
return true;
if(dynamic_cast<BoolType*>(a))
return true;
if(dynamic_cast<NatType*>(a))
return true;
if(dynamic_cast<StringType*>(a))
return true;
if(dynamic_cast<FloatType*>(a))
return true;
if(dynamic_cast<FunctionType*>(a))
{
auto funcA = dynamic_cast<FunctionType*>(a);
auto funcB = dynamic_cast<FunctionType*>(b);
return isEqual(funcA->arg, funcB->arg) && isEqual(funcA->res, funcB->res);
}
if(dynamic_cast<TupleType*>(a))
{
auto tupA = dynamic_cast<TupleType*>(a);
auto tupB = dynamic_cast<TupleType*>(b);
if(tupA->elements.size() != tupB->elements.size()) return false;
for(unsigned pos = 0; pos < tupA->elements.size(); pos++) if(!isEqual(tupA->elements[pos], tupB->elements[pos])) return false;
return true;
}
if(dynamic_cast<TaggedType*>(a))
{
auto tagA = dynamic_cast<TaggedType*>(a);
auto tagB = dynamic_cast<TaggedType*>(b);
return tagA->name == tagA->name && isEqual(tagA->type, tagB->type);
}
if(dynamic_cast<RecordType*>(a))
{
auto recA = dynamic_cast<RecordType*>(a);
auto recB = dynamic_cast<RecordType*>(b);
if(recA->elements.size() != recB->elements.size()) return false;
for(unsigned pos = 0; pos < recA->elements.size(); pos++) if(!isEqual(recA->elements[pos], recB->elements[pos])) return false;
return true;
}
if(dynamic_cast<SumType*>(a))
{
auto sumA = dynamic_cast<SumType*>(a);
auto sumB = dynamic_cast<SumType*>(b);
return isEqual(sumA->left, sumB->left) && isEqual(sumA->right, sumB->right);
}
if(dynamic_cast<VariantType*>(a))
{
auto varA = dynamic_cast<VariantType*>(a);
auto varB = dynamic_cast<VariantType*>(b);
if(varA->elements.size() != varB->elements.size()) return false;
for(unsigned pos = 0; pos < varA->elements.size(); pos++) if(!isEqual(varA->elements[pos], varB->elements[pos])) return false;
return true;
}
if(dynamic_cast<ListType*>(a))
{
auto listA = dynamic_cast<ListType*>(a);
auto listB = dynamic_cast<ListType*>(b);
return isEqual(listA->element, listB->element);
}
assert(false);
return false;
}
// Output
unordered_map<string, Type*> aliases; // global state, eww
string getName(Type* a)
{
if(!a)
return string("-");
for(auto i : aliases)
{
if(isEqual(a, i.second))
return i.first;
}
if(dynamic_cast<UnitType*>(a))
return string("Unit");
if(dynamic_cast<BoolType*>(a))
return string("Bool");
if(dynamic_cast<NatType*>(a))
return string("Nat");
if(dynamic_cast<StringType*>(a))
return string("String");
if(dynamic_cast<FloatType*>(a))
return string("Float");
if(FunctionType* func = dynamic_cast<FunctionType*>(a))
{
if(dynamic_cast<FunctionType*>(func->arg))
return "(" + getName(func->arg) + ")->" + getName(func->res);
return getName(func->arg) + "->" + getName(func->res);
}
if(TupleType* tup = dynamic_cast<TupleType*>(a))
{
string tupStr = getName(tup->elements[0]);
for(unsigned i = 1; i < tup->elements.size(); i++)
tupStr = tupStr + " * " + getName(tup->elements[i]);
return tupStr;
}
if(TaggedType* rpair = dynamic_cast<TaggedType*>(a))
{
return rpair->name + ":" + getName(rpair->type);
}
if(RecordType* rec = dynamic_cast<RecordType*>(a))
{
string recStr = "{" + getName(rec->elements[0]);
for(unsigned i = 1; i < rec->elements.size(); i++)
recStr = recStr + "," + getName(rec->elements[i]);
return recStr + "}";
}
if(SumType* sum = dynamic_cast<SumType*>(a))
{
return getName(sum->left) + "+" + getName(sum->right);
}
if(VariantType* var = dynamic_cast<VariantType*>(a))
{
string varStr = "<" + getName(var->elements[0]);
for(unsigned i = 1; i < var->elements.size(); i++)
varStr = varStr + "," + getName(var->elements[i]);
return varStr + ">";
}
if(ListType* list = dynamic_cast<ListType*>(a))
{
return "[" + getName(list->element) + "]";
}
assert(false);
return "";
}
ostream& operator<<(ostream& out, Type* t)
{
out << getName(t); return out;
}
// Terms
struct Term{
virtual ~Term(){}
Type* type = nullptr;
};
struct Unit : Term{};
struct Boolean : Term
{
Boolean(bool value)
{
this->value = value;
}
bool value;
};
struct Natural : Term
{
Natural(int value)
{
this->value = value;
}
int value;
};
struct String : Term
{
String(string value)
{
this->value = value;
}
string value;
};
struct Float : Term
{
Float(float value)
{
this->value = value;
}
float value;
};
struct Variable : Term
{
Variable(string x)
{
this->x = x;
}
Variable(string x, Type* type)
{
this->x = x; this->type = type; assert(type != nullptr);
}
string x;
};
struct Abstraction : Term
{
Abstraction(string x, Type* xType, Term* t)
{
this->x = x; this->xType = xType; this->t = t; assert(xType != nullptr);
}
string x; Type* xType; Term* t;
};
struct Application : Term
{
Application(Term* t1, Term* t2)
{
this->t1 = t1; this->t2 = t2;
}
Term* t1;
Term* t2;
};
struct Conditional : Term
{
Conditional(Term* cond, Term* tru, Term* fls)
{
this->cond = cond; this->tru = tru; this->fls = fls;
}
Term* cond;
Term* tru;
Term* fls;
};
struct TypeAlias : Term
{
TypeAlias(string name, Type* value)
{
this->name = name; this->value = value; assert(value != nullptr);
}
string name; Type* value;
}; // TODO: if we allow type name to be a stand-alone term, we can use regular assignment
struct Assignment : Term
{
Assignment(string name, Term* value)
{
this->name = name; this->value = value;
}
string name; Term* value;
};
struct Ascription : Term
{
Ascription(Term* t, Type* expected)
{
this->t = t; this->expected = expected; assert(expected != nullptr);
}
Term* t; Type* expected;
};
struct Let : Term
{
Let(string name, Term* t, Term* expr)
{
this->name = name; this->t = t; this->expr = expr;
}
string name; Term* t; Term* expr;
};
struct Match : Term
{
Match(vector<string> names, Term* t, Term* expr)
{
this->names = names; this->t = t; this->expr = expr;
}
vector<string> names; Term* t; Term* expr;
};
struct Tuple : Term
{
Tuple(vector<Term*> elements)
{
this->elements = elements;
}
vector<Term*> elements;
};
struct TupleAccess : Term
{
TupleAccess(Term* t, int index)
{
this->t = t; this->index = index;
}
Term* t; int index;
};
struct TaggedPair : Term
{
TaggedPair(string name, Term* t)
{
this->name = name; this->t = t;
}
string name; Term* t;
};
struct Record : Term
{
Record(vector<TaggedPair*> elements)
{
this->elements = elements;
}
vector<TaggedPair*> elements;
};
struct RecordAccess : Term
{
RecordAccess(Term* t, string name)
{
this->t = t; this->name = name;
}
Term* t; string name;
};
struct Variant : Term
{
Variant(string label, Term* t, Type* variant)
{
this->label = label; this->t = t; this->variant = variant; assert(variant != nullptr);
}
string label; Term* t; Type* variant;
};
struct CaseOption : Term
{
CaseOption(string label, string name, Term* t)
{
this->label = label; this->name = name; this->t = t;
}
string label; string name; Term* t;
}; // cannot be typechecked by itself
struct Case : Term
{
Case(Term* t, vector<CaseOption*> options)
{
this->t = t; this->options = options;
}
Term* t; vector<CaseOption*> options;
};
struct List : Term
{
List(Term* head, Term* tail, Type* type)
{
this->head = head; this->tail = tail; this->type = type; assert(type != nullptr);
}
Term* head;
Term* tail;
};
// Term output
void output(Term* a);
ostream& operator<<(ostream& out, Term* t)
{
output(t); return out;
}
void output(Term* a)
{
if(auto t = dynamic_cast<Unit*>(a))
{
cout << "unit";
}
if(auto t = dynamic_cast<Boolean*>(a))
{
cout << (t->value ? "true" : "false");
}
if(auto t = dynamic_cast<Natural*>(a))
{
cout << t->value;
}
if(auto t = dynamic_cast<String*>(a))
{
cout << "'" << t->value << "'";
}
if(auto t = dynamic_cast<Float*>(a))
{
cout << t->value;
}
if(auto t = dynamic_cast<Variable*>(a))
{
cout << t->x;
}
if(auto t = dynamic_cast<Abstraction*>(a))
{
if(t->xType)
cout << "/" << t->x << ":" << t->xType << ". " << t->t;
else
cout << "/" << t->x << ". " << t->t;
}
if(auto t = dynamic_cast<Application*>(a))
{
if(dynamic_cast<Application*>(t->t2))
cout << t->t1 << " (" << t->t2 << ")";
else
cout << t->t1 << " " << t->t2;
}
if(auto t = dynamic_cast<Conditional*>(a))
{
cout << "if " << t->cond << " then " << t->tru << " else " << t->fls;
}
if(auto t = dynamic_cast<TypeAlias*>(a))
{
cout << "type " << t->name << " = " << t->value;
}
if(auto t = dynamic_cast<Assignment*>(a))
{
cout << "local " << t->name << " = " << t->value;
}
if(auto t = dynamic_cast<Ascription*>(a))
{
cout << t->t << " as " << t->expected;
}
if(auto t = dynamic_cast<Let*>(a))
{
cout << "let " << t->name << "=" << t->t << " in " << t->expr;
}
if(auto t = dynamic_cast<Match*>(a))
{
cout << "let {" << t->names[0];
for(unsigned i = 1; i < t->names.size(); i++)
cout << "," << t->names[i];
cout << "}=" << t->t << " in " << t->expr;
}
if(auto t = dynamic_cast<Tuple*>(a))
{
cout << "{" << t->elements[0];
for(unsigned i = 1; i < t->elements.size(); i++)
cout << "," << t->elements[i];
cout << "}";
}
if(auto t = dynamic_cast<TupleAccess*>(a))
{
cout << t->t << "." << t->index;
}
if(auto t = dynamic_cast<TaggedPair*>(a))
{
cout << t->name << "=" << t->t;
}
if(auto t = dynamic_cast<Record*>(a))
{
cout << "{" << t->elements[0];
for(unsigned i = 1; i < t->elements.size(); i++)
cout << "," << t->elements[i];
cout << "}";
}
if(auto t = dynamic_cast<RecordAccess*>(a))
{
cout << t->t << "." << t->name;
}
if(auto t = dynamic_cast<Variant*>(a))
{
cout << "<" << t->label << "=" << t->t << "> as " << t->variant;
}
if(auto t = dynamic_cast<CaseOption*>(a))
{
cout << "<" << t->label << "=" << t->name << "> -> " << t->t;
}
if(auto t = dynamic_cast<Case*>(a))
{
cout << "case " << t->t << " of" << endl;
cout << t->options[0] << endl;
for(unsigned i = 1; i < t->options.size(); i++)
cout << "| " << t->options[i];
}
if(auto t = dynamic_cast<List*>(a))
{
if(!t->head)
cout << "nil" << t->type;
else
cout << "cons" << t->type << " " << t->head << " " << t->tail;
}
}
void outputln(Term* t)
{
cout << t << endl;
}
// Lexer
enum LexemeType
{
Lambda,
Number,
Rational,
Str,
QuotedStr,
Oparen,
Cparen,
Ofigure,
Cfigure,
Obracket,
Cbracket,
Less,
Greater,
Point,
Comma,
Colon,
Semicolon,
Pipe,
Arrow,
Add,
Mult,
Equal,
If,
Then,
Else,
Let_,
Letrec,
In,
Case_,
Of,
Nil,
Unknown,
Eof
};
struct Lexeme
{
Lexeme() = default;
Lexeme(LexemeType type, string str)
{
this->type = type; this->str = str;
}
LexemeType type = LexemeType::Eof;
string str;
};
struct Lexer
{
Lexer() = default;
Lexer(string str)
{
this->str = str;
}
void skipSpaces()
{
while(pos < str.length() && str[pos] <= ' ')
pos++;
}
auto peek()
{
skipSpaces();
unsigned pos = this->pos;
if(pos == str.length())
return Lexeme();
auto start = pos;
// Lex lambda symbol
if(str[pos] == '/') return Lexeme(LexemeType::Lambda, str.substr(start, 1));
if(str[pos] == '(') return Lexeme(LexemeType::Oparen, str.substr(start, 1));
if(str[pos] == ')') return Lexeme(LexemeType::Cparen, str.substr(start, 1));
if(str[pos] == '{') return Lexeme(LexemeType::Ofigure, str.substr(start, 1));
if(str[pos] == '}') return Lexeme(LexemeType::Cfigure, str.substr(start, 1));
if(str[pos] == '[') return Lexeme(LexemeType::Obracket, str.substr(start, 1));
if(str[pos] == ']') return Lexeme(LexemeType::Cbracket, str.substr(start, 1));
if(str[pos] == '<') return Lexeme(LexemeType::Less, str.substr(start, 1));
if(str[pos] == '>') return Lexeme(LexemeType::Greater, str.substr(start, 1));
if(str[pos] == '.') return Lexeme(LexemeType::Point, str.substr(start, 1));
if(str[pos] == ',') return Lexeme(LexemeType::Comma, str.substr(start, 1));
if(str[pos] == ':') return Lexeme(LexemeType::Colon, str.substr(start, 1));
if(str[pos] == ';') return Lexeme(LexemeType::Semicolon, str.substr(start, 1));
if(str[pos] == '|') return Lexeme(LexemeType::Pipe, str.substr(start, 1));
if(str[pos] == '+') return Lexeme(LexemeType::Add, str.substr(start, 1));
if(str[pos] == '*') return Lexeme(LexemeType::Mult, str.substr(start, 1));
if(str[pos] == '=') return Lexeme(LexemeType::Equal, str.substr(start, 1));
if(str[pos] == '-' && str[pos + 1] == '>') return Lexeme(LexemeType::Arrow, "->");
if(str[pos] == '\'')
{
pos++;
while(pos < str.length() && str[pos] != '\'')
pos++;
pos++;
return Lexeme(LexemeType::QuotedStr, str.substr(start, pos - start));
}
// Lex identifier or a number
auto isAlnum = [](char x){ return ((x | 32) >= 'a' && (x | 32) <= 'z') || x < 0 || x == '_'; };
auto isDigit = [](char x){ return x >= '0' && x <= '9'; };
if(isAlnum(str[pos]))
{
pos++;
while(pos < str.length() && (isAlnum(str[pos]) || isDigit(str[pos])))
pos++;
}
else if(isDigit(str[pos]))
{
pos++;
while(pos < str.length() && isDigit(str[pos]))
pos++;
if(str[pos] == '.')
{
pos++;
while(pos < str.length() && isDigit(str[pos]))
pos++;
return Lexeme(LexemeType::Rational, str.substr(start, pos - start));
}
return Lexeme(LexemeType::Number, str.substr(start, pos - start));
}
string result = str.substr(start, pos - start);
if(result == "if") return Lexeme(LexemeType::If, result);
if(result == "then") return Lexeme(LexemeType::Then, result);
if(result == "else") return Lexeme(LexemeType::Else, result);
if(result == "let") return Lexeme(LexemeType::Let_, result);
if(result == "letrec") return Lexeme(LexemeType::Letrec, result);
if(result == "in") return Lexeme(LexemeType::In, result);
if(result == "case") return Lexeme(LexemeType::Case_, result);
if(result == "of") return Lexeme(LexemeType::Of, result);
if(result == "nil") return Lexeme(LexemeType::Nil, result);
if(result.empty()) return Lexeme(LexemeType::Unknown, "");
return Lexeme(LexemeType::Str, str.substr(start, pos - start));
}
auto consume()
{
skipSpaces();
auto next = peek();
pos += next.str.length();
return next;
}
string str;
unsigned pos = 0;
};
struct Parser
{
Type* parseSimpleType()
{
assert(lexer.peek().type == LexemeType::Str, "type name expected");
auto name = lexer.consume().str;
if(name == "Unit") return new UnitType();
if(name == "Bool") return new BoolType();
if(name == "Nat") return new NatType();
if(name == "String") return new StringType();
if(name == "Float") return new FloatType();
if(auto t = aliases.find(name); t != aliases.end())
return t->second;
assert(false, "unknown type name " + name);
return nullptr;
}
TaggedType* parseTaggedType()
{
assert(lexer.peek().type == LexemeType::Str, "name expected");
auto name = lexer.consume().str;
if(lexer.peek().type == LexemeType::Colon)
{
lexer.consume();
auto type = parseType();
return new TaggedType(name, type);
}
return new TaggedType(name, new UnitType());
}
Type* parseAtomType()
{
if(lexer.peek().type == LexemeType::Ofigure)
{
lexer.consume();
vector<TaggedType*> elements;
elements.push_back(parseTaggedType());
while(lexer.peek().type == LexemeType::Comma)
{
lexer.consume();
elements.push_back(parseTaggedType());
}
assert(lexer.peek().type == LexemeType::Cfigure, "'}' expected after type");
lexer.consume();
return new RecordType(elements);
}
if(lexer.peek().type == LexemeType::Obracket)
{
lexer.consume();
auto element = parseType();
assert(lexer.peek().type == LexemeType::Cbracket, "']' expected after type");
lexer.consume();
return new ListType(element);
}
if(lexer.peek().type == LexemeType::Less)
{
lexer.consume();
vector<TaggedType*> elements;
elements.push_back(parseTaggedType());
while(lexer.peek().type == LexemeType::Comma)
{
lexer.consume();
elements.push_back(parseTaggedType());
}
assert(lexer.peek().type == LexemeType::Greater, "'>' expected after type");
lexer.consume();
return new VariantType(elements);
}
return parseSimpleType();
}
Type* parseTupleType()
{
vector<Type*> elements;
elements.push_back(parseAtomType());
while(lexer.peek().type == LexemeType::Mult)
{
lexer.consume();
elements.push_back(parseAtomType());
}
if(elements.size() == 1)
return elements[0];
return new TupleType(elements);
}
Type* parseSumType()
{
vector<Type*> elements;
elements.push_back(parseTupleType());
while(lexer.peek().type == LexemeType::Add)
{
lexer.consume();
elements.push_back(parseTupleType());
}
if(elements.size() == 1)
return elements[0];
return new TupleType(elements);
}
Type* parseFunctionType()
{
Type* t = parseSumType();
if(lexer.peek().type == LexemeType::Arrow)
{
lexer.consume();
t = new FunctionType(t, parseFunctionType());
}
return t;
}
Type* parseType()
{
return parseFunctionType();
}
CaseOption* parseCaseOption()
{
assert(lexer.peek().type == LexemeType::Less, "'<' expected");
lexer.consume();
assert(lexer.peek().type == LexemeType::Str, "label expected after '<'");
string label = lexer.consume().str;
assert(lexer.peek().type == LexemeType::Equal, "'=' expected after label");
lexer.consume();
assert(lexer.peek().type == LexemeType::Str, "name expected after '='");
string name = lexer.consume().str;
assert(lexer.peek().type == LexemeType::Greater, "'>' expected after name");
lexer.consume();
assert(lexer.peek().type == LexemeType::Arrow, "'->' expected after '>");
lexer.consume();
auto t = parseExpr();
assert(t != nullptr, "term expected after '->'");
return new CaseOption(label, name, t);
}
Term* parseTerm()
{
switch(lexer.peek().type)
{
case LexemeType::Lambda:
{
lexer.consume();
assert(lexer.peek().type == LexemeType::Str, "abstraction name expected after '/'");
string x = lexer.consume().str;
assert(lexer.peek().type == LexemeType::Colon, "type expected after abstraction name");
lexer.consume();
Type* type = parseType();
assert(lexer.peek().type == LexemeType::Point, "'.' not found after abstracton name");
lexer.consume();
return new Abstraction(x, type, parseExpr());
}
case LexemeType::Number:
{
return new Natural(strtol(lexer.consume().str.c_str(), nullptr, 10));
}
case LexemeType::Rational:
{
return new Float(strtof(lexer.consume().str.c_str(), nullptr));
}
case LexemeType::Str:
{
auto str = lexer.consume().str;
if(str == "unit")
return new Unit();
if(str == "true")
return new Boolean(true);
if(str == "false")
return new Boolean(false);
return new Variable(str);
}
case LexemeType::QuotedStr:
{
auto qstr = lexer.consume().str;
return new String(qstr.substr(1, qstr.length() - 2));
}
case LexemeType::Oparen:
{
lexer.consume();
auto t = parseExpr();
assert(t != nullptr, "term expected after '('");
assert(lexer.peek().type == LexemeType::Cparen, "')' not found after '('");
lexer.consume();
return t;
}
case LexemeType::Ofigure:
{
lexer.consume();
auto first = parseExpr();
assert(first != nullptr, "term expected after '{'");
// Looks like a record instead of a tuple
if(dynamic_cast<Variable*>(first) && lexer.peek().type == LexemeType::Equal)
{
lexer.consume();
auto name = dynamic_cast<Variable*>(first)->x;
auto value = parseExpr();
assert(value != nullptr, "term expected after '='");
vector<TaggedPair*> elements;
elements.push_back(new TaggedPair(name, value));
while(lexer.peek().type == LexemeType::Comma)
{
lexer.consume();
assert(lexer.peek().type == LexemeType::Str, "name expected after ','");
name = lexer.consume().str;
assert(lexer.peek().type == LexemeType::Equal, "'=' expected after name");
lexer.consume();
value = parseExpr();
assert(value != nullptr, "term expected after '='");
elements.push_back(new TaggedPair(name, value));
}
assert(lexer.peek().type == LexemeType::Cfigure, "'}' not found after term");
lexer.consume();
return new Record(elements);
}
vector<Term*> elements;
elements.push_back(first);
while(lexer.peek().type == LexemeType::Comma)
{
lexer.consume();
elements.push_back(parseExpr());
assert(elements.back() != nullptr, "term expected after ','");
}
assert(lexer.peek().type == LexemeType::Cfigure, "'}' not found after term");
lexer.consume();
return new Tuple(elements);
}
case LexemeType::Less:
{
lexer.consume();
assert(lexer.peek().type == LexemeType::Str, "label expected after '<'");
auto label = lexer.consume().str;
Term* value = new Unit();
if(lexer.peek().type == LexemeType::Equal)
{
assert(lexer.peek().type == LexemeType::Equal, "'=' expected after label");
lexer.consume();
value = parseExpr();
assert(value != nullptr, "term expected after '='");
}
assert(lexer.peek().type == LexemeType::Greater, "'>' not found after term");
lexer.consume();
assert(lexer.peek().str == "as", "'as' not found after '>'");
lexer.consume();
auto variant = parseType();
assert(variant != nullptr, "type expected after 'as'");
return new Variant(label, value, variant);
}
case LexemeType::If:
{
lexer.consume();
auto cond = parseExpr();
assert(cond != nullptr, "condition expected after 'if'");
assert(lexer.peek().type == LexemeType::Then, "'then' not found after term");
lexer.consume();
auto tru = parseExpr();
assert(tru != nullptr, "term expected after 'then'");
assert(lexer.peek().type == LexemeType::Else, "'else' not found after term");
lexer.consume();
auto fls = parseExpr();
assert(fls != nullptr, "term expected after 'else'");
return new Conditional(cond, tru, fls);
}
case LexemeType::Let_:
{
lexer.consume();
vector<string> names;
switch(lexer.peek().type)
{
case LexemeType::Str:
names.push_back(lexer.consume().str);
break;
case LexemeType::Ofigure:
lexer.consume();
assert(lexer.peek().type == LexemeType::Str, "name expected after '{'");
names.push_back(lexer.consume().str);
while(lexer.peek().type == LexemeType::Comma)
{
lexer.consume();
assert(lexer.peek().type == LexemeType::Str, "name expected after ','");
names.push_back(lexer.consume().str);
}
assert(lexer.peek().type == LexemeType::Cfigure, "'}' not found after name");
lexer.consume();
break;
default:
assert(false, "name or '{' expected after 'let'");
}
assert(lexer.peek().type == LexemeType::Equal, "'=' not found after name");
lexer.consume();
auto term = parseExpr();
assert(term != nullptr, "term expected after '='");
assert(lexer.peek().type == LexemeType::In, "'in' not found after term");
lexer.consume();
auto expr = parseExpr();
assert(expr != nullptr, "term expected after 'in'");
if(names.size() == 1)
return new Let(names[0], term, expr);
return new Match(names, term, expr);
}
case LexemeType::Case_:
{
lexer.consume();
auto caseExpr = parseExpr();
assert(caseExpr != nullptr, "term expected after 'case'");
assert(lexer.peek().type == LexemeType::Of, "'of' not found after term");
lexer.consume();
vector<CaseOption*> options;
options.push_back(parseCaseOption());
while(lexer.peek().type == LexemeType::Pipe)
{
lexer.consume();
options.push_back(parseCaseOption());
}
return new Case(caseExpr, options);
}
case LexemeType::Letrec:
{
lexer.consume();
assert(lexer.peek().type == LexemeType::Str, "label expected after '<'");
auto x = lexer.consume().str;
assert(lexer.peek().type == LexemeType::Colon, "':' not found after label");
lexer.consume();
auto xType = parseType();
assert(xType != nullptr, "type expected after name");
assert(lexer.peek().type == LexemeType::Equal, "'=' not found after name");
lexer.consume();
auto t1 = parseExpr();
assert(t1 != nullptr, "expression expected after '='");
assert(lexer.peek().type == LexemeType::In, "'in' not found after expression");
lexer.consume();
auto t2 = parseExpr();
assert(t2 != nullptr, "expression expected after 'in'");
return new Let(x, new Application(new Variable(string("fix")), new Abstraction(x, xType, t1)), t2);
}
case LexemeType::Nil:
{
lexer.consume();
assert(lexer.peek().type == LexemeType::Obracket, "'[' not found after 'nil'");
lexer.consume();
auto nilType = parseType();
assert(nilType != nullptr, "type expected after '['");
assert(lexer.peek().type == LexemeType::Cbracket, "']' not found after type");
lexer.consume();
return new List(nullptr, nullptr, new ListType(nilType));
}
case LexemeType::Unknown:
assert(false, "unknown lexeme at " + to_string(lexer.pos));
}
return nullptr;
}
Term* parseAccess()
{
auto t = parseTerm();
while(lexer.peek().str == ".")
{
lexer.consume();
switch(lexer.peek().type)
{
case LexemeType::Number:
{
int index = strtol(lexer.consume().str.c_str(), nullptr, 10);
assert(index > 0, "index has to start at '1'");
t = new TupleAccess(t, index);
break;
}
case LexemeType::Str:
{
string name = lexer.consume().str;
t = new RecordAccess(t, name);
break;
}
}
}
return t;
}
Term* parseAs()
{
auto t = parseAccess();
while(lexer.peek().str == "as")
{
lexer.consume();
auto type = parseType();
t = new Ascription(t, type);
}
return t;
}
Term* parseTerms()
{
auto t = parseAs();
auto t2 = parseAs();
while(t2)
{
t = new Application(t, t2);
t2 = parseAs();
}
return t;
}
Term* parseSeq()
{
auto t = parseTerms();
while(lexer.peek().type == LexemeType::Semicolon)
{
lexer.consume();
auto t2 = parseTerms();
assert(t2 != nullptr, "term expected after ';'");
t = new Application(new Abstraction(string("_"), new UnitType(), t2), t);
}
return t;
}
Term* parseExpr()
{
return parseSeq();
}
Term* parseStatement()
{
if(lexer.peek().str == "type")
{
lexer.consume();
assert(lexer.peek().type == LexemeType::Str, "name expected after 'type'");
string name = lexer.consume().str;
assert(lexer.peek().type == LexemeType::Equal, "'=' expected after name");
lexer.consume();
Type* rhs = parseType();
assert(rhs != nullptr, "type name expected after '='");
aliases[name] = rhs;
return new TypeAlias(name, rhs);
}
if(lexer.peek().str == "local")
{
lexer.consume();
assert(lexer.peek().type == LexemeType::Str, "name expected after 'local'");
string name = lexer.consume().str;
assert(lexer.peek().type == LexemeType::Equal, "'=' expected after name");
lexer.consume();
auto value = parseExpr();
assert(value != nullptr, "term expected after '='");
return new Assignment(name, value);
}
return parseExpr();
}
Term* parseStatements()
{
auto t = parseStatement();
while(lexer.peek().type == LexemeType::Comma)
{
lexer.consume();
auto t2 = parseStatement();
assert(t2 != nullptr, "statement expected after ','");
t = new Application(new Abstraction(string("_"), new UnitType(), t2), t);
}
return t;
}
Term* parse(string code)
{
lexer = Lexer(string(code));
auto t = parseStatements();
assert(lexer.pos == lexer.str.length(), "unknown symbol " + lexer.str[lexer.pos]);
return t;
}
Lexer lexer;
};
struct Interpreter
{
Parser p;
unordered_map<string, Type*> bindings;
unordered_map<string, Term*> globals; // optional
Type* typecheck(Term* a);
void free(string name, string code);
void global(string name, string code);
Term* substitute(Term* a, string x, Term* rhs);
Term* eval(Term* t);
Term* evalBuiltin(Variable* v, Term* t);
auto eval(string code);
void printeval(string code);
void printtest(string code, string result);
void printtypecheck(string code);
};
Type* Interpreter::typecheck(Term* a)
{
if(auto t = dynamic_cast<Unit*>(a))
{
t->type = new UnitType();
}
if(auto t = dynamic_cast<Boolean*>(a))
{
t->type = new BoolType();
}
if(auto t = dynamic_cast<Natural*>(a))
{
t->type = new NatType();
}
if(auto t = dynamic_cast<String*>(a))
{
t->type = new StringType();
}
if(auto t = dynamic_cast<Float*>(a))
{
t->type = new FloatType();
}
if(auto t = dynamic_cast<Variable*>(a))
{
if(auto type = bindings.find(t->x); type != bindings.end())
t->type = type->second;
else if(auto global = globals.find(t->x); global != globals.end())
t->type = global->second->type;
else
cout << "ERROR: Unknown variable " << t->x << endl;
}
if(auto t = dynamic_cast<Abstraction*>(a))
{
auto lastType = bindings[t->x];
if(t->x != "_")
bindings[t->x] = t->xType;
if(auto tType = typecheck(t->t))
t->type = new FunctionType(t->xType, tType);
bindings[t->x] = lastType;
}
if(auto t = dynamic_cast<Application*>(a))
{
auto t2Type = typecheck(t->t2);
if(t2Type && dynamic_cast<Variable*>(t->t1))
{
Variable* v = dynamic_cast<Variable*>(t->t1);
if(v->x == "succ")
{
if(dynamic_cast<NatType*>(t2Type))
t->type = new NatType();
else
cout << v->x << " expects Nat as an argument, got " << t2Type << endl;
}
if(v->x == "pred")
{
if(dynamic_cast<NatType*>(t2Type))
t->type = new NatType();
else
cout << v->x << " expects expects Nat as an argument, got " << t2Type << endl;
}
if(v->x == "iszero")
{
if(dynamic_cast<NatType*>(t2Type))
t->type = new BoolType();
else
cout << v->x << " expects expects Nat as an argument, got " << t2Type << endl;
}
if(v->x == "print")
{
if(dynamic_cast<UnitType*>(t2Type) || dynamic_cast<BoolType*>(t2Type) || dynamic_cast<NatType*>(t2Type) || dynamic_cast<StringType*>(t2Type) || dynamic_cast<FloatType*>(t2Type))
t->type = new UnitType();
else
cout << v->x << " expects expects base type as an argument, got " << t2Type << endl;
}
if(v->x == "fix")
{
if(auto funcType = dynamic_cast<FunctionType*>(t2Type))
{
if(isEqual(funcType->arg, funcType->res))
t->type = funcType->res;
else
cout << v->x << " expects expects T->T function type as an argument, got " << t2Type << endl;
}
else
{
cout << v->x << " expects expects function type as an argument, got " << t2Type << endl;
}
}
if(v->x == "eq")
{
if(dynamic_cast<NatType*>(t2Type))
t->type = new FunctionType(new NatType(), new BoolType());
else
cout << v->x << " expects expects Nat as an argument, got " << t2Type << endl;
}
if(v->x == "add")
{
if(dynamic_cast<NatType*>(t2Type))
t->type = new FunctionType(new NatType(), new NatType());
else
cout << v->x << " expects expects Nat as an argument, got " << t2Type << endl;
}
if(v->x == "mul")
{
if(dynamic_cast<NatType*>(t2Type))
t->type = new FunctionType(new NatType(), new NatType());
else
cout << v->x << " expects expects Nat as an argument, got " << t2Type << endl;
}
if(v->x == "addf")
{
if(dynamic_cast<FloatType*>(t2Type))
t->type = new FunctionType(new FloatType(), new FloatType());
else
cout << v->x << " expects expects Nat as an argument, got " << t2Type << endl;
}
if(v->x == "mulf")
{
if(dynamic_cast<FloatType*>(t2Type))
t->type = new FunctionType(new FloatType(), new FloatType());
else
cout << v->x << " expects expects Nat as an argument, got " << t2Type << endl;
}
if(v->x == "cons")
{
t->type = new FunctionType(new ListType(t2Type), new ListType(t2Type));
}
if(v->x == "isnil")
{
if(dynamic_cast<ListType*>(t2Type))
t->type = new BoolType();
else
cout << v->x << " expects expects List[T] as an argument, got " << t2Type << endl;
}
if(v->x == "head")
{
if(auto t2_ = dynamic_cast<ListType*>(t2Type))
t->type = t2_->element;
else
cout << v->x << " expects expects List[T] as an argument, got " << t2Type << endl;
}
if(v->x == "tail")
{
if(dynamic_cast<ListType*>(t2Type))
t->type = t2Type;
else
cout << v->x << " expects expects List[T] as an argument, got " << t2Type << endl;
}
if(t->type)
return a->type;
}
auto t1Type = typecheck(t->t1);
if(t1Type && t2Type)
{
if(FunctionType* t1FuncType = dynamic_cast<FunctionType*>(t1Type))
{
if(isEqual(t1FuncType->arg, t2Type))
t->type = t1FuncType->res;
else
cout << "ERROR: Application rhs type " << t2Type << " has to match function argument type " << t1FuncType->arg << endl;
}
else
{
cout << "ERROR: Application lhs type has to be function got " << t1Type << endl;
}
}
}
if(auto t = dynamic_cast<Conditional*>(a))
{
auto condType = typecheck(t->cond);
auto truType = typecheck(t->tru);
auto flsType = typecheck(t->fls);
if(condType && truType && flsType)
{
if(dynamic_cast<BoolType*>(condType))
{
if(isEqual(truType, flsType))
t->type = truType;
else
cout << "ERROR: Condition branch types have to isEqual " << truType << " != " << flsType << endl;
}
else
{
cout << "ERROR: Condition type has to be Bool got " << condType << endl;
}
}
}
if(auto t = dynamic_cast<TypeAlias*>(a))
{
t->type = new UnitType();
}
if(auto t = dynamic_cast<Assignment*>(a))
{
auto tType = typecheck(t->value);
if(tType)
{
bindings[t->name] = tType;
t->type = new UnitType();
}
}
if(auto t = dynamic_cast<Ascription*>(a))
{
auto tType = typecheck(t->t);
if(tType)
{
if(isEqual(t->expected, tType))
t->type = tType;
else
cout << "ERROR: Ascription failed to match " << tType << " to " << t->expected << endl;
}
}
if(auto t = dynamic_cast<Let*>(a))
{
auto tType = typecheck(t->t);
if(tType)
{
auto lastType = bindings[t->name];
if(t->name != "_")
bindings[t->name] = tType;
if(auto exprType = typecheck(t->expr))
t->type = exprType;
bindings[t->name] = lastType;
}
}
if(auto t = dynamic_cast<Match*>(a))
{
auto tType = typecheck(t->t);
if(tType)
{
if(TupleType* matchTupleType = dynamic_cast<TupleType*>(tType))
{
if(t->names.size() == matchTupleType->elements.size())
{
unordered_map<string, Type*> saved;
for(unsigned pos = 0; pos < t->names.size(); pos++)
{
auto &name = t->names[pos];
auto &type = matchTupleType->elements[pos];
saved[name] = bindings[name];
if(name != "_")
bindings[name] = type;
}
if(auto exprType = typecheck(t->expr))
t->type = exprType;
for(auto el : saved)
bindings[el.first] = el.second;
}
else
{
cout << "ERROR: Mismatch between let pattern " << t << " and tuple type " << tType << endl;
}
}
else
{
cout << "ERROR: Cannot match let pattern to " << tType << endl;
}
}
}
if(auto t = dynamic_cast<Tuple*>(a))
{
vector<Type*> elements;
for(auto i : t->elements)
{
if(auto type = typecheck(i))
{
elements.push_back(type);
}
else
{
cout << "ERROR: Failed to typecheck " << typeid(*a).name() << " '" << a << "'" << endl;
return nullptr;
}
}
t->type = new TupleType(elements);
}
if(auto t = dynamic_cast<TupleAccess*>(a))
{
auto tType = typecheck(t->t);
if(tType)
{
if(TupleType* tupType = dynamic_cast<TupleType*>(tType))
{
if(unsigned(t->index - 1) < tupType->elements.size())
t->type = tupType->elements[t->index - 1];
else
cout << "ERROR: Index " << t->index << " is out of bounds of type " << tupType << endl;
}
else
{
cout << "ERROR: Cannot index type that is not a Tuple " << tType << endl;
}
}
}
if(auto t = dynamic_cast<TaggedPair*>(a))
{
if(auto tType = typecheck(t->t))
t->type = new TaggedType(t->name, tType);
}
if(auto t = dynamic_cast<Record*>(a))
{
vector<TaggedType*> elements;
for(auto i : t->elements)
{
if(auto type = typecheck(i))
{
if(TaggedType* recPairType = dynamic_cast<TaggedType*>(type))
{
if(count_if(elements.begin(), elements.end(), [&](auto&& x){ return x->name == recPairType->name; }) != 0)
cout << "ERROR: record already has field named " << recPairType->name << endl;
else
elements.push_back(recPairType);
}
else
{
cout << "ERROR: Record element type is not a TaggedPair: " << type << endl;
}
}
else
{
cout << "ERROR: Failed to typecheck " << typeid(*a).name() << " '" << a << "'" << endl;
return nullptr;
}
}
t->type = new RecordType(elements);
}
if(auto t = dynamic_cast<RecordAccess*>(a))
{
auto tType = typecheck(t->t);
if(tType)
{
if(RecordType* recType = dynamic_cast<RecordType*>(tType))
{
if(count_if(recType->elements.begin(), recType->elements.end(), [&](auto&& x){ return x->name == t->name; }) == 0)
{
cout << "ERROR: record has no field named " << t->name << " in " << recType;
}
else
{
for(auto i : recType->elements)
{
if(i->name == t->name)
t->type = i->type;
}
}
}
else
{
cout << "ERROR: Cannot index type that is not a Record: " << tType << endl;
}
}
}
if(auto t = dynamic_cast<Variant*>(a))
{
if(!dynamic_cast<VariantType*>(t->variant))
{
cout << "ERROR: " << t->variant << " is not a VariantType" << endl;
cout << "ERROR: Failed to typecheck " << typeid(*a).name() << " '" << a << "'" << endl;
return nullptr;
}
VariantType* varType = dynamic_cast<VariantType*>(t->variant);
Type* optionType = nullptr;
for(auto option : varType->elements)
{
if(t->label == option->name)
optionType = option->type;
}
if(!optionType)
{
cout << "ERROR: variant has no label named " << t->label << " in " << varType;
cout << "ERROR: Failed to typecheck " << typeid(*a).name() << " '" << a << "'" << endl;
return nullptr;
}
auto tType = typecheck(t->t);
if(tType)
{
if(isEqual(optionType, tType))
t->type = varType;
else
cout << "ERROR: variant failed to match option " << t->label << " of type " << optionType << " to " << tType << endl;
}
}
if(auto t = dynamic_cast<Case*>(a))
{
auto tType = typecheck(t->t);
if(!tType)
{
cout << "ERROR: Failed to typecheck " << typeid(*a).name() << " '" << a << "'" << endl;
return nullptr;
}
if(!dynamic_cast<VariantType*>(tType))
{
cout << "ERROR: " << t->t << " is not a VariantType" << endl;
cout << "ERROR: Failed to typecheck " << typeid(*a).name() << " '" << a << "'" << endl;
return nullptr;
}
VariantType* varType = dynamic_cast<VariantType*>(tType);
if(varType->elements.size() != t->options.size())
{
cout << "ERROR: case does cover all options" << endl;
cout << "ERROR: Failed to typecheck " << typeid(*a).name() << " '" << a << "'" << endl;
return nullptr;
}
Type* result = nullptr;
for(auto option : t->options)
{
auto varOption = find_if(varType->elements.begin(), varType->elements.end(), [&](auto&& x){ return x->name == option->label; });
if(varOption == varType->elements.end())
{
cout << "ERROR: label " << option->label << " not found in " << varType << endl;
cout << "ERROR: Failed to typecheck " << typeid(*a).name() << " '" << a << "'" << endl;
return nullptr;
}
if(count_if(t->options.begin(), t->options.end(), [&](auto&& x){ return x->label == option->label; }) > 1)
{
cout << "ERROR: duplicate case labels" << endl;
cout << "ERROR: Failed to typecheck " << typeid(*a).name() << " '" << a << "'" << endl;
return nullptr;
}
auto lastType = bindings[option->name];
if(option->name != "_")
bindings[option->name] = (*varOption)->type;
auto exprType = typecheck(option->t);
if(result && !isEqual(exprType, result))
{
cout << "ERROR: label " << option->label << " result " << exprType << " doesn't match " << result << endl;
cout << "ERROR: Failed to typecheck " << typeid(*a).name() << " '" << a << "'" << endl;
return nullptr;
}
result = exprType;
bindings[option->name] = lastType;
}
t->type = result;
}
if(auto t = dynamic_cast<List*>(a))
{
assert(t->type != nullptr, "list must be typed internally");
}
if(!a->type)
cout << "ERROR: Failed to typecheck " << typeid(*a).name() << " '" << a << "'" << endl;
return a->type;
}
void Interpreter::free(string name, string code)
{
p.lexer = Lexer(string(code));
auto t = p.parseType();
assert(t != nullptr, "failed to parse type");
auto var = new Variable(string(name));
var->type = t;
globals[string(name)] = var;
}
void Interpreter::global(string name, string code)
{
auto t = p.parse(code);
assert(t != nullptr, "failed to parse");
typecheck(t);
assert(t->type != nullptr, "failed to typecheck");
globals[string(name)] = t;
}
Term* Interpreter::substitute(Term* a, string x, Term* rhs)
{
if(auto t = dynamic_cast<Variable*>(a))
{
if(t->x == x)
return rhs;
}
if(auto t = dynamic_cast<Abstraction*>(a))
{
// Do not step into the abstractions that have a binding over same name
if(t->x != "_" && t->x == x)
return a;
return new Abstraction(t->x, t->xType, substitute(t->t, x, rhs));
}
if(auto t = dynamic_cast<Application*>(a))
{
// built-in 'binary' ops
if(x == "eq'")
return new Boolean(dynamic_cast<Natural*>(t->t1)->value == dynamic_cast<Natural*>(rhs)->value);
if(x == "add'")
return new Natural(dynamic_cast<Natural*>(t->t1)->value + dynamic_cast<Natural*>(rhs)->value);
if(x == "mul'")
return new Natural(dynamic_cast<Natural*>(t->t1)->value * dynamic_cast<Natural*>(rhs)->value);
if(x == "addf'")
return new Float(dynamic_cast<Float*>(t->t1)->value + dynamic_cast<Float*>(rhs)->value);
if(x == "mulf'")
return new Float(dynamic_cast<Float*>(t->t1)->value * dynamic_cast<Float*>(rhs)->value);
if(x == "cons'")
return new List(t->t1, rhs, new ListType(t->t1->type));
return new Application(substitute(t->t1, x, rhs), substitute(t->t2, x, rhs));
}
if(auto t = dynamic_cast<Conditional*>(a))
{
return new Conditional(substitute(t->cond, x, rhs), substitute(t->tru, x, rhs), substitute(t->fls, x, rhs));
}
if(auto t = dynamic_cast<Ascription*>(a))
{
return new Ascription(substitute(t->t, x, rhs), t->expected);
}
if(auto t = dynamic_cast<Let*>(a))
{
auto value = substitute(t->t, x, rhs);
auto expr = t->name == x ? t->expr : substitute(t->expr, x, rhs);
return new Let(t->name, value, expr);
}
if(auto t = dynamic_cast<Match*>(a))
{
auto value = substitute(t->t, x, rhs);
auto expr = count_if(t->names.begin(), t->names.end(), [&](auto&& name){ return name == x; }) != 0 ? t->expr : substitute(t->expr, x, rhs);
return new Match(t->names, value, expr);
}
if(auto t = dynamic_cast<Tuple*>(a))
{
vector<Term*> elements;
for(auto i : t->elements)
elements.push_back(substitute(i, x, rhs));
return new Tuple(elements);
}
if(auto t = dynamic_cast<TupleAccess*>(a))
{
return new TupleAccess(substitute(t->t, x, rhs), t->index);
}
if(auto t = dynamic_cast<Record*>(a))
{
vector<TaggedPair*> elements;
for(auto i : t->elements)
elements.push_back(new TaggedPair(i->name, substitute(i->t, x, rhs)));
return new Record(elements);
}
if(auto t = dynamic_cast<RecordAccess*>(a))
{
return new RecordAccess(substitute(t->t, x, rhs), t->name);
}
if(auto t = dynamic_cast<Variant*>(a))
{
return new Variant(t->label, substitute(t->t, x, rhs), t->variant);
}
if(auto t = dynamic_cast<CaseOption*>(a))
{
if(t->name != "_" && t->name == x)
return a;
return new CaseOption(t->label, t->name, substitute(t->t, x, rhs));
}
if(auto t = dynamic_cast<Case*>(a))
{
auto value = substitute(t->t, x, rhs);
vector<CaseOption*> options;
for(auto i : t->options)
options.push_back(dynamic_cast<CaseOption*>(substitute(i, x, rhs)));
return new Case(value, options);
}
return a;
}
Term* Interpreter::evalBuiltin(Variable* v, Term* t)
{
t = eval(t);
if(v->x == "succ")
{
if(Natural* a = dynamic_cast<Natural*>(t))
return new Natural(a->value + 1);
}
if(v->x == "pred")
{
if(Natural* a = dynamic_cast<Natural*>(t))
{
if(a->value != 0)
return new Natural(a->value - 1);
}
}
if(v->x == "iszero")
{
if(Natural* a = dynamic_cast<Natural*>(t))
return new Boolean(a->value == 0);
}
if(v->x == "print")
{
if(auto a = dynamic_cast<Unit*>(t)) cout << "STDOUT: " << endl;
if(auto a = dynamic_cast<Boolean*>(t)) cout << "STDOUT: " << a->value << endl;
if(auto a = dynamic_cast<Natural*>(t)) cout << "STDOUT: " << a->value << endl;
if(auto a = dynamic_cast<String*>(t)) cout << "STDOUT: " << a->value << endl;
if(auto a = dynamic_cast<Float*>(t)) cout << "STDOUT: " << a->value << endl;
return new Unit();
}
if(v->x == "fix")
{
if(Abstraction* tabs = dynamic_cast<Abstraction*>(eval(t)))
{
auto s = new Application(new Variable(string("fix")), t);
return substitute(tabs->t, tabs->x, s);
}
assert(false, "fix must receive a function");
}
if(v->x == "eq")
{
return new Abstraction(string("eq'"), new FunctionType(new NatType(), new BoolType()), new Application(t, new Variable(string("eq'"), new NatType())));
}
if(v->x == "add")
{
return new Abstraction(string("add'"), new FunctionType(new NatType(), new NatType()), new Application(t, new Variable(string("add'"), new NatType())));
}
if(v->x == "mul")
{
return new Abstraction(string("mul'"), new FunctionType(new NatType(), new NatType()), new Application(t, new Variable(string("mul'"), new NatType())));
}
if(v->x == "addf")
{
return new Abstraction(string("addf'"), new FunctionType(new FloatType(), new FloatType()), new Application(t, new Variable(string("addf'"), new FloatType())));
}
if(v->x == "mulf")
{
return new Abstraction(string("mulf'"), new FunctionType(new FloatType(), new FloatType()), new Application(t, new Variable(string("mulf'"), new FloatType())));
}
if(v->x == "cons")
{
return new Abstraction(string("cons'"), new FunctionType(new ListType(t->type), new ListType(t->type)), new Application(t, new Variable(string("cons'"), new ListType(t->type))));
}
if(v->x == "isnil")
{
if(List* a = dynamic_cast<List*>(t))
return new Boolean(a->head == nullptr);
}
if(v->x == "head")
{
if(List* a = dynamic_cast<List*>(t))
{
if(a->tail)
return a->head;
}
}
if(v->x == "tail")
{
if(List* a = dynamic_cast<List*>(t))
{
if(a->tail)
return a->tail;
}
}
return nullptr;
}
Term* Interpreter::eval(Term* t)
{
if(auto a = dynamic_cast<Unit*>(t))
{
return t;
}
if(auto a = dynamic_cast<Boolean*>(t))
{
return t;
}
if(auto a = dynamic_cast<Natural*>(t))
{
return t;
}
if(auto a = dynamic_cast<String*>(t))
{
return t;
}
if(auto a = dynamic_cast<Float*>(t))
{
return t;
}
if(auto a = dynamic_cast<Abstraction*>(t))
{
return t;
}
if(auto a = dynamic_cast<List*>(t))
{
return t;
}
if(auto a = dynamic_cast<Application*>(t))
{
auto t1 = eval(a->t1);
if(auto v1 = dynamic_cast<Variable*>(t1))
{
if(auto r = evalBuiltin(v1, a->t2))
return r;
}
if(!dynamic_cast<Abstraction*>(t1))
cout << "expected abstraction for t1, got: (" << typeid(*t1).name() << ") " << t1 << endl;
Abstraction* lhs = dynamic_cast<Abstraction*>(t1);
auto value = eval(a->t2);
return lhs->x == "_" ? eval(lhs->t) : eval(substitute(lhs->t, lhs->x, value));
}
if(auto a = dynamic_cast<Variable*>(t))
{
if(auto t = globals.find(a->x); t != globals.end())
return eval(t->second);
}
if(auto a = dynamic_cast<Conditional*>(t))
{
auto cond = eval(a->cond);
if(dynamic_cast<Boolean*>(cond)->value)
return eval(a->tru);
else
return eval(a->fls);
}
if(auto a = dynamic_cast<TypeAlias*>(t))
{
return new Unit();
}
if(auto a = dynamic_cast<Assignment*>(t))
{
globals[a->name] = eval(a->value);
return new Unit();
}
if(auto a = dynamic_cast<Ascription*>(t))
{
return eval(a->t);
}
if(auto a = dynamic_cast<Let*>(t))
{
auto value = eval(a->t);
return a->name == "_" ? eval(a->expr) : eval(substitute(a->expr, a->name, value));
}
if(auto a = dynamic_cast<Match*>(t))
{
Tuple* value = dynamic_cast<Tuple*>(eval(a->t));
auto expr = a->expr;
for(unsigned pos = 0; pos < a->names.size(); pos++)
{
auto& i = a->names[pos];
auto& j = value->elements[pos];
expr = i == "_" ? expr : substitute(expr, i, j);
}
return eval(expr);
}
if(auto a = dynamic_cast<Tuple*>(t))
{
vector<Term*> elements;
for(auto i : a->elements)
elements.push_back(eval(i));
return new Tuple(elements);
}
if(auto a = dynamic_cast<TupleAccess*>(t))
{
Tuple* value = dynamic_cast<Tuple*>(eval(a->t));
return eval(value->elements[a->index - 1]); // should be typechecked already
}
if(auto a = dynamic_cast<Record*>(t))
{
vector<TaggedPair*> elements;
for(auto i : a->elements)
elements.push_back(new TaggedPair(i->name, eval(i->t)));
return new Record(elements);
}
if(auto a = dynamic_cast<RecordAccess*>(t))
{
Record* value = dynamic_cast<Record*>(eval(a->t));
for(auto i : value->elements)
{
if(i->name == a->name)
return i->t;
}
assert(false, "failed to access record element");
}
if(auto a = dynamic_cast<Variant*>(t))
{
return new Variant(a->label, eval(a->t), a->variant);
}
if(auto a = dynamic_cast<Case*>(t))
{
Variant* value = dynamic_cast<Variant*>(eval(a->t));
for(auto option : a->options)
{
if(value->label == option->label)
{
auto tmp = substitute(option->t, option->name, value->t);
cout << tmp << endl;
return eval(tmp);
}
}
}
return t;
}
auto Interpreter::eval(string code)
{
auto t = p.parse(code);
assert(t != nullptr, "failed to parse");
typecheck(t);
assert(t->type != nullptr, "failed to typecheck");
return eval(t);
}
void Interpreter::printeval(string code)
{
cout << code << endl;
auto t = eval(code);
typecheck(t);
cout << "> " << t << " : " << getName(t->type) << endl;
}
void Interpreter::printtest(string code, string result)
{
cout << code << endl;
auto t = eval(code);
typecheck(t);
cout << "> " << t << " : " << getName(t->type) << endl;
auto r = eval(result);
typecheck(r);
//if(!isEqual(t, r))
// assert(false, "result mismatch");
}
void Interpreter::printtypecheck(string code)
{
auto t = p.parse(code);
assert(t != nullptr, "failed to parse");
typecheck(t);
assert(t->type != nullptr, "failed to typecheck");
cout << "> " << t << " : " << getName(t->type) << endl;
}
int main()
{
Interpreter i;
i.global("not", "/x:Bool. if x then false else true");
i.printeval("if true then (/x:Bool. x) else (/x:Bool. not x)");
i.free("f", "Bool -> Bool");
i.printtypecheck("f");
i.printtypecheck("f (if false then true else false)");
i.printtypecheck("/x:Bool. f (if x then false else x)");
// 11.1 Base types, Type Aliases
i.printeval("true");
i.printeval("1");
i.printeval("'hello'");
i.printeval("1.3");
i.printeval(R"(
type UU = Unit->Unit,
(/f:UU.f unit) (/x:Unit.x)
)");
// 11.2 Unit type
i.printeval("unit");
// 11.3 Sequencing
i.printeval("unit;2");
// 11.4 Ascription
i.printeval("2 as Nat");
// 11.5 Let Bindings
// 11.5.1
i.printeval("let a = {2, 3} in a.2");
// 11.6, 11.7 Pairs, Tuples
i.free("tupTest", "Bool * Bool -> Bool * Nat");
i.printtypecheck("tupTest");
i.printeval("{2, 'hello', false}");
i.printtest("{2, 'hello', false}.2", "'hello'");
i.printtest("pred 3", "2");
i.printeval("{pred 4, pred 5}");
i.printtest("{pred 4, if true then false else false}.1", "3");
i.printtest("(/x:Nat*Nat. x.2) {pred 4, pred 5}", "4");
// some substitution tests
i.printeval("(/x:Nat. x as Nat) 2");
i.printeval("(/x:Nat. let a=x in a) 2");
i.printeval("(/x:Nat. let x={x, x} in x.1) 2");
i.printeval("(/x:Nat. {x, x}) 2");
i.printeval("(/x:Nat*Nat. x.2) {2, 3}");
i.printeval("print 'hello world'; 2");
i.printeval("{x=5}");
i.printeval("{x=pred 10}");
// 11.8 Records
i.printeval("{partno=5524,cost=30.27}");
i.printeval("{x=pred 4} as {x:Nat}");
i.printeval("{x=5} as {x:Nat}");
i.printeval("{partno=5524,cost=30.27} as {partno:Nat,cost:Float}");
// 11.8.2
i.printeval("let {a,b} = {2, 3} in print b;a");
i.printtest("(/_:Nat. 2) (print 'test';3)", "2");
i.printeval(R"(
local a = 4,
a
)");
i.free("tupTest", "Bool * Bool -> Bool * Nat");
i.printeval("type PhysicalAddr = {firstlast:String, addr:String}");
i.printeval("type VirtualAddr = {name:String, email:String}");
// 11.10 Variants
i.printeval(R"(
type Addr = <physical:PhysicalAddr, virtual:VirtualAddr>,
local pa = {firstlast='ve', addr='none'},
local a = <physical=pa> as Addr,
a
)");
i.printeval(R"(
local getName = /a:Addr.
case a of
<physical=x> -> x.firstlast
| <virtual=y> -> y.name,
getName
)");
i.printeval(R"(
getName a
)");
// Options (with Unit skip)
i.printeval("type OptionalNat = <none, some:Nat>");
i.printeval("type Table = Nat -> OptionalNat");
i.printeval("local emptyTable = /n:Nat. <none> as OptionalNat");
i.printeval("emptyTable 1");
i.printeval("eq 4 5");
i.printeval("eq 4 4");
i.printeval("add 4 5");
i.printeval("mulf 4.2 5.2");
i.printeval(R"(
local extendTable =
/t:Table./m:Nat./v:Nat.
/n:Nat.
if eq n m then <some=v> as OptionalNat
else t n
)");
i.printeval("extendTable emptyTable 1 10");
// Enumerations (with Unit skip)
i.printeval("type Weekday = <monday, tuesday, wednesday, thursday, friday>");
i.printeval(R"(
local nextBusinessDay = /w:Weekday.
case w of
<monday=x> -> <tuesday> as Weekday
| <tuesday=x> -> <wednesday> as Weekday
| <wednesday=x> -> <friday> as Weekday
| <thursday=x> -> <friday> as Weekday
| <friday=x> -> <monday> as Weekday
)");
i.printeval("nextBusinessDay <tuesday> as Weekday");
// Single-field variants
i.printeval("type DollarAmount = <dollars:Float>");
i.printeval("type EuroAmount = <euros:Float>");
i.printeval(R"(
local dollars2euros = /d:DollarAmount.
case d of
<dollars=x> -> <euros = mulf x 1.1325> as EuroAmount
)");
i.printeval(R"(
local euros2dollars = /d:EuroAmount.
case d of
<euros=x> -> <dollars = mulf x 0.883> as DollarAmount
)");
i.printeval("local mybankballance = <dollars=39.50> as DollarAmount");
i.printeval("euros2dollars (dollars2euros mybankballance)");
// TODO: Dynamic?
// 11.11 General Recursion
i.printeval(R"(
local ff = /ie:Nat->Bool.
/x:Nat.
if iszero x then true
else if iszero (pred x) then false
else ie (pred (pred x))
)");
i.printeval("local iseven = fix ff");
i.printeval("iseven 7");
i.printeval("iseven 8");
// 11.11.1
i.printeval(R"(
local equalrec = fix /self:Nat->Nat->Bool.
/x:Nat./y:Nat.
if iszero x then (if iszero y then true else false)
else if iszero y then (if iszero x then true else false)
else self (pred x) (pred y)
)");
i.printeval("equalrec 8 7");
i.printeval("equalrec 14 14");
i.printeval(R"(
local plusrec = fix /self:Nat->Nat->Nat.
/x:Nat./y:Nat.
if iszero x then y
else if iszero (pred x) then succ y
else self (pred x) (succ y)
)");
i.printeval("plusrec 8 7");
i.printeval(R"(
local timesrec = fix /self:Nat->Nat->Nat.
/x:Nat./y:Nat.
if iszero x then 0
else plusrec y (self (pred x) y)
)");
i.printeval("timesrec 4 8");
i.printeval(R"(
local factrec = fix /self:Nat->Nat.
/x:Nat.
if iszero x then 1
else if iszero (pred x) then 1
else timesrec x (self (pred x))
)");
i.printeval("factrec 5");
i.printeval(R"(
local ff = /ieio:{iseven:Nat->Bool, isodd:Nat->Bool}.
{
iseven = /x:Nat.
if iszero x then true
else ieio.isodd (pred x),
isodd = /x:Nat.
if iszero x then false
else ieio.iseven (pred x)
},
ff
)");
i.printeval("local r = fix ff");
i.printtypecheck("r");
i.printeval("local iseven = r.iseven");
i.printeval("iseven 7");
i.printeval(R"(
letrec iseven: Nat->Bool =
/x:Nat.
if iszero x then true
else if iszero (pred x) then false
else iseven (pred (pred x))
in
iseven 7
)");
// 11.11.2
i.printeval(R"(
letrec equalrec: Nat->Nat->Bool =
/x:Nat./y:Nat.
if iszero x then (if iszero y then true else false)
else if iszero y then (if iszero x then true else false)
else equalrec (pred x) (pred y)
in
equalrec 8 7
)");
i.printeval(R"(
letrec plusrec: Nat->Nat->Nat =
/x:Nat./y:Nat.
if iszero x then y
else if iszero (pred x) then succ y
else plusrec (pred x) (succ y)
in
plusrec 8 7
)");
i.printeval("plusrec 8 7");
i.printeval(R"(
letrec timesrec: Nat->Nat->Nat =
/x:Nat./y:Nat.
if iszero x then 0
else plusrec y (timesrec (pred x) y)
in
timesrec 4 8
)");
i.printeval("timesrec 4 8");
i.printeval(R"(
letrec factrec: Nat->Nat =
/x:Nat.
if iszero x then 1
else if iszero (pred x) then 1
else timesrec x (factrec (pred x))
in
factrec 5
)");
// 11.13 Lists
i.free("listTest", "[Nat]");
i.printtypecheck("listTest");
i.printeval("nil[Nat]");
i.printeval("isnil nil[Nat]");
i.printeval("local l1 = cons 1 nil[Nat]");
i.printeval("l1");
i.printeval("head l1");
i.printeval("tail l1");
i.printeval("isnil l1");
i.printeval("local l2 = cons 1 (cons 2 nil[Nat])");
i.printeval("l2");
i.printeval("head l2");
i.printeval("head (tail l2)");
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment