Skip to content

Instantly share code, notes, and snippets.

@oowekyala
Last active June 14, 2020 23:05
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save oowekyala/17126fb441798dffd22cd3808e1c8cd3 to your computer and use it in GitHub Desktop.
Save oowekyala/17126fb441798dffd22cd3808e1c8cd3 to your computer and use it in GitHub Desktop.
Brainfuck to C "optimizing compiler"
#!/bin/sh
exec scala "$0" "$@"
!#
/*
A compiler from Brainfuck to C, with some peephole optimizations.
Usage (no compilation required):
./bfc.scala -f <brainfuck source file> -o <output file name> (<gcc flag>)*
*/
import java.nio.file.{Files, Paths}
import scala.collection.mutable.ListBuffer
import scala.io.Source
import scala.language.postfixOps
object Bfc {
val Skeleton =
"""
|
|#include <stdio.h>
|
|#define tapelen 4096
|
|unsigned char tape[tapelen] = {0};
|
|int main()
|{
|unsigned char* ptr = tape;
|
|%code
|
|}
|
|""".stripMargin
sealed trait Node {
def stepOptimized: Node = this
def optimized: Node = {
var prev = this
var step = stepOptimized
while (step != prev) {
prev = step
step = prev.stepOptimized
}
step
}
def emit(lineBreaker: Int => String => Unit, out: String => Unit, depth: Int): Unit = {
def lineBreak(): Unit = lineBreaker(depth)("")
def line(s: String): Unit = {
out(s)
lineBreak()
}
this match {
case Incr(lhs, diff) => lhs.emit(out); out(" += "); diff.emit(out); out(";"); lineBreak()
case Set(lhs, diff) => lhs.emit(out); out(" = "); diff.emit(out); out(";"); lineBreak()
case Print(v) => out("putchar("); v.emit(out); out(");"); lineBreak()
case Read(into) => into.emit(out); out(" = getchar();"); lineBreak()
case Loop(seq) =>
line("while(*ptr) {")
seq.foreach {
_.emit(lineBreaker, out, depth + 1)
}
line("}")
case Program(seq) =>
val List(before, after) = Skeleton.split("%code").toList
line(before)
seq.foreach {
_.emit(lineBreaker, out, depth + 1)
}
line(after)
}
}
}
sealed trait Atomic extends Node
sealed trait Write extends Atomic
/** Set the lval to the given val. */
case class Set(lhs: LVal, value: Val) extends Write
/**
* Increment the lval by the given amount.
* Move instructions (`><`) are represented as Incr([[Ptr]], _),
* and increment instructions (`+-`) are represented as Incr([[Deref]], _).
*/
case class Incr(lhs: LVal, diff: Val) extends Write
case class Print(v: Val) extends Atomic
case class Read(into: LVal) extends Atomic
sealed trait Val {
def +(v: Val): Val = (this, v) match {
case (Imm(a), Imm(b)) => Imm(a + b)
case (Imm(0), b) => b
case (a, Imm(0)) => a
case _ => Add(this, v)
}
def *(v: Val): Val = (this, v) match {
case (Imm(a), Imm(b)) => Imm(a * b)
case (Imm(1), b) => b
case (a, Imm(1)) => a
case _ => Mul(this, v)
}
def emit(out: String => Unit): Unit = {
this match {
case Imm(i) => out(i.toString)
case Add(a, b) => out("("); a.emit(out); out(" + "); b.emit(out); out(")");
case Mul(a, b) => out("("); a.emit(out); out(" * "); b.emit(out); out(")");
case Ptr => out("ptr")
case Deref(Zero) => out("*ptr")
case Deref(offset) => out("*(ptr + "); offset.emit(out); out(")")
}
}
}
sealed trait LVal extends Val
/** Immediate value. */
case class Imm(int: Int) extends Val
case class Add(lhs: Val, rhs: Val) extends Val
case class Mul(lhs: Val, rhs: Val) extends Val
val Zero = Imm(0)
// lvalues
/** The pointer value, not dereferenced. */
case object Ptr extends LVal
/** Dereference the pointer, and add the given [offset]. */
case class Deref(offset: Val) extends LVal
val PtrDeref = Deref(Zero)
sealed abstract class Sub(seq: List[Node]) extends Node {
override def stepOptimized: Sub = {
val newSeq = foldPrefix(seq) {
case Incr(_, Zero) :: ns => (Nil, ns)
case Set(l1, a) :: Incr(l2, b) :: ns if l1 == l2 => (List(Set(l1, a + b)), ns)
// collapse sequences of those, this also takes care of move
case Incr(l1, d1) :: Incr(l2, d2) :: ns if l1 == l2 => (List(Incr(l1, d1 + d2)), ns)
// going somewhere to do something and back
case original @ (Incr(Ptr, off @ Imm(a)) :: (action : Atomic) :: Incr(Ptr, Imm(b)) :: ns) if a == -b =>
action match {
case Incr(Deref(poff), d) => (List(Incr(Deref(off + poff), d)), ns)
case Set(Deref(poff), d) => (List(Set(Deref(off + poff), d)), ns)
case Print(Deref(poff)) => (List(Print(Deref(off + poff))), ns)
case Read(Deref(poff)) => (List(Read(Deref(off + poff))), ns)
case _ => (List(), original)
}
// loop until over-/underflow
case Loop(Incr(PtrDeref, Imm(a)) :: Nil) :: ns if a.abs == 1 => (Set(PtrDeref, Zero) :: Nil, ns)
case Loop(Incr(PtrDeref, Imm(-1)) :: rest) :: ns
// ie the pointer is not moved but just used to increment some values
if rest.forall {
case Incr(_: Deref, _) | Set(_: Deref, _) => true
case _ => false
} =>
(rest.map {
case Incr(p: Deref, value) => Incr(p, value * PtrDeref)
case Set(p: Deref, value) => Set(p, value * PtrDeref)
case _ => throw new IllegalStateException()
} :+ Set(PtrDeref, Zero), ns)
case (loop: Loop) :: ns => (List(loop.stepOptimized), ns) // recurse
}
clone(newSeq)
}
def clone(seq: List[Node]): Sub
def foldPrefix[V](vs: List[V])(f: PartialFunction[List[V], (List[V], List[V])]): List[V] = {
val buffer = ListBuffer[V]()
var remaining = vs
while (remaining.nonEmpty) {
var progress = false
if (f.isDefinedAt(remaining)) {
val (mapped, rest) = f(remaining)
progress = rest != vs
if (progress) {
remaining = rest
buffer.appendAll(mapped)
}
}
if (!progress) {
buffer.append(remaining.head)
remaining = remaining.tail
}
}
buffer.toList
}
}
case class Loop(seq: List[Node]) extends Sub(seq) {
override def clone(seq: List[Node]) = Loop(seq)
}
case class Program(seq: List[Node]) extends Sub(seq) {
override def clone(seq: List[Node]) = Program(seq)
}
object any2stringadd
def main(args: Array[String]): Unit = {
val wd = Paths.get(System.getProperty("user.dir"))
val (source, oFile, rest) = args.toList match {
case "-f" :: fname :: "-o" :: oname :: tl => (Source.fromFile(wd.resolve(fname).toFile), wd.resolve(oname), tl)
case head :: "-o" :: oname :: tl => (Source.fromString(head), wd.resolve(oname), tl)
case head :: tl => (Source.fromString(head), wd.resolve("bf"), tl)
case _ => throw new IllegalArgumentException("Must provide some input!")
}
val prog: Sub = parse(source)
val opt = prog.optimized
val fullCode: String = emit(opt)
val tmp = Files.createTempDirectory("brainfuck-tmp")
val cFile = tmp.resolve("program.c")
val writer = Files.newBufferedWriter(cFile)
writer.write(fullCode)
writer.close()
import scala.sys.process._
s"gcc $cFile -o $oFile ${rest.mkString(" ")}" !!
println(s"C output in $cFile")
println(s"Run $oFile")
}
def parse(source: Source): Program = {
var loopStack = List(ListBuffer[Node]())
source.mkString.foreach {
case '+' => loopStack.head += Incr(PtrDeref, Imm(1))
case '-' => loopStack.head += Incr(PtrDeref, Imm(-1))
case '>' => loopStack.head += Incr(Ptr, Imm(1))
case '<' => loopStack.head += Incr(Ptr, Imm(-1))
case '.' => loopStack.head += Print(PtrDeref)
case ',' => loopStack.head += Read(PtrDeref)
case '[' => loopStack = ListBuffer[Node]() :: loopStack
case ']' =>
val top = Loop(loopStack.head.toList)
loopStack = loopStack.tail
loopStack.head += top
case _ => /* Ignore */
}
source.close()
Program(loopStack.head.toList)
}
private def emit(prog: Node): String = {
val builder = new StringBuilder()
def lineBreak(depth: Int)(line: String): Unit = {
builder.append("\n")
for (_ <- 0 to depth) builder.append(" ") // indent
}
prog.emit(lineBreak, {
builder.append(_)
}, 0)
builder.toString()
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment