Last active
April 14, 2019 10:20
-
-
Save sorrge/2d2271e57ad1e91a50ea to your computer and use it in GitHub Desktop.
Dependent types in C++
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <iostream> | |
using namespace std; | |
// TypeValue is a type and a value at the same time. | |
// Each ID is a type, and there can be only one value of this type | |
// This prevents (only in runtime) construction of two vectors with the same Length variable, but of different actual lengths | |
template<int ID> | |
class TypeValue | |
{ | |
static int value; | |
static bool initialized; | |
public: | |
TypeValue(int v) | |
{ | |
if (initialized) | |
throw "Tried to reinitialize a TypeValue"; | |
initialized = true; | |
value = v; | |
} | |
static int Eval() { return value; } | |
}; | |
template<int ID> | |
bool TypeValue<ID>::initialized = false; | |
template<int ID> | |
int TypeValue<ID>::value; | |
// the presense of this class in the context means that TypeExpr1 is proven to be equal to TypeExpr2 | |
template<class TypeExpr1, class TypeExpr2> | |
class TypesAreEqual; | |
// static TypeExpr equality check using a context | |
template<class TV1, class TV2, class Context> | |
struct TypeExprEq : false_type{}; | |
template<int ID, class Context> | |
struct TypeExprEq<TypeValue<ID>, TypeValue<ID>, Context> : true_type{}; | |
template<class TV1, class TV2> | |
struct TypeExprEq<TV1, TV2, TypesAreEqual<TV1, TV2>> : true_type{}; | |
template<class TV1, class TV2> | |
struct TypeExprEq<TV1, TV2, TypesAreEqual<TV2, TV1>> : true_type{}; | |
template<class E1, class E2> | |
class Plus | |
{ | |
public: | |
Plus(E1, E2) {} // The constructor takes values as the proof that they have been created and initialized | |
static int Eval() { return E1::Eval() + E2::Eval(); } | |
}; | |
// plusCommutative | |
template<class TV1, class TV2, class TV3, class TV4, class Context> | |
struct TypeExprEq<Plus<TV1, TV2>, Plus<TV3, TV4>, Context> | |
{ | |
static const bool value = TypeExprEq<TV1, TV3, Context>::value && TypeExprEq<TV2, TV4, Context>::value || | |
TypeExprEq<TV1, TV4, Context>::value && TypeExprEq<TV2, TV3, Context>::value; | |
}; | |
// runtime TypeExpr equality check | |
template<class TypeExpr1, class TypeExpr2> | |
class TypeExprEqRT | |
{ | |
public: | |
TypeExprEqRT(TypeExpr1, TypeExpr2) {} | |
static bool Eval() | |
{ | |
return TypeExpr1::Eval() == TypeExpr2::Eval(); | |
} | |
typedef TypesAreEqual<TypeExpr1, TypeExpr2> ContextIfTrue; | |
}; | |
template<int ID1, int ID2> | |
TypeExprEqRT<TypeValue<ID1>, TypeValue<ID2>> operator==(TypeValue<ID1> tv1, TypeValue<ID2> tv2) | |
{ | |
return TypeExprEqRT<TypeValue<ID1>, TypeValue<ID2>>(tv1, tv2); | |
} | |
template<class Condition, class Operation> | |
void If(Condition, Operation oper) | |
{ | |
if (Condition::Eval()) | |
oper.Eval<Condition::ContextIfTrue>(); | |
} | |
// Utility function which passes the context | |
template<class Func> | |
class PrintLN | |
{ | |
Func f; | |
public: | |
PrintLN(Func _f) : f(_f) {} | |
template<class Context> | |
void Eval() | |
{ | |
cout << f.Eval<Context>() << endl; | |
} | |
}; | |
template<class Func> | |
PrintLN<Func> printLn(Func f) | |
{ | |
return PrintLN<Func>(f); | |
} | |
// The vector class. Length is a TypeExpr | |
template<class ElemType, class Length> | |
class Vec | |
{ | |
Length length; | |
ElemType *data; | |
public: | |
Vec(Length l) : length(l) | |
{ | |
data = new ElemType[Length::Eval()]; | |
} | |
Vec(const Vec<ElemType, Length>& v) : length(v.length) | |
{ | |
data = new ElemType[Length::Eval()]; | |
for (int i = 0; i < Length::Eval(); ++i) | |
data[i] = v.data[i]; | |
} | |
virtual ~Vec() | |
{ | |
delete[] data; | |
} | |
ElemType& operator[](int idx) { return data[idx]; } | |
Length Len() { return length; } | |
}; | |
// Vector operations | |
template<class Length> | |
Vec<unsigned, Length> NatVec(Length l) | |
{ | |
Vec<unsigned, Length> res(l); | |
for (int i = 0; i < Length::Eval(); ++i) | |
res[i] = i; | |
return res; | |
} | |
template<class L1, class L2, class T> | |
class DotProduct | |
{ | |
Vec<T, L1> v1; | |
Vec<T, L2> v2; | |
public: | |
DotProduct(Vec<T, L1> _v1, Vec<T, L2> _v2) : v1(_v1), v2(_v2) {} | |
template<class Context> | |
T Eval() | |
{ | |
static_assert(TypeExprEq<L1, L2, Context>::value, "Can't prove that vectors have the same length"); | |
T acc = {}; | |
for (int i = 0; i < L1::Eval(); ++i) | |
acc += v1[i] * v2[i]; | |
return acc; | |
} | |
operator T() | |
{ | |
return Eval<void>(); | |
} | |
}; | |
template<class L1, class L2, class T> | |
DotProduct<L1, L2, T> operator*(Vec<T, L1> v1, Vec<T, L2> v2) | |
{ | |
return DotProduct<L1, L2, T>(v1, v2); | |
} | |
template<class L1, class L2, class T> | |
Vec<T, Plus<L1, L2>> operator+(Vec<T, L1> v1, Vec<T, L2> v2) | |
{ | |
Vec<T, Plus<L1, L2>> res(Plus<L1, L2>(v1.Len(), v2.Len())); | |
for (int i = 0; i < L1::Eval(); ++i) | |
res[i] = v1[i]; | |
for (int i = 0; i < L2::Eval(); ++i) | |
res[i + L1::Eval()] = v2[i]; | |
return res; | |
} | |
int main(int argc, char* argv[]) | |
{ | |
int nn, mm; | |
cin >> nn >> mm; | |
TypeValue<1> n(nn); | |
TypeValue<2> m(mm); | |
auto a = NatVec(n); | |
auto b = NatVec(m); | |
cout << a * a << endl; | |
// cout << a * b << endl; // Error | |
cout << (a + b) * (a + b) << endl; | |
cout << (a + b) * (b + a) << endl; // OK if plusCommutative is present | |
cout << (a + b + a + a + b) * (b + a + a + a + b) << endl; // OK if plusCommutative is present | |
// cout << (a + b + a + b + a) * (b + a + a + a + b) << endl; // Error: associativity of + is not implemented | |
If(n == m, printLn(a * b)); | |
If(n == m, printLn((a + a) * (b + a))); | |
// If(n == m, printLn((a + a) * (b + a + a))); // Error | |
return 0; | |
} | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment