Skip to content

Instantly share code, notes, and snippets.

@keveman
Created July 20, 2019 23:21
Show Gist options
  • Save keveman/2f2ab304fd12a3d6bf9e965eef9b1059 to your computer and use it in GitHub Desktop.
Save keveman/2f2ab304fd12a3d6bf9e965eef9b1059 to your computer and use it in GitHub Desktop.
outer product
#include <iostream>
#include <type_traits>
#include "peg.inl.h"
@meta void check_array(@meta const access &a) {
static_assert(std::is_array<decltype(@(a.tensor_name))>::value);
static_assert(std::rank<decltype(@(a.tensor_name))>::value == a.A.rows);
}
@macro void foo(const char* str) {
@meta loop x = parse(str);
@meta x.dump(std::cout);
@meta const access &lhs = x.lhs.a;
@meta check_array(lhs);
@meta const access &rhs1 = x.rhs.operands[0].a;
@meta const access &rhs2 = x.rhs.operands[1].a;
@meta check_array(rhs1);
@meta check_array(rhs2);
}
float W[32][16], x[32], y[16];
void outer_product() {
@macro foo(`
all(i, j) in (32, 16)
W[i][j] = x[i] * y[j]
`);
}
int main() {
std::cout << std::is_array<decltype(W)>::value << "\n";
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment