Skip to content

Instantly share code, notes, and snippets.

@steinybot
Last active May 8, 2016 23:37
Show Gist options
  • Save steinybot/23411c5adfe86864697c16d73d0dff9c to your computer and use it in GitHub Desktop.
Save steinybot/23411c5adfe86864697c16d73d0dff9c to your computer and use it in GitHub Desktop.
Class for providing variable substitution in a command at runtime.
import scala.language.experimental.macros
import scala.reflect.macros.blackbox
/**
* Macro bundle for case class macros.
*/
class CaseClassMacrosImpl(val c: blackbox.Context) {
import c.universe._
/**
* Creates a new conversion class that provides evidence for converting a case class to a sequence of `String`.
*
* @tparam T the type of the case class
* @return the conversion type
*/
def fieldConverter[T: c.WeakTypeTag]: c.Expr[CaseClassToSeq[T]] = {
val symbol = c.weakTypeOf[T].typeSymbol
val traitSymbol = c.weakTypeOf[CaseClassToSeq[T]].typeSymbol
val name = c.freshName(traitSymbol.name.toTypeName)
val fields = fieldsOf(c.Expr[T](q"caseClass"))
val result =
q"""final class $name extends $traitSymbol[$symbol] {
override def toSeq(caseClass: $symbol): Seq[String] = $fields
}
new $name
"""
c.Expr[CaseClassToSeq[T]](result)
}
/**
* Retrieves the fields of a case class as a sequence in the order that they were defined.
*
* @tparam T the type of the case class
* @return the fields of the case class
*/
def fieldsOf[T: c.WeakTypeTag](caseClass: c.Expr[T]): c.Expr[Seq[String]] = {
val classType = caseClassOf[T]
val getters = caseFieldGetters(classType)
val result = withCachedExpr(caseClass) { term =>
val fieldsAsStrings = getters.map(param => q"$term.$param.toString")
q"Seq(..$fieldsAsStrings)"
}
c.Expr[Seq[String]](result)
}
/**
* Determines the `Type` of the given case class or aborts if it is not a case class.
*
* @tparam T the case class type tag
* @return the weak type of the case class
*/
def caseClassOf[T: c.WeakTypeTag]: c.Type = {
val classType = c.weakTypeOf[T]
val symbol = classType.typeSymbol
if (!symbol.isClass || !symbol.asClass.isCaseClass)
c.abort(c.enclosingPosition, s"$symbol is not a case class")
classType
}
/**
* Finds the getters of each field that is also a parameter in the first parameter list of the primary constructor of
* a case class.
*
* @param classType the type of the case class
* @return the terms in the same order that they were declared
*/
def caseFieldGetters(classType: c.Type): List[TermSymbol] = {
classType.decls.sorted.filter(_.isTerm).map(_.asTerm).
filter(term => term.isCaseAccessor && term.isGetter)
}
/**
* Assigns the result of evaluating an expression to a value and then provides the term of that value to a
* function that can then reuse that cached value it constructing another tree.
*
* This is part of good macro hygiene.
*
* @param expr the expression to evaluate
* @param other the function that uses the value
* @tparam T the type of the expression
* @return a block containing the value and result of the function
*/
def withCachedExpr[T](expr: c.Expr[T])(other: (TermName) => c.Tree): c.Tree = {
val resolvedTerm = TermName(c.freshName)
val resolvedExpr = q"val $resolvedTerm = $expr"
val exprs = List(resolvedExpr, other(resolvedTerm))
q"..$exprs"
}
}
/**
* Macros for working with case classes.
*/
object CaseClassMacros {
/**
* Creates a sequence of the fields of a case class.
*
* The fields that are returned are the parameters in the first parameter list of the primary constructor.
*
* Each field is access via its getter and is converted into a `String` via its `toString` method.
*
* @param caseClass the case class instance
* @tparam T the type of the case class
* @return a sequence of the fields in the order that they were declared
*/
def fieldsOf[T](caseClass: T): Seq[String] = macro CaseClassMacrosImpl.fieldsOf[T]
}
/**
* Converts a case class to a sequence of `String`.
*
* This uses a technique known as
* <a href="http://docs.scala-lang.org/overviews/macros/implicits#fundep-materialization">Fundep materialization</a>
* and is useful when combined with Type Class Pattern and/or
* <a href="http://docs.scala-lang.org/tutorials/FAQ/finding-implicits.html#context-bounds">Context Bounds</a>.
*/
object CaseClassToSeq {
/**
* Provides an implicit conversion from a case class to a sequence of `String`.
*
* This should be used to provide evidence in Context Bounds.
*
* @tparam T the type of the case class
* @return the conversion type
*/
implicit def materializeCaseClassToSeq[T]: CaseClassToSeq[T] = macro CaseClassMacrosImpl.fieldConverter[T]
}
/**
* Represents a conversion from a case class to a sequence of `String`.
*
* @tparam T the type of the case class
*/
trait CaseClassToSeq[T] {
/**
* Converts the instance to a sequence of `String`.
*
* @param caseClass the case class instance
* @return the case class as a sequence
*/
def toSeq(caseClass: T): Seq[String]
}
copy {
# This is the command for copying a file.
# If using the Vagrant VM the settings can be obtained from ssh-config.
command = ["""C:\Program Files (x86)\PuTTY\pscp.exe""", "-P", "2222",
"-i", ".vagrant/machines/default/virtualbox/private_key.ppk", "-batch",
"-hostkey", ${agent.test.remote.hostkey},
"-l", "vagrant", "$source", "127.0.0.1:/home/vagrant/$target"]
# Variables in the command.
variables {
# The local path of the file to be copied.
source = "$source"
# The target path of the file on the remote.
target = "$target"
}
}
/**
* A command for copying files from the local machine to a remote machine.
*
* @param command the command to be executed
* @param variables the variable names in the command
*/
case class CopyCommand(override val command: Seq[String], variables: CopyVariables)
extends SafeVariableCommand(command, variables) {
}
/**
* The variables for copying.
*
* @param source the source variable
* @param target the target variable
*/
case class CopyVariables(source: String, target: String)
val configs = ConfigFactory.load("command")
val copyCmd = configs.get[CopyCommand]("copy").value
val variables = CopyVariables("""C:\Users\me\test.txt""", "~/test.txt")
val cmd = copyCmd.substitute(variables)
cmd !
import scala.annotation.tailrec
/**
* This represents a command which contains named variables to be substituted at runtime.
*
* The command is a sequence of strings where each string is a command or argument. Each command or argument may
* contain zero, one or more variables.
*
* Variable substitution works by replacing a variable name with another value. There is no special escaping since the
* caller is free to choose whatever name of the variable that is guaranteed to be unambiguous.
*
* Determining where variables are and what they need to be replaced with is done when the command is created.
*
* This is a generic implementation so the number of variables permitted is not fixed. However this means that the
* caller is responsible for providing the correct number of values for substitution.
*
* Initialisation is worst case the number of commands multiplied by the number of variable names. However this is
* intended to be created once and used multiple times so that cost is amortised. Substitution is linear time,
* proportional to the number of commands and arguments plus the number of variables. (At least this is the idea)
*
* This is similar to Scala's string interpolation and [[StringContext]] however it works for variable strings as
* opposed to string literals.
*
* @tparam T type of variables that are accepted
*/
trait VariableCommand[T] {
/**
* The command that contains variables.
*/
val command: Seq[String]
/**
* The names of each variable to be replaced.
*/
val names: T
implicit def toSeq(variables: T): Seq[String]
private type VariableSubstitution = (Seq[String]) => String
private val namesList: List[String] = names.toList
private val substitutions: Seq[VariableSubstitution] = createSubstitutions
/**
* Substitutes the variable names in the command with the values.
*
* The order of the values must match the order of the variable names.
*
* @param values the values to use as substitutions
* @return the command with the substituted values
* @throws IllegalArgumentException if the number of values does not match the number of variable names
*/
def substitute(values: T): Seq[String] = substitutions.map(build => build(values))
private def createSubstitutions: Seq[VariableSubstitution] = command.map(createSubstitution)
private def createSubstitution(arg: String): VariableSubstitution = {
// Recursively go through each position in the argument and check to see if there is a variable at that position.
// If there is then accumulate the part between where the last match ended and where the current match starts,
// followed by the index of the variable (to be used for substitution with the values later). Then continue
// searching from where the variable ends.
// Once we reach the end then accumulate any remaining characters and reverse (since we accumulated by prepending to
// the list) and then convert it to a substitution rule.
@tailrec
def loop(currentPos: Int, lastMatchEnd: Int, accum: List[Either[String, Int]]): VariableSubstitution = {
if (currentPos < arg.length) {
checkVariable(currentPos, arg) match {
case Some(result) =>
val (nextPos, varIndex) = result
val nextAccum = if (currentPos > lastMatchEnd) {
val beforeVar = arg.substring(lastMatchEnd, currentPos)
Right(varIndex) :: Left(beforeVar) :: accum
} else {
Right(varIndex) :: accum
}
loop(nextPos, nextPos, nextAccum)
case None => loop(currentPos + 1, lastMatchEnd, accum)
}
} else {
val finalAccum = if (lastMatchEnd < arg.length) Left(arg.substring(lastMatchEnd)) :: accum else accum
mergeChoices(finalAccum.reverse)
}
}
loop(0, 0, Nil)
}
private def checkVariable(startPos: Int, arg: String): Option[(Int, Int)] = {
// Recursively check each position from the start until we find:
// - a complete match, or
// - there are no more names left that could be a match, or
// - we have gone past the end of the argument
// If a match is found then return both the position after the match (which becomes the next position to search
// from) and also the index of the variable (this is used for substitution later on).
@tailrec
def loop(currentPos: Int, depth: Int, possibleNames: List[(String, Int)]): Option[(Int, Int)] = {
// If the depth is the length of the term then we have already matched every character.
possibleNames.find(_._1.length == depth) match {
case Some(term) => Some(currentPos, term._2)
case None =>
if (possibleNames.isEmpty || currentPos >= arg.length) None
else {
val c = arg.charAt(currentPos)
val matching = possibleNames.filter(_._1.charAt(depth) == c)
loop(currentPos + 1, depth + 1, matching)
}
}
}
loop(startPos, 0, namesList.zipWithIndex)
}
private def mergeChoices(choices: List[Either[String, Int]]): VariableSubstitution = {
// Now we have all the "choices" where a choice is either the part of the argument to be copied verbatim (the
// left) or the index of the variable to be substituted (the right).
// Create a function which, given the variable values (that have the same indicies as the names), will apply each
// choice to build up the resulting argument.
@tailrec
def loop(choices: List[Either[String, Int]], builder: StringBuilder)(values: Seq[String]): String = {
choices match {
case head :: tail =>
head match {
case Left(str) => builder.append(str)
case Right(index) => builder.append(values(index))
}
loop(tail, builder)(values)
case Nil => builder.toString
}
}
// This actually starts the loop. We need to ensure that each function creates its own builder.
def start(values: Seq[String]): String = loop(choices, new StringBuilder)(values)
// Slight optimisation (both now and for substitutions) for when there are no substitutions.
// Be careful with comparing the length of the list since its worst case complexity is the length of the list.
if (choices.lengthCompare(1) == 0 && choices.head.isLeft) _ => choices.head.left.get
else start _
}
}
/**
* An extension to [[VariableCommand]] which does not do any compile time checking of variables.
*/
class UnsafeVariableCommand(val command: Seq[String], val names: String*) extends VariableCommand[Seq[String]] {
override implicit def toSeq(variables: Seq[String]): Seq[String] = {
require(variables.length == names.length, s"The number of variables (${variables.length}) must be equal to the " +
s"number of variable names (${names.length})")
variables
}
}
/**
* An extension to [[VariableCommand]] which provides compile time checking of variables.
*
* @tparam T the type of variables
*/
class SafeVariableCommand[T: CaseClassToSeq](val command: Seq[String], val names: T) extends VariableCommand[T] {
// Be careful with initialisation order in here.
// The VariableCommand constructor needs to use toSeq so the fields it uses need to be initialised lazily.
// Conjure up the implicit converter from the context bounds.
private lazy val converter = implicitly[CaseClassToSeq[T]]
protected lazy val nameSeq = converter.toSeq(names)
override implicit def toSeq(variables: T): Seq[String] = converter.toSeq(variables)
}
@steinybot
Copy link
Author

steinybot commented Apr 30, 2016

Modified to include a type safe version and an example of how it can be used.

@steinybot
Copy link
Author

The safe version is truly safe now. No need to implement a toSeq method. Macros to the rescue!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment