Skip to content

Instantly share code, notes, and snippets.

@monamimani monamimani/Jit.h Secret
Created Apr 8, 2019

Embed
What would you like to do?
Code for jitting c++
#pragma once
#include <iostream>
#include <memory>
#include <optional>
#include <string>
#include <type_traits>
#include "llvm/Bitcode/BitcodeReader.h"
#include "llvm/ExecutionEngine/Orc/LLJIT.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/GlobalValue.h"
#include "llvm/IR/Verifier.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Support/raw_os_ostream.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Support/DynamicLibrary.h"
namespace llvmOrc = llvm::orc;
template <typename T>
llvm::Type* getLLVMTypeImpl(llvm::LLVMContext& context)
{
if constexpr (std::is_same<T, void>::value)
{
return llvm::Type::getVoidTy(context);
}
if constexpr (std::is_same<T, float>::value)
{
return llvm::Type::getFloatTy(context);
}
if constexpr (std::is_same<T, double>::value)
{
return llvm::Type::getDoubleTy(context);
}
if constexpr (std::is_same<T, int8_t>::value)
{
return llvm::Type::getIntNTy(context, 8);
}
if constexpr (std::is_same<T, int16_t>::value)
{
return llvm::Type::getIntNTy(context, 16);
}
if constexpr (std::is_same<T, int32_t>::value)
{
return llvm::Type::getIntNTy(context, 32);
}
if constexpr (std::is_same<T, int64_t>::value)
{
return llvm::Type::getIntNTy(context, 64);
}
throw std::runtime_error("No llvm Type correspondance implemented yet. Add new cases as needed.");
}
template <typename T>
llvm::Type* getLLVMType(llvm::LLVMContext& context)
{
auto type = getLLVMTypeImpl<T>(context);
if constexpr (!std::is_pointer<T>::value)
{
return type;
}
else
{
constexpr uint32_t addressSpace = 0;
return type->getPointerto(addressSpace);
}
}
template <typename T>
class FunctionType;
template <typename R, typename... Args>
class FunctionType<R(Args...)>
{
public:
static auto Get(llvm::LLVMContext& context)
{
auto returnType = getLLVMType<R>(context);
return llvm::FunctionType::get(returnType, {getLLVMType<Args>(context)...}, /*isVarArg=*/false);
}
};
class Jit
{
public:
Jit(std::unique_ptr<llvmOrc::LLJIT> lljit, llvm::DataLayout dataLayout)
: m_lljit(std::move(lljit))
, m_context(llvm::make_unique<llvm::LLVMContext>())
, m_dataLayout(dataLayout)
{
m_lljit->getMainJITDylib().setGenerator(llvm::cantFail(llvmOrc::DynamicLibrarySearchGenerator::GetForCurrentProcess(dataLayout)));
m_lljit->getObjLinkingLayer().setOverrideObjectFlagsWithResponsibilityFlags(true);
//llvm::sys::DynamicLibrary::LoadLibraryPermanently(nullptr);
//auto success = llvm::sys::DynamicLibrary::LoadLibraryPermanently("vcruntime140d.dll");
}
static std::unique_ptr<Jit> create()
{
llvm::InitializeNativeTarget();
llvm::InitializeNativeTargetAsmParser();
llvm::InitializeNativeTargetAsmPrinter();
auto jitTargetMachineBuilder = llvm::cantFail(llvmOrc::JITTargetMachineBuilder::detectHost());
auto dataLayout = llvm::cantFail(jitTargetMachineBuilder.getDefaultDataLayoutForTarget());
auto lljit = llvm::cantFail(llvmOrc::LLJIT::Create(jitTargetMachineBuilder, dataLayout));
return llvm::make_unique<Jit>(std::move(lljit), dataLayout);
}
void createAddModule(const std::string& bytecode)
{
//if (!llvm::isBitcode(bytecode.data(), bytecode.data()+ bytecode.size()))
//{
// throw std::exception("String buffer doesn't contain bytecode.");
//}
std::unique_ptr<llvm::MemoryBuffer> buffer = llvm::MemoryBuffer::getMemBuffer(bytecode, "BytecodeBuffer");
std::unique_ptr<llvm::Module> bitecodeModule = llvm::cantFail(llvm::parseBitcodeFile(buffer.get()->getMemBufferRef(), *m_context.getContext()));
llvm::verifyModule(*bitecodeModule.get(), &llvm::errs());
auto threadSafeModule = llvmOrc::ThreadSafeModule(std::move(bitecodeModule), m_context);
llvm::cantFail(m_lljit->addIRModule(std::move(threadSafeModule)));
{
std::error_code EC;
llvm::raw_fd_ostream rawFileOstream("MainJITDylib.txt", EC);
m_lljit->getMainJITDylib().dump(rawFileOstream);
}
}
template <typename Type>
std::optional<std::function<Type>> getSymbolAsFunction(std::string name)
{
llvm::Module llvmModule("TmpModule", *m_context.getContext());
llvmModule.setDataLayout(m_dataLayout);
auto fctType = FunctionType<void()>::Get(*m_context.getContext());
auto function = llvm::Function::Create(fctType, llvm::GlobalValue::ExternalLinkage, llvm::StringRef(name), llvmModule);
function->setCallingConv(llvm::CallingConv::X86_StdCall);
std::string mangled;
llvm::raw_string_ostream stringStream(mangled);
llvm::Mangler mangler;
mangler.getNameWithPrefix(stringStream, function, false);
stringStream.flush();
function->eraseFromParent();
m_lljit->runConstructors();
auto jitEvaluatedSymbolExpected = m_lljit->lookupLinkerMangled(mangled);
if (!jitEvaluatedSymbolExpected)
{
llvm::raw_os_ostream errStream{std::cerr};
llvm::logAllUnhandledErrors(jitEvaluatedSymbolExpected.takeError(), errStream, {});
return std::nullopt;
}
auto symbolAddress = jitEvaluatedSymbolExpected->getAddress();
auto typedFunctionPtr = reinterpret_cast<Type*>(symbolAddress);
return std::function<Type>(typedFunctionPtr);
}
private:
std::unique_ptr<llvmOrc::LLJIT> m_lljit;
llvmOrc::ThreadSafeContext m_context;
llvm::DataLayout m_dataLayout;
};
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.