Skip to content

Instantly share code, notes, and snippets.

@ochafik
Last active December 13, 2015 21:38
Show Gist options
  • Save ochafik/4978627 to your computer and use it in GitHub Desktop.
Save ochafik/4978627 to your computer and use it in GitHub Desktop.
// Author: Olivier Chafik (http://ochafik.com)
// Feel free to modify and reuse this for any purpose ("public domain / I don't care").
package scalaxy.pretyping.example
import scala.reflect.internal._
import scala.tools.nsc.CompilerCommand
import scala.tools.nsc.Global
import scala.tools.nsc.Phase
import scala.tools.nsc.plugins.Plugin
import scala.tools.nsc.plugins.PluginComponent
import scala.tools.nsc.Settings
import scala.tools.nsc.reporters.ConsoleReporter
import scala.tools.nsc.transform.TypingTransformers
/**
* This compiler plugin demonstrates how to do "useful" stuff before the typer phase.
*
* It defines a toy syntax that uses annotations to define implicit classes:
*
* @extend(Int) def toStr = self.toString
*
* Which gets desugared to:
*
* implicit class toStr(self: Int) {
* def toStr = self.toString
* }
*
* This example code doesn't try to be hygienic: it assumes @extend and Int are not locally redefined to something else.
*
* To see the AST before and after the rewrite, run the compiler with -Xprint:parser -Xprint:twist.
*/
object TwistCompiler {
private val scalaLibraryJar =
classOf[List[_]].getProtectionDomain.getCodeSource.getLocation.getFile
def main(args: Array[String]) {
try {
val settings = new Settings
val command =
new CompilerCommand(List("-bootclasspath", scalaLibraryJar) ++ args, settings)
if (!command.ok)
System.exit(1)
val global = new Global(settings, new ConsoleReporter(settings)) {
override protected def computeInternalPhases() {
super.computeInternalPhases
phasesSet += new TwistComponent(this)
}
}
new global.Run().compile(command.files)
} catch {
case ex: Throwable =>
ex.printStackTrace
System.exit(2)
}
}
}
/**
* To use this, just write the following in `src/main/resources/scalac-plugin.xml`:
* <plugin>
* <name>twist-plugin</name>
* <classname>scalaxy.pretyping.example.TwistPlugin</classname>
* </plugin>
*/
class TwistPlugin(override val global: Global) extends Plugin {
override val name = "twist"
override val description = "Compiler plugin that adds a `@extend(Int) def toStr = self.toString` syntax to create extension methods."
override val components: List[PluginComponent] =
List(new TwistComponent(global))
}
/**
* To understand / reproduce this, you should use paulp's :power mode in the scala console:
*
* scala
* > :power
* > :phase parser // will show us ASTs just after parsing
* > val Some(List(ast)) = intp.parse("@extend(Int) def str = self.toString")
* > nodeToString(ast)
* > val DefDef(mods, name, tparams, vparamss, tpt, rhs) = ast // play with extractors to explore the tree and its properties.
*/
class TwistComponent(val global: Global)
extends PluginComponent
with TypingTransformers
{
import global._
import definitions._
override val phaseName = "twister"
override val runsRightAfter = Some("parser")
override val runsAfter = runsRightAfter.toList
override val runsBefore = List[String]("typer")
def newPhase(prev: Phase): StdPhase = new StdPhase(prev) {
def apply(unit: CompilationUnit) {
val onTransformer = new TypingTransformer(unit) {
override def transform(tree: Tree): Tree = tree match {
case dd @ DefDef(Modifiers(flags, privateWithin, annotations), name, tparams, vparamss, tpt, rhs) =>
annotations match {
case List(Apply(Select(New(Ident(annotationName)), initName), List(Ident(typeName)))) if annotationName.toString == "extend" && initName == nme.CONSTRUCTOR =>
val targetTpt = Ident(typeName.toString: TypeName)
val emptyTpt = TypeTree(null)
// If the type being extended is an AnyVal, make the implicit class a value class :-)
val parentTypeName: TypeName = typeName.toString match {
case "Int" | "Long" | "Short" | "Byte" | "Double" | "Float" | "Char" | "Boolean" | "AnyVal" =>
"AnyVal"
case _ =>
"AnyRef"
}
ClassDef(
Modifiers(flags | Flag.IMPLICIT, privateWithin, Nil),
name.toString: TypeName,
tparams,
Template(
List(Select(Ident("scala": TermName), parentTypeName)),
// Don't even ask what this is.
ValDef(Modifiers(Flag.PRIVATE), "_": TermName, emptyTpt, EmptyTree),
List(
// <paramaccessor> private[this] val self: T = _;
ValDef(Modifiers(Flags.PARAMACCESSOR), "self", targetTpt, EmptyTree),
// def <init>(self: T) = { super.<init>(); () }
DefDef(
NoMods,
nme.CONSTRUCTOR,
Nil,
List(List(ValDef(NoMods, "self", targetTpt, EmptyTree))),
emptyTpt,
Block(
// super.<init>()
Apply(Select(Super(This("": TypeName), "": TypeName), nme.CONSTRUCTOR), Nil),
Literal(Constant(()))
)
),
// Copying the original def over, without its annotation.
DefDef(Modifiers(flags, privateWithin, Nil), name, tparams, vparamss, tpt, rhs)
)
)
)
case _ =>
super.transform(tree)
}
case _ =>
super.transform(tree)
}
}
unit.body = onTransformer.transform(unit.body)
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment