Skip to content

Instantly share code, notes, and snippets.

@sorrge
Last active April 14, 2019 10:20
Show Gist options
  • Save sorrge/2d2271e57ad1e91a50ea to your computer and use it in GitHub Desktop.
Save sorrge/2d2271e57ad1e91a50ea to your computer and use it in GitHub Desktop.
Dependent types in C++
#include <iostream>
using namespace std;
// TypeValue is a type and a value at the same time.
// Each ID is a type, and there can be only one value of this type
// This prevents (only in runtime) construction of two vectors with the same Length variable, but of different actual lengths
template<int ID>
class TypeValue
{
static int value;
static bool initialized;
public:
TypeValue(int v)
{
if (initialized)
throw "Tried to reinitialize a TypeValue";
initialized = true;
value = v;
}
static int Eval() { return value; }
};
template<int ID>
bool TypeValue<ID>::initialized = false;
template<int ID>
int TypeValue<ID>::value;
// the presense of this class in the context means that TypeExpr1 is proven to be equal to TypeExpr2
template<class TypeExpr1, class TypeExpr2>
class TypesAreEqual;
// static TypeExpr equality check using a context
template<class TV1, class TV2, class Context>
struct TypeExprEq : false_type{};
template<int ID, class Context>
struct TypeExprEq<TypeValue<ID>, TypeValue<ID>, Context> : true_type{};
template<class TV1, class TV2>
struct TypeExprEq<TV1, TV2, TypesAreEqual<TV1, TV2>> : true_type{};
template<class TV1, class TV2>
struct TypeExprEq<TV1, TV2, TypesAreEqual<TV2, TV1>> : true_type{};
template<class E1, class E2>
class Plus
{
public:
Plus(E1, E2) {} // The constructor takes values as the proof that they have been created and initialized
static int Eval() { return E1::Eval() + E2::Eval(); }
};
// plusCommutative
template<class TV1, class TV2, class TV3, class TV4, class Context>
struct TypeExprEq<Plus<TV1, TV2>, Plus<TV3, TV4>, Context>
{
static const bool value = TypeExprEq<TV1, TV3, Context>::value && TypeExprEq<TV2, TV4, Context>::value ||
TypeExprEq<TV1, TV4, Context>::value && TypeExprEq<TV2, TV3, Context>::value;
};
// runtime TypeExpr equality check
template<class TypeExpr1, class TypeExpr2>
class TypeExprEqRT
{
public:
TypeExprEqRT(TypeExpr1, TypeExpr2) {}
static bool Eval()
{
return TypeExpr1::Eval() == TypeExpr2::Eval();
}
typedef TypesAreEqual<TypeExpr1, TypeExpr2> ContextIfTrue;
};
template<int ID1, int ID2>
TypeExprEqRT<TypeValue<ID1>, TypeValue<ID2>> operator==(TypeValue<ID1> tv1, TypeValue<ID2> tv2)
{
return TypeExprEqRT<TypeValue<ID1>, TypeValue<ID2>>(tv1, tv2);
}
template<class Condition, class Operation>
void If(Condition, Operation oper)
{
if (Condition::Eval())
oper.Eval<Condition::ContextIfTrue>();
}
// Utility function which passes the context
template<class Func>
class PrintLN
{
Func f;
public:
PrintLN(Func _f) : f(_f) {}
template<class Context>
void Eval()
{
cout << f.Eval<Context>() << endl;
}
};
template<class Func>
PrintLN<Func> printLn(Func f)
{
return PrintLN<Func>(f);
}
// The vector class. Length is a TypeExpr
template<class ElemType, class Length>
class Vec
{
Length length;
ElemType *data;
public:
Vec(Length l) : length(l)
{
data = new ElemType[Length::Eval()];
}
Vec(const Vec<ElemType, Length>& v) : length(v.length)
{
data = new ElemType[Length::Eval()];
for (int i = 0; i < Length::Eval(); ++i)
data[i] = v.data[i];
}
virtual ~Vec()
{
delete[] data;
}
ElemType& operator[](int idx) { return data[idx]; }
Length Len() { return length; }
};
// Vector operations
template<class Length>
Vec<unsigned, Length> NatVec(Length l)
{
Vec<unsigned, Length> res(l);
for (int i = 0; i < Length::Eval(); ++i)
res[i] = i;
return res;
}
template<class L1, class L2, class T>
class DotProduct
{
Vec<T, L1> v1;
Vec<T, L2> v2;
public:
DotProduct(Vec<T, L1> _v1, Vec<T, L2> _v2) : v1(_v1), v2(_v2) {}
template<class Context>
T Eval()
{
static_assert(TypeExprEq<L1, L2, Context>::value, "Can't prove that vectors have the same length");
T acc = {};
for (int i = 0; i < L1::Eval(); ++i)
acc += v1[i] * v2[i];
return acc;
}
operator T()
{
return Eval<void>();
}
};
template<class L1, class L2, class T>
DotProduct<L1, L2, T> operator*(Vec<T, L1> v1, Vec<T, L2> v2)
{
return DotProduct<L1, L2, T>(v1, v2);
}
template<class L1, class L2, class T>
Vec<T, Plus<L1, L2>> operator+(Vec<T, L1> v1, Vec<T, L2> v2)
{
Vec<T, Plus<L1, L2>> res(Plus<L1, L2>(v1.Len(), v2.Len()));
for (int i = 0; i < L1::Eval(); ++i)
res[i] = v1[i];
for (int i = 0; i < L2::Eval(); ++i)
res[i + L1::Eval()] = v2[i];
return res;
}
int main(int argc, char* argv[])
{
int nn, mm;
cin >> nn >> mm;
TypeValue<1> n(nn);
TypeValue<2> m(mm);
auto a = NatVec(n);
auto b = NatVec(m);
cout << a * a << endl;
// cout << a * b << endl; // Error
cout << (a + b) * (a + b) << endl;
cout << (a + b) * (b + a) << endl; // OK if plusCommutative is present
cout << (a + b + a + a + b) * (b + a + a + a + b) << endl; // OK if plusCommutative is present
// cout << (a + b + a + b + a) * (b + a + a + a + b) << endl; // Error: associativity of + is not implemented
If(n == m, printLn(a * b));
If(n == m, printLn((a + a) * (b + a)));
// If(n == m, printLn((a + a) * (b + a + a))); // Error
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment