Code for jitting c++
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#pragma once | |
#include <iostream> | |
#include <memory> | |
#include <optional> | |
#include <string> | |
#include <type_traits> | |
#include "llvm/Bitcode/BitcodeReader.h" | |
#include "llvm/ExecutionEngine/Orc/LLJIT.h" | |
#include "llvm/IR/DerivedTypes.h" | |
#include "llvm/IR/GlobalValue.h" | |
#include "llvm/IR/Verifier.h" | |
#include "llvm/Support/TargetSelect.h" | |
#include "llvm/Support/raw_os_ostream.h" | |
#include "llvm/Support/raw_ostream.h" | |
#include "llvm/Support/DynamicLibrary.h" | |
namespace llvmOrc = llvm::orc; | |
template <typename T> | |
llvm::Type* getLLVMTypeImpl(llvm::LLVMContext& context) | |
{ | |
if constexpr (std::is_same<T, void>::value) | |
{ | |
return llvm::Type::getVoidTy(context); | |
} | |
if constexpr (std::is_same<T, float>::value) | |
{ | |
return llvm::Type::getFloatTy(context); | |
} | |
if constexpr (std::is_same<T, double>::value) | |
{ | |
return llvm::Type::getDoubleTy(context); | |
} | |
if constexpr (std::is_same<T, int8_t>::value) | |
{ | |
return llvm::Type::getIntNTy(context, 8); | |
} | |
if constexpr (std::is_same<T, int16_t>::value) | |
{ | |
return llvm::Type::getIntNTy(context, 16); | |
} | |
if constexpr (std::is_same<T, int32_t>::value) | |
{ | |
return llvm::Type::getIntNTy(context, 32); | |
} | |
if constexpr (std::is_same<T, int64_t>::value) | |
{ | |
return llvm::Type::getIntNTy(context, 64); | |
} | |
throw std::runtime_error("No llvm Type correspondance implemented yet. Add new cases as needed."); | |
} | |
template <typename T> | |
llvm::Type* getLLVMType(llvm::LLVMContext& context) | |
{ | |
auto type = getLLVMTypeImpl<T>(context); | |
if constexpr (!std::is_pointer<T>::value) | |
{ | |
return type; | |
} | |
else | |
{ | |
constexpr uint32_t addressSpace = 0; | |
return type->getPointerto(addressSpace); | |
} | |
} | |
template <typename T> | |
class FunctionType; | |
template <typename R, typename... Args> | |
class FunctionType<R(Args...)> | |
{ | |
public: | |
static auto Get(llvm::LLVMContext& context) | |
{ | |
auto returnType = getLLVMType<R>(context); | |
return llvm::FunctionType::get(returnType, {getLLVMType<Args>(context)...}, /*isVarArg=*/false); | |
} | |
}; | |
class Jit | |
{ | |
public: | |
Jit(std::unique_ptr<llvmOrc::LLJIT> lljit, llvm::DataLayout dataLayout) | |
: m_lljit(std::move(lljit)) | |
, m_context(llvm::make_unique<llvm::LLVMContext>()) | |
, m_dataLayout(dataLayout) | |
{ | |
m_lljit->getMainJITDylib().setGenerator(llvm::cantFail(llvmOrc::DynamicLibrarySearchGenerator::GetForCurrentProcess(dataLayout))); | |
m_lljit->getObjLinkingLayer().setOverrideObjectFlagsWithResponsibilityFlags(true); | |
//llvm::sys::DynamicLibrary::LoadLibraryPermanently(nullptr); | |
//auto success = llvm::sys::DynamicLibrary::LoadLibraryPermanently("vcruntime140d.dll"); | |
} | |
static std::unique_ptr<Jit> create() | |
{ | |
llvm::InitializeNativeTarget(); | |
llvm::InitializeNativeTargetAsmParser(); | |
llvm::InitializeNativeTargetAsmPrinter(); | |
auto jitTargetMachineBuilder = llvm::cantFail(llvmOrc::JITTargetMachineBuilder::detectHost()); | |
auto dataLayout = llvm::cantFail(jitTargetMachineBuilder.getDefaultDataLayoutForTarget()); | |
auto lljit = llvm::cantFail(llvmOrc::LLJIT::Create(jitTargetMachineBuilder, dataLayout)); | |
return llvm::make_unique<Jit>(std::move(lljit), dataLayout); | |
} | |
void createAddModule(const std::string& bytecode) | |
{ | |
//if (!llvm::isBitcode(bytecode.data(), bytecode.data()+ bytecode.size())) | |
//{ | |
// throw std::exception("String buffer doesn't contain bytecode."); | |
//} | |
std::unique_ptr<llvm::MemoryBuffer> buffer = llvm::MemoryBuffer::getMemBuffer(bytecode, "BytecodeBuffer"); | |
std::unique_ptr<llvm::Module> bitecodeModule = llvm::cantFail(llvm::parseBitcodeFile(buffer.get()->getMemBufferRef(), *m_context.getContext())); | |
llvm::verifyModule(*bitecodeModule.get(), &llvm::errs()); | |
auto threadSafeModule = llvmOrc::ThreadSafeModule(std::move(bitecodeModule), m_context); | |
llvm::cantFail(m_lljit->addIRModule(std::move(threadSafeModule))); | |
{ | |
std::error_code EC; | |
llvm::raw_fd_ostream rawFileOstream("MainJITDylib.txt", EC); | |
m_lljit->getMainJITDylib().dump(rawFileOstream); | |
} | |
} | |
template <typename Type> | |
std::optional<std::function<Type>> getSymbolAsFunction(std::string name) | |
{ | |
llvm::Module llvmModule("TmpModule", *m_context.getContext()); | |
llvmModule.setDataLayout(m_dataLayout); | |
auto fctType = FunctionType<void()>::Get(*m_context.getContext()); | |
auto function = llvm::Function::Create(fctType, llvm::GlobalValue::ExternalLinkage, llvm::StringRef(name), llvmModule); | |
function->setCallingConv(llvm::CallingConv::X86_StdCall); | |
std::string mangled; | |
llvm::raw_string_ostream stringStream(mangled); | |
llvm::Mangler mangler; | |
mangler.getNameWithPrefix(stringStream, function, false); | |
stringStream.flush(); | |
function->eraseFromParent(); | |
m_lljit->runConstructors(); | |
auto jitEvaluatedSymbolExpected = m_lljit->lookupLinkerMangled(mangled); | |
if (!jitEvaluatedSymbolExpected) | |
{ | |
llvm::raw_os_ostream errStream{std::cerr}; | |
llvm::logAllUnhandledErrors(jitEvaluatedSymbolExpected.takeError(), errStream, {}); | |
return std::nullopt; | |
} | |
auto symbolAddress = jitEvaluatedSymbolExpected->getAddress(); | |
auto typedFunctionPtr = reinterpret_cast<Type*>(symbolAddress); | |
return std::function<Type>(typedFunctionPtr); | |
} | |
private: | |
std::unique_ptr<llvmOrc::LLJIT> m_lljit; | |
llvmOrc::ThreadSafeContext m_context; | |
llvm::DataLayout m_dataLayout; | |
}; |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment