Skip to content

Instantly share code, notes, and snippets.

@torokati44
Last active September 29, 2015 19:52
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save torokati44/4521f03e5baa66a3ceae to your computer and use it in GitHub Desktop.
Save torokati44/4521f03e5baa66a3ceae to your computer and use it in GitHub Desktop.
#include "llvm/Analysis/Passes.h"
#include "llvm/ExecutionEngine/ExecutionEngine.h"
#include "llvm/ExecutionEngine/MCJIT.h"
#include "llvm/ExecutionEngine/SectionMemoryManager.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Verifier.h"
#include "llvm/PassManager.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Support/raw_os_ostream.h"
#include "llvm/Transforms/Scalar.h"
#include <cctype>
#include <cstdio>
#include <iostream>
#include <ctime>
#include <iomanip>
#include <map>
#include <algorithm>
#include <string>
#include <sstream>
#include <vector>
#include <math.h>
using namespace llvm;
IRBuilder<> Builder(getGlobalContext());
class ASTNode {
public:
virtual Value *genCode(Module *m, Function *f) = 0;
virtual double eval(double x) = 0;
virtual void print(std::ostream &os = std::cout) = 0;
virtual ~ASTNode() {}
};
class ConstantNode : public ASTNode {
double v;
ASTNode *l, *r;
public:
ConstantNode(double value) : v(value) {}
double eval(double x) { return v; }
Value *genCode(Module *m, Function *f) override {
return ConstantFP::get(getGlobalContext(), APFloat(v));
}
void print(std::ostream &os = std::cout) override {
os << v;
}
};
class ArgumentNode : public ASTNode {
public:
double eval(double x) {
return x;
}
Value *genCode(Module *m, Function *f) {
return &f->getArgumentList().front();
}
void print(std::ostream &os = std::cout) override {
os << 'x';
}
};
class BinaryOperatorNode : public ASTNode {
char o;
std::shared_ptr<ASTNode> l, r;
public:
BinaryOperatorNode(char op, std::shared_ptr<ASTNode> left, std::shared_ptr<ASTNode> right) : o(op), l(left), r(right) { }
double eval(double x) override {
switch (o) {
case '+': return l->eval(x) + r->eval(x);
case '-': return l->eval(x) - r->eval(x);
case '*': return l->eval(x) * r->eval(x);
case '^': return std::pow(l->eval(x), r->eval(x));
default: throw std::runtime_error(std::string("Unsupported binary operator: ") + o);
}
}
Value *genCode(Module *m, Function *f) override {
auto lc = l->genCode(m, f);
auto rc = r->genCode(m, f);
switch (o) {
case '+': return Builder.CreateFAdd(lc, rc);
case '-': return Builder.CreateFSub(lc, rc);
case '*': return Builder.CreateFMul(lc, rc);
case '^': return Builder.CreateCall2(m->getFunction("pow"), lc, rc);
default: throw std::runtime_error(std::string("Unsupported binary operator: ") + o);
}
}
void print(std::ostream &os = std::cout) override {
l->print(os);
os << ' ';
r->print(os);
os << ' ' << o;
}
};
class FunctionNode : public ASTNode {
std::string n;
std::shared_ptr<ASTNode> a;
public:
FunctionNode(const std::string &name, std::shared_ptr<ASTNode> arg) : n(name), a(arg) { }
double eval(double x) override {
if (n == "sin") {
return std::sin(a->eval(x));
} else if (n == "cos") {
return std::cos(a->eval(x));
} else if (n == "sqrt") {
return std::sqrt(a->eval(x));
} else if (n == "abs") {
return std::abs(a->eval(x));
} else if (n == "neg") {
return -a->eval(x);
} else if (n == "inv") {
return 1.0 / a->eval(x);
}
}
Value *genCode(Module *m, Function *f) override {
if (n == "inv") {
return Builder.CreateFDiv(ConstantFP::get(getGlobalContext(), APFloat(1.0)), a->genCode(m, f));
} else if (n == "neg") {
return Builder.CreateFSub(ConstantFP::get(getGlobalContext(), APFloat(0.0)), a->genCode(m, f));
} else {
auto fun = m->getFunction(n == "abs" ? "fabs" : n);
if (!fun)
throw std::runtime_error("No such function: " + n);
return Builder.CreateCall(fun, ArrayRef<Value*>(a->genCode(m, f)));
}
}
void print(std::ostream &os = std::cout) override {
a->print(os);
os << ' ' << n;
}
};
typedef double (*expr_fun_t)(double);
//===----------------------------------------------------------------------===//
// Main driver code.
//===----------------------------------------------------------------------===//
std::shared_ptr<ExecutionEngine> JIT(std::shared_ptr<ASTNode> root, Module **moduleOut = nullptr) {
// Make the module, which holds all the code.
std::unique_ptr<Module> Owner = make_unique<Module>("expression", getGlobalContext());
Module *TheModule = Owner.get();
std::vector<Type *> ArgList(1, Type::getDoubleTy(getGlobalContext()));
FunctionType *FT = FunctionType::get(Type::getDoubleTy(getGlobalContext()), ArgList, false);
Function::Create(FT, Function::ExternalLinkage, "sin", TheModule);
Function::Create(FT, Function::ExternalLinkage, "cos", TheModule);
Function::Create(FT, Function::ExternalLinkage, "sqrt", TheModule);
Function::Create(FT, Function::ExternalLinkage, "pow", TheModule);
Function::Create(FT, Function::ExternalLinkage, "fabs", TheModule);
// Create the JIT. This takes ownership of the module.
std::string ErrStr;
TargetOptions opts;
opts.AllowFPOpFusion = FPOpFusion::Fast;
opts.UnsafeFPMath = true;
ExecutionEngine *TheExecutionEngine = EngineBuilder(std::move(Owner))
.setErrorStr(&ErrStr)
.setMCJITMemoryManager(llvm::make_unique<SectionMemoryManager>())
.setTargetOptions(opts)
.create();
FunctionPassManager OurFPM(TheModule);
// Set up the optimizer pipeline. Start with registering info about how the
// target lays out data structures.
TheModule->setDataLayout(TheExecutionEngine->getDataLayout());
OurFPM.add(new DataLayoutPass());
// Provide basic AliasAnalysis support for GVN.
OurFPM.add(createBasicAliasAnalysisPass());
// Do simple "peephole" optimizations and bit-twiddling optzns.
OurFPM.add(createInstructionCombiningPass());
// Reassociate expressions.
OurFPM.add(createReassociatePass());
// Eliminate Common SubExpressions.
OurFPM.add(createGVNPass());
// Simplify the control flow graph (deleting unreachable blocks, etc).
OurFPM.add(createCFGSimplificationPass());
OurFPM.doInitialization();
std::vector<Type*> params;
params.push_back(Type::getDoubleTy(getGlobalContext()));
Function *TheFunction = Function::Create(FunctionType::get(Type::getDoubleTy(getGlobalContext()),
params,
false),
GlobalValue::InternalLinkage, "expr",
TheModule);
// Create a new basic block to start insertion into.
BasicBlock *BB = BasicBlock::Create(getGlobalContext(), "entry", TheFunction);
Builder.SetInsertPoint(BB);
Builder.CreateRet(root->genCode(TheModule, TheFunction));
// Validate the generated code, checking for consistency.
verifyFunction(*TheFunction);
// Optimize the function.
OurFPM.run(*TheFunction);
TheExecutionEngine->finalizeObject();
// Print out all of the generated code.*
if (moduleOut)
*moduleOut = TheModule;
return std::shared_ptr<ExecutionEngine>(TheExecutionEngine);
}
void draw(std::shared_ptr<ASTNode> root) {
char buffer[20][80];
for (int i = 0; i < 20; ++i) {
for (int j = 0; j < 80; ++j) {
buffer[i][j] = ' ';
}
}
for (int i = -40; i < 40; ++i) {
int v = std::min(std::max((int)std::round(-10 * root->eval(i / 10.0)), -10), 9);
buffer[v + 10][i + 40] = '*';
}
for (int i = 0; i < 20; ++i) {
for (int j = 0; j < 80; ++j) {
std::cout << buffer[i][j];
}
std::cout << std::endl;
}
}
std::shared_ptr<ASTNode> parseRPN(const std::string &expr) {
std::stack<std::shared_ptr<ASTNode>> stack;
std::stringstream ss(expr);
std::string token;
while (ss >> token) {
if (token == "x") {
stack.push(std::make_shared<ArgumentNode>());
} else if (token == "+" || token == "*" || token == "-" || token == "^") {
auto r = stack.top(); stack.pop();
auto l = stack.top(); stack.pop();
stack.push(std::make_shared<BinaryOperatorNode>(token[0], l, r));
} else if (std::all_of(token.begin(), token.end(), [](char a) { return isdigit(a) || a == '.'; })) {
stack.push(std::make_shared<ConstantNode>(atof(token.c_str())));
} else {
auto arg = stack.top(); stack.pop();
stack.push(std::make_shared<FunctionNode>(token, arg));
}
}
if (stack.size() == 1) {
return stack.top();
} else {
std::stringstream ss;
ss << "RPN parsing failed at token '" << token << "', stack has " << stack.size() << " elements.";
throw std::runtime_error(ss.str());
}
}
std::shared_ptr<ASTNode> generateRandomAST(int depth = 0) {
int type = rand() % 100;
const int maxDepth = 30;
if (((depth >= maxDepth) && (type % 2)) || type < 15)
return std::make_shared<ConstantNode>(200.0 * rand() / RAND_MAX - 100.0);
if (depth >= maxDepth || type < 30)
return std::make_shared<ArgumentNode>();
if (type < 65)
return std::make_shared<BinaryOperatorNode>("+-*^"[rand()%4],
generateRandomAST(depth + 1), generateRandomAST(depth + 1));
return std::make_shared<FunctionNode>(
(const char*[]){ "sin", "cos", "abs", "neg", "inv"}[rand()%5],
generateRandomAST(depth + 1));
}
int main() {
srand(time(nullptr));
InitializeNativeTarget();
InitializeNativeTargetAsmPrinter();
InitializeNativeTargetAsmParser();
std::string line;
while (std::getline(std::cin, line)) {
auto root = generateRandomAST();
std::cout << "The generated expression is:" << std::endl;
root->print();
std::cout << std::endl << std::endl;
constexpr int num_runs = 100000;
double values1[num_runs], values2[num_runs];
{
std::cout << "Evaluating it " << num_runs << " times without JIT...\n";
auto start = std::clock();
for (int i = 0; i < num_runs; ++i)
values1[i] = root->eval(i);
auto end = std::clock();
std::cout << "time elapsed: "
<< 1000000.0 * (end - start) / CLOCKS_PER_SEC << " us\n\n";
}
{
auto jit_start = std::clock();
Module *module;
auto eng = JIT(root, &module);
auto jit_end = std::clock();
std::cout << "JIT-ing took " << 1000000.0 * (jit_end - jit_start) / CLOCKS_PER_SEC << " us\n";
std::cout << "The generated LLVM code is:\n\n";
module->dump();
std::cout << std::endl << std::endl;
std::cout << "Evaluating it " << num_runs << " times with JIT...\n";
expr_fun_t expr_fun = (expr_fun_t)(eng->getFunctionAddress("expr"));
auto start = std::clock();
for (int i = 0; i < num_runs; ++i) {
values2[i] = expr_fun(i);
}
auto end = std::clock();
std::cout << "time elapsed: "
<< 1000000.0 * (end - start) / CLOCKS_PER_SEC << " us "
<< "(" << 1000000.0 * ((jit_end - jit_start) + (end - start)) / CLOCKS_PER_SEC << " us total)\n\n";
}
int max_mismatches_to_print = 100;
std::cout << "Mismatches:" << std::endl;
int mismatches = 0;
int denormals = 0;
for (int i = 0; i < num_runs; ++i) {
if (!(std::isnormal(values1[i]) && std::isnormal(values2[i]))) {
++denormals;
continue;
}
if (values1[i] != values2[i]) {
if (mismatches == max_mismatches_to_print) {
std::cout << "(too many to print)" << std::endl;
} else if (mismatches < max_mismatches_to_print) {
std::cout << "at " << i << std::setprecision(100) << " without JIT: " << values1[i] << " but with JIT: " << values2[i] << std::endl;
}
++mismatches;
}
}
std::cout << "Found " << mismatches << " mismatches and " << denormals << " denormals" << std::endl;
std::cout << "----" << std::endl;
continue;
if (!line.empty()) {
//auto root = parseRPN(line);
auto root = generateRandomAST();
//draw(root, true);
try {
for (int i = 0; i < 10; ++i) {
std::cout << root->eval(i) << " ";
}
std::cout << std::endl;
} catch (std::runtime_error &err) {
std::cerr << "ERROR! " << err.what() << std::endl;
}
}
}
return 0;
}
The generated expression is:
23.3544 neg x 12.6982 51.0453 ^ + sin inv cos cos + inv -72.4637 - inv 23.5312 -
Evaluating it 100000 times without JIT...
time elapsed: 67980 us
JIT-ing took 3930 us
The generated LLVM code is:
; ModuleID = 'expression'
target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128"
declare double @sin(double)
declare double @cos(double)
declare double @sqrt(double)
declare double @pow(double)
declare double @fabs(double)
define internal double @expr(double) {
entry:
%1 = fadd double %0, 0x4BA1E15E18E6398A
%2 = call double @sin(double %1)
%3 = fdiv double 1.000000e+00, %2
%4 = call double @cos(double %3)
%5 = call double @cos(double %4)
%6 = fadd double %5, 0xC0375AB769E6B570
%7 = fdiv double 1.000000e+00, %6
%8 = fadd double %7, 0x40521DADC2023B5C
%9 = fdiv double 1.000000e+00, %8
%10 = fadd double %9, 0xC03787FB0AB70FF8
ret double %10
}
Evaluating it 100000 times with JIT...
time elapsed: 16215 us (20145 us total)
Mismatches:
Found 0 mismatches and 0 denormals
----
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment