// Author: Olivier Chafik (http://ochafik.com)
// Feel free to modify and reuse this for any purpose ("public domain / I don't care").
package scalaxy.pretyping.example

import scala.reflect.internal._
import scala.tools.nsc.CompilerCommand
import scala.tools.nsc.Global
import scala.tools.nsc.Phase
import scala.tools.nsc.plugins.Plugin
import scala.tools.nsc.plugins.PluginComponent
import scala.tools.nsc.Settings
import scala.tools.nsc.reporters.ConsoleReporter
import scala.tools.nsc.transform.TypingTransformers

/**
 *  This compiler plugin demonstrates how to do "useful" stuff before the typer phase.
 *  
 *  It defines a toy syntax that uses annotations to define implicit classes:
 *  
 *    @extend(Int) def toStr = self.toString
 *    
 *  Which gets desugared to:
 *  
 *    implicit class toStr(self: Int) { 
 *      def toStr = self.toString
 *    }
 *
 *  This example code doesn't try to be hygienic: it assumes @extend and Int are not locally redefined to something else.
 *
 *  To see the AST before and after the rewrite, run the compiler with -Xprint:parser -Xprint:twist.
 */
object TwistCompiler {
  private val scalaLibraryJar =
    classOf[List[_]].getProtectionDomain.getCodeSource.getLocation.getFile

  def main(args: Array[String]) {
    try {
      val settings = new Settings
      val command = 
        new CompilerCommand(List("-bootclasspath", scalaLibraryJar) ++ args, settings)

      if (!command.ok)
        System.exit(1)

      val global = new Global(settings, new ConsoleReporter(settings)) {
        override protected def computeInternalPhases() {
          super.computeInternalPhases
          phasesSet += new TwistComponent(this)
        }
      }
      new global.Run().compile(command.files)
    } catch { 
      case ex: Throwable =>
        ex.printStackTrace
        System.exit(2)
    }
  }
}

/**
 *  To use this, just write the following in `src/main/resources/scalac-plugin.xml`:
 *  <plugin>
 *    <name>twist-plugin</name>
 *    <classname>scalaxy.pretyping.example.TwistPlugin</classname>
 *  </plugin>
 */
class TwistPlugin(override val global: Global) extends Plugin {
  override val name = "twist"
  override val description = "Compiler plugin that adds a `@extend(Int) def toStr = self.toString` syntax to create extension methods."
  override val components: List[PluginComponent] =
    List(new TwistComponent(global))
}

/**
 *  To understand / reproduce this, you should use paulp's :power mode in the scala console:
 *  
 *  scala
 *  > :power
 *  > :phase parser // will show us ASTs just after parsing
 *  > val Some(List(ast)) = intp.parse("@extend(Int) def str = self.toString")
 *  > nodeToString(ast)
 *  > val DefDef(mods, name, tparams, vparamss, tpt, rhs) = ast // play with extractors to explore the tree and its properties.
 */ 
class TwistComponent(val global: Global)
    extends PluginComponent
    with TypingTransformers
{
  import global._
  import definitions._

  override val phaseName = "twister"

  override val runsRightAfter = Some("parser")
  override val runsAfter = runsRightAfter.toList
  override val runsBefore = List[String]("typer")

  def newPhase(prev: Phase): StdPhase = new StdPhase(prev) {
    def apply(unit: CompilationUnit) {
      val onTransformer = new TypingTransformer(unit) {
        override def transform(tree: Tree): Tree = tree match {
          case dd @ DefDef(Modifiers(flags, privateWithin, annotations), name, tparams, vparamss, tpt, rhs) =>
            annotations match {
              case List(Apply(Select(New(Ident(annotationName)), initName), List(Ident(typeName)))) if annotationName.toString == "extend" && initName == nme.CONSTRUCTOR =>
                val targetTpt = Ident(typeName.toString: TypeName)
                val emptyTpt = TypeTree(null)
                // If the type being extended is an AnyVal, make the implicit class a value class :-)
                val parentTypeName: TypeName = typeName.toString match {
                  case "Int" | "Long" | "Short" | "Byte" | "Double" | "Float" | "Char" | "Boolean" | "AnyVal" => 
                    "AnyVal"
                  case _ =>
                    "AnyRef"
                }
                
                ClassDef(
                  Modifiers(flags | Flag.IMPLICIT, privateWithin, Nil),
                  name.toString: TypeName,
                  tparams,
                  Template(
                    List(Select(Ident("scala": TermName), parentTypeName)),
                    // Don't even ask what this is.
                    ValDef(Modifiers(Flag.PRIVATE), "_": TermName, emptyTpt, EmptyTree),
                    List(
                      // <paramaccessor> private[this] val self: T = _;
                      ValDef(Modifiers(Flags.PARAMACCESSOR), "self", targetTpt, EmptyTree),
                      // def <init>(self: T) = { super.<init>(); () }
                      DefDef(
                        NoMods,
                        nme.CONSTRUCTOR,
                        Nil,
                        List(List(ValDef(NoMods, "self", targetTpt, EmptyTree))),
                        emptyTpt,
                        Block(
                          // super.<init>()
                          Apply(Select(Super(This("": TypeName), "": TypeName), nme.CONSTRUCTOR), Nil),
                          Literal(Constant(()))
                        )
                      ),
                      // Copying the original def over, without its annotation.
                      DefDef(Modifiers(flags, privateWithin, Nil), name, tparams, vparamss, tpt, rhs)
                    )
                  )
                )
              case _ =>
                super.transform(tree)
            }
          case _ =>
            super.transform(tree)
        }
      }
      unit.body = onTransformer.transform(unit.body)
    }
  }
}