Skip to content

Instantly share code, notes, and snippets.

@AdamBrouwersHarries
Created January 25, 2018 22:47
Show Gist options
  • Save AdamBrouwersHarries/f3dd65825dad7b1564eeab073f49f5a1 to your computer and use it in GitHub Desktop.
Save AdamBrouwersHarries/f3dd65825dad7b1564eeab073f49f5a1 to your computer and use it in GitHub Desktop.
object Tree {
class PrintBuilder(var content: String, var indent: Int) {
def +=(line: String) = {
content = content + "\n" + (" " * indent) + line
}
def in : Unit = {
indent = indent + 1
}
def out : Unit = {
indent = indent - 1
}
def get : String = {
content
}
}
// Define our trait that defines printing
trait ASTNode {
def visit[T](z: T)(visitFun: (T, ASTNode) => T) : T
def print(acc: PrintBuilder) : PrintBuilder
}
// make sending things through pipes easier
implicit class Pipe[A](a: A) {
def |>[B](f: A => B): B = f(a)
def pipe[B](f: A => B): B = f(a)
}
// Define a "top level" node class that we can inherit from
trait Node extends ASTNode
// a block of nodes
class Block(val functions: Seq[ASTNode]) extends Node {
override def visit[T](z: T)(visitFun: (T, ASTNode) => T) : T = {
z |>
// visit the function object
(visitFun(_, this)) |>
// visit the parameters
(functions.foldLeft(_) {
case (acc, node) => {
node.visit(acc)(visitFun)
}
})
}
// print a Block
def print(acc: PrintBuilder) : PrintBuilder = {
(acc : PrintBuilder) += "{"
acc.in
(functions.foldLeft(acc) {
case (acc, node) => {
node.print(acc)
}
})
acc.out
(acc: PrintBuilder) += "}"
acc
}
}
trait Fun extends Node {
val name: String
val code: String
override def visit[T](z: T)(visitFun: (T, ASTNode) => T) : T = {
visitFun(z, this)
}
}
// Our "Generic" main node, that we want our OpenCL node to inherit from
case class Function(val name: String, val code: String) extends Fun {
override def print(acc: PrintBuilder) : PrintBuilder = {
(acc : PrintBuilder) += s"${name}() {"
acc.in
(acc : PrintBuilder) += s"${code}"
acc.out
(acc : PrintBuilder) += "}"
acc
}
}
// An OpenCL trait for our OpenCL specific nodes
trait CLAttributes {
val attributes: String
}
// And an OpenCL node, derived from the Generic Node
case class OCLFunction(override val name: String, override val code: String, val attributes: String) extends Fun with CLAttributes {
override def print(acc: PrintBuilder) : PrintBuilder = {
(acc : PrintBuilder) += s"${attributes} ${name}() {"
acc.in
(acc : PrintBuilder) += s"${code}"
acc.out
(acc : PrintBuilder) += "}"
acc
}
}
def print(node: ASTNode) : String = {
node match {
case cn: OCLFunction => print(cn: OCLFunction)
case gn: Function => print(gn: Function)
case b: Block => print(b: Block)
case _ => "Shit, we can't print that!"
}
}
def print(thing: Function) : String =
s"I am a generic node, with code ${thing.code} and nothing else."
def print(thing: OCLFunction) : String =
s"I am a specialised OpenCL node, with code ${thing.code} and an extra attribute ${thing.attributes}."
def print(b: Block) : String =
s"I am a block, with some functions. To be precise: ${b.functions.length} functions. "
def main = {
val generic_node = new Function("generic_function","return 1.0;")
val cl_node = new OCLFunction("KERNEL", "//something in parallel", "__kernel")
val block = new Block(Seq(generic_node, cl_node))
println(block.visit("Result block: ")({case (acc, t) => acc + print(t) + "\n"}))
println(block.print(new PrintBuilder("", 0)).get)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment