Created
November 22, 2009 06:09
-
-
Save kmizu/240458 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import scala.io._ | |
import scala.util.parsing.combinator.syntactical._ | |
import org.objectweb.asm.ClassWriter | |
import org.objectweb.asm.MethodVisitor | |
import org.objectweb.asm.Label | |
import org.objectweb.asm.Opcodes._ | |
object Calculator extends StandardTokenParsers { | |
lexical.delimiters ++= List( | |
"(", ")","+","-","*","/", ">", "<", ">=", "<=", "=", ";" | |
) | |
lexical.reserved ++= List("if", "true", "false", "print", "while") | |
abstract sealed class Stmt | |
case class SetStmt(sym: Symbol, exp: Exp) extends Stmt | |
case class PrintStmt(exp: Exp) extends Stmt | |
case class ExpStmt(exp: Exp) extends Stmt | |
abstract sealed class Exp | |
case class Add(lhs: Exp, rhs: Exp) extends Exp | |
case class Sub(lhs: Exp, rhs: Exp) extends Exp | |
case class Mul(lhs: Exp, rhs: Exp) extends Exp | |
case class Div(lhs: Exp, rhs: Exp) extends Exp | |
case class Num(value: Int) extends Exp | |
case class Ident(sym: Symbol) extends Exp | |
class VarNotFoundException(msg: String) extends Exception(msg) | |
type Env = Map[Symbol, Int] | |
def prog: Parser[List[Stmt]] = rep(stmt) | |
def stmt: Parser[Stmt] = | |
(ident <~ "=") ~ exp <~ ";" ^^ { case l ~ r => SetStmt(Symbol(l), r) } | | |
"print" ~> exp <~ ";" ^^ { case e => PrintStmt(e) } | | |
exp <~ ";" ^^ { case e => ExpStmt(e) } | |
def exp: Parser[Exp] = add | |
def add :Parser[Exp] = { | |
mul ~ rep("+" ~ mul | "-" ~ mul) ^^ { case a ~ b => | |
b.foldLeft[Exp](a){(accum, e) => | |
e match { | |
case "+" ~ r => Add(accum, r) | |
case "-" ~ r => Sub(accum, r) | |
} | |
} | |
} | |
} | |
def mul :Parser[Exp] = { | |
prm ~ rep("*" ~ prm | "/" ~ prm) ^^ { case a ~ b => | |
b.foldLeft[Exp](a){(accum, e) => | |
e match { | |
case "*" ~ r => Mul(accum, r) | |
case "/" ~ r => Div(accum, r) | |
} | |
} | |
} | |
} | |
def prm :Parser[Exp] = { | |
numericLit ^^ { case n => Num(n.toInt) } | | |
"(" ~> exp <~ ")" | | |
ident ^^ { case n => Ident(Symbol(n)) } | |
} | |
def eval(stmt: Stmt, env: Env): (Int, Env) = { | |
def evalExp(exp: Exp) :Int = exp match { | |
case l Add r => evalExp(l) + evalExp(r) | |
case l Sub r => evalExp(l) - evalExp(r) | |
case l Mul r => evalExp(l) * evalExp(r) | |
case l Div r => evalExp(l) / evalExp(r) | |
case Num(v) => v | |
case Ident(n) => env.get(n) match { | |
case Some(v) => v | |
case None => | |
throw new VarNotFoundException("variable " + n.name + " not found") | |
} | |
} | |
stmt match { | |
case SetStmt(l, r) => | |
val result = evalExp(r); (result, env(l) = result) | |
case ExpStmt(e) => (evalExp(e), env) | |
case PrintStmt(e) => | |
val result = evalExp(e); println(result); | |
(result, env) | |
} | |
} | |
private def visitMethod(cw: ClassWriter)( | |
access: Int, name: String, desc: String, signature: String, | |
exceptions: Array[String] | |
)(block: MethodVisitor => Unit) { | |
val mv = cw.visitMethod(access, name, desc, signature, exceptions) | |
try { | |
block(mv) | |
} finally { | |
mv.visitEnd | |
} | |
} | |
def compileAndRun(prog: List[Stmt], className: Option[String]) { | |
val name = className.getOrElse("Main") | |
val cw = new ClassWriter(ClassWriter.COMPUTE_MAXS) | |
val indexMap = new scala.collection.mutable.HashMap[Symbol, Int] | |
var countVar = 1 // In instance methods, index 0 indicates `this` | |
cw.visit( | |
V1_1, ACC_PUBLIC, name, null, "java/lang/Object", new Array[String](0) | |
) | |
visitMethod(cw)(ACC_PUBLIC, "<init>", "()V", null, null){mv => | |
mv.visitVarInsn(ALOAD, 0) | |
mv.visitMethodInsn(INVOKESPECIAL, "java/lang/Object", "<init>", "()V") | |
mv.visitInsn(RETURN) | |
mv.visitMaxs(1, 1) | |
} | |
visitMethod(cw)(ACC_PUBLIC, "compute", "()V", null, null){mv => | |
def compileExp(arg: Exp) { | |
arg match { | |
case l Add r => | |
compileExp(l) | |
compileExp(r) | |
mv.visitInsn(IADD) | |
case l Sub r => | |
compileExp(l) | |
compileExp(r) | |
mv.visitInsn(ISUB) | |
case l Mul r => | |
compileExp(l) | |
compileExp(r) | |
mv.visitInsn(IMUL) | |
case l Div r => | |
compileExp(l) | |
compileExp(r) | |
mv.visitInsn(IDIV) | |
case Num(v) => | |
mv.visitLdcInsn(v.asInstanceOf[AnyRef]) | |
case Ident(sym) => | |
(indexMap get sym) match { | |
case Some(index) => | |
mv.visitVarInsn(ILOAD, index) | |
case None => | |
throw new VarNotFoundException("variable " + sym.name + " not found") | |
} | |
} | |
} | |
def compileStmt(arg: Stmt) { | |
arg match { | |
case SetStmt(varSym, exp) => | |
val index = indexMap.get(varSym).getOrElse { | |
indexMap(varSym) = countVar | |
countVar += 1; | |
countVar - 1 | |
} | |
compileExp(exp) | |
mv.visitVarInsn(ISTORE, index) | |
case PrintStmt(exp) => | |
mv.visitFieldInsn( | |
GETSTATIC, "java/lang/System", "out", "Ljava/io/PrintStream;" | |
) | |
compileExp(exp) | |
mv.visitMethodInsn( | |
INVOKEVIRTUAL, "java/io/PrintStream", "println", "(I)V" | |
) | |
case ExpStmt(exp) => | |
compileExp(exp) | |
mv.visitInsn(POP) | |
} | |
} | |
prog.foreach(compileStmt) | |
mv.visitInsn(RETURN) | |
mv.visitMaxs(0, 0) | |
} | |
val content = cw.toByteArray | |
val loader = new ClassLoader { | |
protected override def findClass(target: String) :Class[_] = { | |
if(target == name) | |
defineClass(content, 0, content.length) | |
else | |
super.findClass(target) | |
} | |
} | |
val target = loader.loadClass(name).newInstance.asInstanceOf[AnyRef] | |
val method = target.getClass.getMethod("compute") | |
method.invoke(target) | |
} | |
def main(args: Array[String]) { | |
if(args.length < 1) { | |
var line :String = null | |
var env :Env = Map[Symbol, Int]() | |
while({line = readLine("> "); line != null && line != ""}) try { | |
stmt(new lexical.Scanner(line)) match { | |
case Success(ast, _) => | |
val (result, newEnv) = eval(ast, env); env = newEnv | |
println(result) | |
case Failure(msg, _) => println(msg) | |
case Error(msg, _) => println(msg) | |
} | |
} catch { | |
case e:VarNotFoundException => println(e.getMessage) | |
} | |
return | |
} | |
val src = Source.fromFile(args(0)) | |
try { | |
prog(new lexical.Scanner(src mkString "")) match { | |
case Success(ast, _) => | |
compileAndRun(ast, None) | |
case Failure(msg, _) => println(msg) | |
case Error(msg, _) => println(msg) | |
} | |
} finally { | |
src.asInstanceOf[BufferedSource].close | |
} | |
} | |
} | |
Calculator.main(args) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment