Last active
September 29, 2015 19:52
-
-
Save torokati44/4521f03e5baa66a3ceae to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include "llvm/Analysis/Passes.h" | |
#include "llvm/ExecutionEngine/ExecutionEngine.h" | |
#include "llvm/ExecutionEngine/MCJIT.h" | |
#include "llvm/ExecutionEngine/SectionMemoryManager.h" | |
#include "llvm/IR/DataLayout.h" | |
#include "llvm/IR/DerivedTypes.h" | |
#include "llvm/IR/IRBuilder.h" | |
#include "llvm/IR/LLVMContext.h" | |
#include "llvm/IR/Module.h" | |
#include "llvm/IR/Verifier.h" | |
#include "llvm/PassManager.h" | |
#include "llvm/Support/TargetSelect.h" | |
#include "llvm/Support/raw_os_ostream.h" | |
#include "llvm/Transforms/Scalar.h" | |
#include <cctype> | |
#include <cstdio> | |
#include <iostream> | |
#include <ctime> | |
#include <iomanip> | |
#include <map> | |
#include <algorithm> | |
#include <string> | |
#include <sstream> | |
#include <vector> | |
#include <math.h> | |
using namespace llvm; | |
IRBuilder<> Builder(getGlobalContext()); | |
class ASTNode { | |
public: | |
virtual Value *genCode(Module *m, Function *f) = 0; | |
virtual double eval(double x) = 0; | |
virtual void print(std::ostream &os = std::cout) = 0; | |
virtual ~ASTNode() {} | |
}; | |
class ConstantNode : public ASTNode { | |
double v; | |
ASTNode *l, *r; | |
public: | |
ConstantNode(double value) : v(value) {} | |
double eval(double x) { return v; } | |
Value *genCode(Module *m, Function *f) override { | |
return ConstantFP::get(getGlobalContext(), APFloat(v)); | |
} | |
void print(std::ostream &os = std::cout) override { | |
os << v; | |
} | |
}; | |
class ArgumentNode : public ASTNode { | |
public: | |
double eval(double x) { | |
return x; | |
} | |
Value *genCode(Module *m, Function *f) { | |
return &f->getArgumentList().front(); | |
} | |
void print(std::ostream &os = std::cout) override { | |
os << 'x'; | |
} | |
}; | |
class BinaryOperatorNode : public ASTNode { | |
char o; | |
std::shared_ptr<ASTNode> l, r; | |
public: | |
BinaryOperatorNode(char op, std::shared_ptr<ASTNode> left, std::shared_ptr<ASTNode> right) : o(op), l(left), r(right) { } | |
double eval(double x) override { | |
switch (o) { | |
case '+': return l->eval(x) + r->eval(x); | |
case '-': return l->eval(x) - r->eval(x); | |
case '*': return l->eval(x) * r->eval(x); | |
case '^': return std::pow(l->eval(x), r->eval(x)); | |
default: throw std::runtime_error(std::string("Unsupported binary operator: ") + o); | |
} | |
} | |
Value *genCode(Module *m, Function *f) override { | |
auto lc = l->genCode(m, f); | |
auto rc = r->genCode(m, f); | |
switch (o) { | |
case '+': return Builder.CreateFAdd(lc, rc); | |
case '-': return Builder.CreateFSub(lc, rc); | |
case '*': return Builder.CreateFMul(lc, rc); | |
case '^': return Builder.CreateCall2(m->getFunction("pow"), lc, rc); | |
default: throw std::runtime_error(std::string("Unsupported binary operator: ") + o); | |
} | |
} | |
void print(std::ostream &os = std::cout) override { | |
l->print(os); | |
os << ' '; | |
r->print(os); | |
os << ' ' << o; | |
} | |
}; | |
class FunctionNode : public ASTNode { | |
std::string n; | |
std::shared_ptr<ASTNode> a; | |
public: | |
FunctionNode(const std::string &name, std::shared_ptr<ASTNode> arg) : n(name), a(arg) { } | |
double eval(double x) override { | |
if (n == "sin") { | |
return std::sin(a->eval(x)); | |
} else if (n == "cos") { | |
return std::cos(a->eval(x)); | |
} else if (n == "sqrt") { | |
return std::sqrt(a->eval(x)); | |
} else if (n == "abs") { | |
return std::abs(a->eval(x)); | |
} else if (n == "neg") { | |
return -a->eval(x); | |
} else if (n == "inv") { | |
return 1.0 / a->eval(x); | |
} | |
} | |
Value *genCode(Module *m, Function *f) override { | |
if (n == "inv") { | |
return Builder.CreateFDiv(ConstantFP::get(getGlobalContext(), APFloat(1.0)), a->genCode(m, f)); | |
} else if (n == "neg") { | |
return Builder.CreateFSub(ConstantFP::get(getGlobalContext(), APFloat(0.0)), a->genCode(m, f)); | |
} else { | |
auto fun = m->getFunction(n == "abs" ? "fabs" : n); | |
if (!fun) | |
throw std::runtime_error("No such function: " + n); | |
return Builder.CreateCall(fun, ArrayRef<Value*>(a->genCode(m, f))); | |
} | |
} | |
void print(std::ostream &os = std::cout) override { | |
a->print(os); | |
os << ' ' << n; | |
} | |
}; | |
typedef double (*expr_fun_t)(double); | |
//===----------------------------------------------------------------------===// | |
// Main driver code. | |
//===----------------------------------------------------------------------===// | |
std::shared_ptr<ExecutionEngine> JIT(std::shared_ptr<ASTNode> root, Module **moduleOut = nullptr) { | |
// Make the module, which holds all the code. | |
std::unique_ptr<Module> Owner = make_unique<Module>("expression", getGlobalContext()); | |
Module *TheModule = Owner.get(); | |
std::vector<Type *> ArgList(1, Type::getDoubleTy(getGlobalContext())); | |
FunctionType *FT = FunctionType::get(Type::getDoubleTy(getGlobalContext()), ArgList, false); | |
Function::Create(FT, Function::ExternalLinkage, "sin", TheModule); | |
Function::Create(FT, Function::ExternalLinkage, "cos", TheModule); | |
Function::Create(FT, Function::ExternalLinkage, "sqrt", TheModule); | |
Function::Create(FT, Function::ExternalLinkage, "pow", TheModule); | |
Function::Create(FT, Function::ExternalLinkage, "fabs", TheModule); | |
// Create the JIT. This takes ownership of the module. | |
std::string ErrStr; | |
TargetOptions opts; | |
opts.AllowFPOpFusion = FPOpFusion::Fast; | |
opts.UnsafeFPMath = true; | |
ExecutionEngine *TheExecutionEngine = EngineBuilder(std::move(Owner)) | |
.setErrorStr(&ErrStr) | |
.setMCJITMemoryManager(llvm::make_unique<SectionMemoryManager>()) | |
.setTargetOptions(opts) | |
.create(); | |
FunctionPassManager OurFPM(TheModule); | |
// Set up the optimizer pipeline. Start with registering info about how the | |
// target lays out data structures. | |
TheModule->setDataLayout(TheExecutionEngine->getDataLayout()); | |
OurFPM.add(new DataLayoutPass()); | |
// Provide basic AliasAnalysis support for GVN. | |
OurFPM.add(createBasicAliasAnalysisPass()); | |
// Do simple "peephole" optimizations and bit-twiddling optzns. | |
OurFPM.add(createInstructionCombiningPass()); | |
// Reassociate expressions. | |
OurFPM.add(createReassociatePass()); | |
// Eliminate Common SubExpressions. | |
OurFPM.add(createGVNPass()); | |
// Simplify the control flow graph (deleting unreachable blocks, etc). | |
OurFPM.add(createCFGSimplificationPass()); | |
OurFPM.doInitialization(); | |
std::vector<Type*> params; | |
params.push_back(Type::getDoubleTy(getGlobalContext())); | |
Function *TheFunction = Function::Create(FunctionType::get(Type::getDoubleTy(getGlobalContext()), | |
params, | |
false), | |
GlobalValue::InternalLinkage, "expr", | |
TheModule); | |
// Create a new basic block to start insertion into. | |
BasicBlock *BB = BasicBlock::Create(getGlobalContext(), "entry", TheFunction); | |
Builder.SetInsertPoint(BB); | |
Builder.CreateRet(root->genCode(TheModule, TheFunction)); | |
// Validate the generated code, checking for consistency. | |
verifyFunction(*TheFunction); | |
// Optimize the function. | |
OurFPM.run(*TheFunction); | |
TheExecutionEngine->finalizeObject(); | |
// Print out all of the generated code.* | |
if (moduleOut) | |
*moduleOut = TheModule; | |
return std::shared_ptr<ExecutionEngine>(TheExecutionEngine); | |
} | |
void draw(std::shared_ptr<ASTNode> root) { | |
char buffer[20][80]; | |
for (int i = 0; i < 20; ++i) { | |
for (int j = 0; j < 80; ++j) { | |
buffer[i][j] = ' '; | |
} | |
} | |
for (int i = -40; i < 40; ++i) { | |
int v = std::min(std::max((int)std::round(-10 * root->eval(i / 10.0)), -10), 9); | |
buffer[v + 10][i + 40] = '*'; | |
} | |
for (int i = 0; i < 20; ++i) { | |
for (int j = 0; j < 80; ++j) { | |
std::cout << buffer[i][j]; | |
} | |
std::cout << std::endl; | |
} | |
} | |
std::shared_ptr<ASTNode> parseRPN(const std::string &expr) { | |
std::stack<std::shared_ptr<ASTNode>> stack; | |
std::stringstream ss(expr); | |
std::string token; | |
while (ss >> token) { | |
if (token == "x") { | |
stack.push(std::make_shared<ArgumentNode>()); | |
} else if (token == "+" || token == "*" || token == "-" || token == "^") { | |
auto r = stack.top(); stack.pop(); | |
auto l = stack.top(); stack.pop(); | |
stack.push(std::make_shared<BinaryOperatorNode>(token[0], l, r)); | |
} else if (std::all_of(token.begin(), token.end(), [](char a) { return isdigit(a) || a == '.'; })) { | |
stack.push(std::make_shared<ConstantNode>(atof(token.c_str()))); | |
} else { | |
auto arg = stack.top(); stack.pop(); | |
stack.push(std::make_shared<FunctionNode>(token, arg)); | |
} | |
} | |
if (stack.size() == 1) { | |
return stack.top(); | |
} else { | |
std::stringstream ss; | |
ss << "RPN parsing failed at token '" << token << "', stack has " << stack.size() << " elements."; | |
throw std::runtime_error(ss.str()); | |
} | |
} | |
std::shared_ptr<ASTNode> generateRandomAST(int depth = 0) { | |
int type = rand() % 100; | |
const int maxDepth = 30; | |
if (((depth >= maxDepth) && (type % 2)) || type < 15) | |
return std::make_shared<ConstantNode>(200.0 * rand() / RAND_MAX - 100.0); | |
if (depth >= maxDepth || type < 30) | |
return std::make_shared<ArgumentNode>(); | |
if (type < 65) | |
return std::make_shared<BinaryOperatorNode>("+-*^"[rand()%4], | |
generateRandomAST(depth + 1), generateRandomAST(depth + 1)); | |
return std::make_shared<FunctionNode>( | |
(const char*[]){ "sin", "cos", "abs", "neg", "inv"}[rand()%5], | |
generateRandomAST(depth + 1)); | |
} | |
int main() { | |
srand(time(nullptr)); | |
InitializeNativeTarget(); | |
InitializeNativeTargetAsmPrinter(); | |
InitializeNativeTargetAsmParser(); | |
std::string line; | |
while (std::getline(std::cin, line)) { | |
auto root = generateRandomAST(); | |
std::cout << "The generated expression is:" << std::endl; | |
root->print(); | |
std::cout << std::endl << std::endl; | |
constexpr int num_runs = 100000; | |
double values1[num_runs], values2[num_runs]; | |
{ | |
std::cout << "Evaluating it " << num_runs << " times without JIT...\n"; | |
auto start = std::clock(); | |
for (int i = 0; i < num_runs; ++i) | |
values1[i] = root->eval(i); | |
auto end = std::clock(); | |
std::cout << "time elapsed: " | |
<< 1000000.0 * (end - start) / CLOCKS_PER_SEC << " us\n\n"; | |
} | |
{ | |
auto jit_start = std::clock(); | |
Module *module; | |
auto eng = JIT(root, &module); | |
auto jit_end = std::clock(); | |
std::cout << "JIT-ing took " << 1000000.0 * (jit_end - jit_start) / CLOCKS_PER_SEC << " us\n"; | |
std::cout << "The generated LLVM code is:\n\n"; | |
module->dump(); | |
std::cout << std::endl << std::endl; | |
std::cout << "Evaluating it " << num_runs << " times with JIT...\n"; | |
expr_fun_t expr_fun = (expr_fun_t)(eng->getFunctionAddress("expr")); | |
auto start = std::clock(); | |
for (int i = 0; i < num_runs; ++i) { | |
values2[i] = expr_fun(i); | |
} | |
auto end = std::clock(); | |
std::cout << "time elapsed: " | |
<< 1000000.0 * (end - start) / CLOCKS_PER_SEC << " us " | |
<< "(" << 1000000.0 * ((jit_end - jit_start) + (end - start)) / CLOCKS_PER_SEC << " us total)\n\n"; | |
} | |
int max_mismatches_to_print = 100; | |
std::cout << "Mismatches:" << std::endl; | |
int mismatches = 0; | |
int denormals = 0; | |
for (int i = 0; i < num_runs; ++i) { | |
if (!(std::isnormal(values1[i]) && std::isnormal(values2[i]))) { | |
++denormals; | |
continue; | |
} | |
if (values1[i] != values2[i]) { | |
if (mismatches == max_mismatches_to_print) { | |
std::cout << "(too many to print)" << std::endl; | |
} else if (mismatches < max_mismatches_to_print) { | |
std::cout << "at " << i << std::setprecision(100) << " without JIT: " << values1[i] << " but with JIT: " << values2[i] << std::endl; | |
} | |
++mismatches; | |
} | |
} | |
std::cout << "Found " << mismatches << " mismatches and " << denormals << " denormals" << std::endl; | |
std::cout << "----" << std::endl; | |
continue; | |
if (!line.empty()) { | |
//auto root = parseRPN(line); | |
auto root = generateRandomAST(); | |
//draw(root, true); | |
try { | |
for (int i = 0; i < 10; ++i) { | |
std::cout << root->eval(i) << " "; | |
} | |
std::cout << std::endl; | |
} catch (std::runtime_error &err) { | |
std::cerr << "ERROR! " << err.what() << std::endl; | |
} | |
} | |
} | |
return 0; | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
The generated expression is: | |
23.3544 neg x 12.6982 51.0453 ^ + sin inv cos cos + inv -72.4637 - inv 23.5312 - | |
Evaluating it 100000 times without JIT... | |
time elapsed: 67980 us | |
JIT-ing took 3930 us | |
The generated LLVM code is: | |
; ModuleID = 'expression' | |
target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128" | |
declare double @sin(double) | |
declare double @cos(double) | |
declare double @sqrt(double) | |
declare double @pow(double) | |
declare double @fabs(double) | |
define internal double @expr(double) { | |
entry: | |
%1 = fadd double %0, 0x4BA1E15E18E6398A | |
%2 = call double @sin(double %1) | |
%3 = fdiv double 1.000000e+00, %2 | |
%4 = call double @cos(double %3) | |
%5 = call double @cos(double %4) | |
%6 = fadd double %5, 0xC0375AB769E6B570 | |
%7 = fdiv double 1.000000e+00, %6 | |
%8 = fadd double %7, 0x40521DADC2023B5C | |
%9 = fdiv double 1.000000e+00, %8 | |
%10 = fadd double %9, 0xC03787FB0AB70FF8 | |
ret double %10 | |
} | |
Evaluating it 100000 times with JIT... | |
time elapsed: 16215 us (20145 us total) | |
Mismatches: | |
Found 0 mismatches and 0 denormals | |
---- |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment