Skip to content

Instantly share code, notes, and snippets.

@kmizu
Created November 22, 2009 06:09
Show Gist options
  • Save kmizu/240458 to your computer and use it in GitHub Desktop.
Save kmizu/240458 to your computer and use it in GitHub Desktop.
import scala.io._
import scala.util.parsing.combinator.syntactical._
import org.objectweb.asm.ClassWriter
import org.objectweb.asm.MethodVisitor
import org.objectweb.asm.Label
import org.objectweb.asm.Opcodes._
object Calculator extends StandardTokenParsers {
lexical.delimiters ++= List(
"(", ")","+","-","*","/", ">", "<", ">=", "<=", "=", ";"
)
lexical.reserved ++= List("if", "true", "false", "print", "while")
abstract sealed class Stmt
case class SetStmt(sym: Symbol, exp: Exp) extends Stmt
case class PrintStmt(exp: Exp) extends Stmt
case class ExpStmt(exp: Exp) extends Stmt
abstract sealed class Exp
case class Add(lhs: Exp, rhs: Exp) extends Exp
case class Sub(lhs: Exp, rhs: Exp) extends Exp
case class Mul(lhs: Exp, rhs: Exp) extends Exp
case class Div(lhs: Exp, rhs: Exp) extends Exp
case class Num(value: Int) extends Exp
case class Ident(sym: Symbol) extends Exp
class VarNotFoundException(msg: String) extends Exception(msg)
type Env = Map[Symbol, Int]
def prog: Parser[List[Stmt]] = rep(stmt)
def stmt: Parser[Stmt] =
(ident <~ "=") ~ exp <~ ";" ^^ { case l ~ r => SetStmt(Symbol(l), r) } |
"print" ~> exp <~ ";" ^^ { case e => PrintStmt(e) } |
exp <~ ";" ^^ { case e => ExpStmt(e) }
def exp: Parser[Exp] = add
def add :Parser[Exp] = {
mul ~ rep("+" ~ mul | "-" ~ mul) ^^ { case a ~ b =>
b.foldLeft[Exp](a){(accum, e) =>
e match {
case "+" ~ r => Add(accum, r)
case "-" ~ r => Sub(accum, r)
}
}
}
}
def mul :Parser[Exp] = {
prm ~ rep("*" ~ prm | "/" ~ prm) ^^ { case a ~ b =>
b.foldLeft[Exp](a){(accum, e) =>
e match {
case "*" ~ r => Mul(accum, r)
case "/" ~ r => Div(accum, r)
}
}
}
}
def prm :Parser[Exp] = {
numericLit ^^ { case n => Num(n.toInt) } |
"(" ~> exp <~ ")" |
ident ^^ { case n => Ident(Symbol(n)) }
}
def eval(stmt: Stmt, env: Env): (Int, Env) = {
def evalExp(exp: Exp) :Int = exp match {
case l Add r => evalExp(l) + evalExp(r)
case l Sub r => evalExp(l) - evalExp(r)
case l Mul r => evalExp(l) * evalExp(r)
case l Div r => evalExp(l) / evalExp(r)
case Num(v) => v
case Ident(n) => env.get(n) match {
case Some(v) => v
case None =>
throw new VarNotFoundException("variable " + n.name + " not found")
}
}
stmt match {
case SetStmt(l, r) =>
val result = evalExp(r); (result, env(l) = result)
case ExpStmt(e) => (evalExp(e), env)
case PrintStmt(e) =>
val result = evalExp(e); println(result);
(result, env)
}
}
private def visitMethod(cw: ClassWriter)(
access: Int, name: String, desc: String, signature: String,
exceptions: Array[String]
)(block: MethodVisitor => Unit) {
val mv = cw.visitMethod(access, name, desc, signature, exceptions)
try {
block(mv)
} finally {
mv.visitEnd
}
}
def compileAndRun(prog: List[Stmt], className: Option[String]) {
val name = className.getOrElse("Main")
val cw = new ClassWriter(ClassWriter.COMPUTE_MAXS)
val indexMap = new scala.collection.mutable.HashMap[Symbol, Int]
var countVar = 1 // In instance methods, index 0 indicates `this`
cw.visit(
V1_1, ACC_PUBLIC, name, null, "java/lang/Object", new Array[String](0)
)
visitMethod(cw)(ACC_PUBLIC, "<init>", "()V", null, null){mv =>
mv.visitVarInsn(ALOAD, 0)
mv.visitMethodInsn(INVOKESPECIAL, "java/lang/Object", "<init>", "()V")
mv.visitInsn(RETURN)
mv.visitMaxs(1, 1)
}
visitMethod(cw)(ACC_PUBLIC, "compute", "()V", null, null){mv =>
def compileExp(arg: Exp) {
arg match {
case l Add r =>
compileExp(l)
compileExp(r)
mv.visitInsn(IADD)
case l Sub r =>
compileExp(l)
compileExp(r)
mv.visitInsn(ISUB)
case l Mul r =>
compileExp(l)
compileExp(r)
mv.visitInsn(IMUL)
case l Div r =>
compileExp(l)
compileExp(r)
mv.visitInsn(IDIV)
case Num(v) =>
mv.visitLdcInsn(v.asInstanceOf[AnyRef])
case Ident(sym) =>
(indexMap get sym) match {
case Some(index) =>
mv.visitVarInsn(ILOAD, index)
case None =>
throw new VarNotFoundException("variable " + sym.name + " not found")
}
}
}
def compileStmt(arg: Stmt) {
arg match {
case SetStmt(varSym, exp) =>
val index = indexMap.get(varSym).getOrElse {
indexMap(varSym) = countVar
countVar += 1;
countVar - 1
}
compileExp(exp)
mv.visitVarInsn(ISTORE, index)
case PrintStmt(exp) =>
mv.visitFieldInsn(
GETSTATIC, "java/lang/System", "out", "Ljava/io/PrintStream;"
)
compileExp(exp)
mv.visitMethodInsn(
INVOKEVIRTUAL, "java/io/PrintStream", "println", "(I)V"
)
case ExpStmt(exp) =>
compileExp(exp)
mv.visitInsn(POP)
}
}
prog.foreach(compileStmt)
mv.visitInsn(RETURN)
mv.visitMaxs(0, 0)
}
val content = cw.toByteArray
val loader = new ClassLoader {
protected override def findClass(target: String) :Class[_] = {
if(target == name)
defineClass(content, 0, content.length)
else
super.findClass(target)
}
}
val target = loader.loadClass(name).newInstance.asInstanceOf[AnyRef]
val method = target.getClass.getMethod("compute")
method.invoke(target)
}
def main(args: Array[String]) {
if(args.length < 1) {
var line :String = null
var env :Env = Map[Symbol, Int]()
while({line = readLine("> "); line != null && line != ""}) try {
stmt(new lexical.Scanner(line)) match {
case Success(ast, _) =>
val (result, newEnv) = eval(ast, env); env = newEnv
println(result)
case Failure(msg, _) => println(msg)
case Error(msg, _) => println(msg)
}
} catch {
case e:VarNotFoundException => println(e.getMessage)
}
return
}
val src = Source.fromFile(args(0))
try {
prog(new lexical.Scanner(src mkString "")) match {
case Success(ast, _) =>
compileAndRun(ast, None)
case Failure(msg, _) => println(msg)
case Error(msg, _) => println(msg)
}
} finally {
src.asInstanceOf[BufferedSource].close
}
}
}
Calculator.main(args)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment