Skip to content

Instantly share code, notes, and snippets.

@oopsmishap
Created May 26, 2024 18:57
Show Gist options
  • Save oopsmishap/622734a7623e8a8124e002b357e6fc10 to your computer and use it in GitHub Desktop.
Save oopsmishap/622734a7623e8a8124e002b357e6fc10 to your computer and use it in GitHub Desktop.
#pragma once
#include <memory>
#include <string>
#include <vector>
#include <fmt/core.h>
#include <stdexcept>
#include <zasm/zasm.hpp>
#include <windows.h>
using namespace zasm;
class ASTNode
{
public:
virtual ~ASTNode() = default;
virtual void print() const = 0;
virtual void optimize(std::vector<std::unique_ptr<ASTNode>>& optimized) const = 0;
};
using Nodes = std::vector<std::unique_ptr<ASTNode>>;
void optimizeSequence(const Nodes& nodes, Nodes& optimized);
class Operation : public ASTNode
{
public:
explicit Operation(int count) : value(count)
{
}
void print() const override
{
fmt::println("Operation({})", value);
}
void optimize(Nodes& optimized) const override
{
if (value != 0)
{
optimized.push_back(std::make_unique<Operation>(*this));
}
}
int value;
};
class Move : public ASTNode
{
public:
explicit Move(int count) : value(count)
{
}
void print() const override
{
fmt::println("Move({})", value);
}
void optimize(Nodes& optimized) const override
{
if (value != 0)
{
optimized.push_back(std::make_unique<Move>(*this));
}
}
int value;
};
class Output : public ASTNode
{
public:
void print() const override
{
fmt::println("Output()");
}
void optimize(Nodes& optimized) const override
{
optimized.push_back(std::make_unique<Output>(*this));
}
};
class Input : public ASTNode
{
public:
void print() const override
{
fmt::println("Input()");
}
void optimize(Nodes& optimized) const override
{
optimized.push_back(std::make_unique<Input>(*this));
}
};
class SetToZero : public ASTNode
{
public:
void print() const override
{
fmt::println("SetToZero()");
}
void optimize(Nodes& optimized) const override
{
optimized.push_back(std::make_unique<SetToZero>(*this));
}
};
int indent = 0;
class Loop : public ASTNode
{
public:
explicit Loop(Nodes body) : body(std::move(body))
{
}
void print() const override
{
fmt::println("{} LoopStart()", std::string(indent, ' '));
indent++;
for (const auto& node : body)
{
fmt::print("{}", std::string(indent, ' '));
node->print();
}
indent--;
fmt::println("{} LoopEnd()", std::string(indent, ' '));
}
void optimize(Nodes& optimized) const override
{
Nodes optimizedBody;
optimizeSequence(body, optimizedBody);
if (!optimizedBody.empty())
{
optimized.push_back(std::make_unique<Loop>(std::move(optimizedBody)));
}
}
Nodes body;
};
void optimizeSequence(const Nodes& nodes, Nodes& optimized)
{
int moveCount = 0;
int opCount = 0;
auto pushMove = [&optimized](int count)
{
if (count != 0)
{
optimized.push_back(std::make_unique<Move>(count));
}
};
auto pushOperation = [&optimized](int count)
{
if (count != 0)
{
optimized.push_back(std::make_unique<Operation>(count));
}
};
auto finalizeCurrent = [&]()
{
pushMove(moveCount);
pushOperation(opCount);
moveCount = 0;
opCount = 0;
};
for (const auto& node : nodes)
{
if (auto move = dynamic_cast<Move*>(node.get()))
{
finalizeCurrent();
moveCount += move->value;
}
else if (auto op = dynamic_cast<Operation*>(node.get()))
{
finalizeCurrent();
opCount += op->value;
}
else if( auto loop = dynamic_cast< Loop* >( node.get() ) )
{
// Check if the loop is a SetToZero pattern
if (loop->body.size() == 1)
{
if (auto loopOp = dynamic_cast<Operation*>(loop->body.front().get()))
{
if (loopOp->value == -1)
{
finalizeCurrent();
optimized.push_back(std::make_unique<SetToZero>());
continue;
}
}
}
finalizeCurrent();
Nodes optimizedBody;
optimizeSequence(loop->body, optimizedBody);
optimized.push_back(std::make_unique<Loop>(std::move(optimizedBody)));
}
else
{
finalizeCurrent();
node->optimize(optimized);
}
}
finalizeCurrent();
}
Nodes optimizeAST(const Nodes& nodes)
{
Nodes optimized;
optimizeSequence(nodes, optimized);
return optimized;
}
class Lexer
{
public:
explicit Lexer(const std::string& source) : source(source), pos(0)
{
}
char next()
{
char c = peek();
if (c != '\0')
advance();
return c;
}
char peek()
{
while (pos < source.size() && !isValidCharacter(source[pos]))
{
advance();
}
return (pos < source.size()) ? source[pos] : '\0';
}
private:
std::string source;
size_t pos;
const std::string validChars = "+-<>[].,";
void advance()
{
++pos;
}
bool isValidCharacter(char c) const
{
return validChars.find(c) != std::string::npos;
}
};
class Parser
{
public:
explicit Parser(Lexer& lexer) : lexer(lexer)
{
}
Nodes parse()
{
Nodes nodes;
char token = 0;
while ((token = lexer.next()) != '\0')
{
nodes.push_back(parseToken(token));
}
return nodes;
}
private:
Lexer& lexer;
std::unique_ptr<ASTNode> parseToken(char token)
{
int count = 1;
switch (token)
{
case '+':
while (lexer.peek() == '+')
{
lexer.next();
++count;
}
return std::make_unique<Operation>(count);
case '-':
while (lexer.peek() == '-')
{
lexer.next();
++count;
}
return std::make_unique<Operation>(-count);
case '>':
while (lexer.peek() == '>')
{
lexer.next();
++count;
}
return std::make_unique<Move>(count);
case '<':
while (lexer.peek() == '<')
{
lexer.next();
++count;
}
return std::make_unique<Move>(-count);
case '.':
return std::make_unique<Output>();
case ',':
return std::make_unique<Input>();
case '[':
return parseLoop();
default:
throw std::runtime_error("Invalid token");
}
}
std::unique_ptr<ASTNode> parseLoop()
{
Nodes nodes;
char token;
while ((token = lexer.next()) != ']')
{
if (token == '\0')
{
throw std::runtime_error("Unmatched '['");
}
nodes.push_back(parseToken(token));
}
return std::make_unique<Loop>(std::move(nodes));
}
};
struct BrainFuckGenerator
{
using FunctionType = void(__fastcall*)(void*);
static void generate(Label labelFunc, Program& program, const Nodes& nodes)
{
x86::Assembler a(program);
a.bind(labelFunc);
a.mov(x86::rsi, x86::rcx);
a.mov(x86::r14, reinterpret_cast<uintptr_t>(&putchar));
a.mov(x86::r15, reinterpret_cast<uintptr_t>(&getchar));
generateAsm(a, nodes);
a.ret();
}
private:
static void generateAsm(x86::Assembler& a, const Nodes& nodes)
{
for (const auto& node : nodes)
{
if (auto move = dynamic_cast<Move*>(node.get()))
{
if (move->value > 0)
{
a.add(x86::rsi, move->value);
}
else
{
a.sub(x86::rsi, -move->value);
}
}
else if (auto op = dynamic_cast<Operation*>(node.get()))
{
if (op->value > 0)
{
a.add(x86::byte_ptr(x86::rsi), op->value);
}
else
{
a.sub(x86::byte_ptr(x86::rsi), -op->value);
}
}
else if (auto output = dynamic_cast<Output*>(node.get()))
{
a.mov(x86::cl, x86::byte_ptr(x86::rsi));
a.call(x86::r14);
}
else if (auto input = dynamic_cast<Input*>(node.get()))
{
a.call(x86::r15);
a.mov(x86::byte_ptr(x86::rsi), x86::al);
}
else if (auto setToZero = dynamic_cast<SetToZero*>(node.get()))
{
a.mov(x86::byte_ptr(x86::rsi), 0);
}
else if (auto loop = dynamic_cast<Loop*>(node.get()))
{
auto startLabel = a.createLabel();
auto endLabel = a.createLabel();
a.bind(startLabel);
a.cmp(x86::byte_ptr(x86::rsi), 0);
a.jz(endLabel);
generateAsm(a, loop->body);
a.jmp(startLabel);
a.bind(endLabel);
}
}
}
};
size_t estimateCodeSize(const Program& program)
{
std::size_t size = 0;
for (auto* node = program.getHead(); node != nullptr; node = node->getNext())
{
if (auto* nodeData = node->getIf<Data>(); nodeData != nullptr)
{
size += nodeData->getTotalSize();
}
else if (auto* nodeInstr = node->getIf<Instruction>(); nodeInstr != nullptr)
{
const auto& instrInfo = nodeInstr->getDetail(program.getMode());
if (instrInfo.hasValue())
{
size += instrInfo->getLength();
}
else
{
fmt::println("Error: Unable to get instruction info");
}
}
else if (auto* nodeEmbeddedLabel = node->getIf<EmbeddedLabel>(); nodeEmbeddedLabel != nullptr)
{
const auto bitSize = nodeEmbeddedLabel->getSize();
if (bitSize == BitSize::_32)
size += 4;
if (bitSize == BitSize::_64)
size += 8;
}
}
return size;
}
void* allocate(size_t size)
{
void* ptr = VirtualAlloc(nullptr, size, MEM_COMMIT | MEM_RESERVE, PAGE_EXECUTE_READWRITE);
if (ptr == nullptr)
{
throw std::runtime_error("Failed to allocate memory");
}
return ptr;
}
void jitExecute(const Nodes& nodes)
{
Program program(MachineMode::AMD64);
auto labelFunc = program.createLabel("BrainFuck");
BrainFuckGenerator::generate(labelFunc, program, nodes);
const auto codeSize = estimateCodeSize(program);
void* code = VirtualAlloc(nullptr, codeSize, MEM_COMMIT | MEM_RESERVE, PAGE_EXECUTE_READWRITE);
if (code == nullptr)
{
throw std::runtime_error("Failed to allocate memory");
}
Serializer serializer;
if (auto err = serializer.serialize(program, reinterpret_cast<int64_t>(code)); err != zasm::ErrorCode::None)
{
fmt::println("Serialization failure: {}", err.getErrorName());
throw std::runtime_error("Serialization failure");
}
memcpy(code, serializer.getCode(), serializer.getCodeSize());
const auto funcAddress = serializer.getLabelAddress(labelFunc.getId());
assert(funcAddress != -1);
const auto brainFuckEntry = reinterpret_cast<BrainFuckGenerator::FunctionType>(funcAddress);
std::vector<uint8_t> memory(30000, 0);
brainFuckEntry(memory.data());
VirtualFree(code, 0, MEM_RELEASE);
}
#include "brainfuck.h"
#include <iostream>
#include <fstream>
#include <sstream>
int main(int argc, char** argv)
{
if (argc < 2)
{
fmt::println("Usage: {} <file>", argv[0]);
return 1;
}
std::ifstream file(argv[1]);
if (!file)
{
fmt::println("Failed to open file: {}", argv[1]);
return 1;
}
std::stringstream buffer;
buffer << file.rdbuf();
std::string fileContents = buffer.str();
Lexer lexer(fileContents);
Parser parser(lexer);
auto nodes = parser.parse();
auto optimized = optimizeAST(nodes);
/*for( const auto& node : optimized )
{
node->print();
}*/
jitExecute(optimized);
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment