Skip to content

Instantly share code, notes, and snippets.

@ShaneDelmore
Created January 17, 2017 06:12
Show Gist options
  • Save ShaneDelmore/38e414740dce27cfde76ad47f4624e30 to your computer and use it in GitHub Desktop.
Save ShaneDelmore/38e414740dce27cfde76ad47f4624e30 to your computer and use it in GitHub Desktop.
More fun with scalatest
package scalafix.util
import scala.collection.immutable.Seq
import scala.meta._
import syntax._
object Import {
def remove(prefix: String, term: String): PartialFunction[Tree, Tree] = {
case b @ q"{ ..$stats }" /* verify targeted term exists */ =>
val newStats: Seq[Stat] = stats.collect {
case t @ q"import ..$importersnel"
if importersnel.exists(
imprtr =>
imprtr.ref.syntax == prefix &&
imprtr.importees.exists(_.syntax == term) &&
imprtr.importees.filterNot(_.syntax == term).length > 0) =>
val filtered: Seq[Importer] = importersnel.collect {
case importer"$ref.{..$importeesnel}" if {
importeesnel.length > 1
} =>
val filteredImports: Seq[Importee] =
importeesnel.filterNot(_.syntax == term)
importer"$ref.{..$filteredImports}"
}
q"import ..$filtered"
case unchanged if {
unchanged match {
//discard if an import with one item matching excluded term
case q"import ..$importersnel" =>
val toMatch = s"$prefix.$term"
!(importersnel.head.syntax == toMatch
|| importersnel.head.syntax.startsWith(toMatch))
//Otherwise passthrough
case _ => true
}
} =>
unchanged
}
q"{ ..$newStats }"
}
//apply all unimport rewrites
//tail recursive version is safer but I like the readabilty of the fold more.
//I could also fold the functions up then apply them: ops.foldRight(identity)(_ andThen _)
def removeAll(prefix: String,
terms: Seq[String]): PartialFunction[Tree, Tree] = {
case tree =>
terms
.map(term => remove(prefix, term))
.foldLeft(tree)((value, op) => value.transform(op))
}
//I am still experimenting with various ways to stop transforming a tree.
//Might need transformFirst or something else like that, with imports I may match multiple places to add
//them but only want to add a single time and stop
def hasImporter(prefix: String): PartialFunction[Tree, Tree] = {
case t @ importer"$ref.{..$importeesnel}" if ref.syntax == prefix => t
}
def hasImport(prefix: String, term: String): PartialFunction[Tree, Tree] = {
case t @ importer"$ref.{..$importeesnel}" if {
ref.syntax == prefix &&
importeesnel.map(_.syntax).intersect(Seq(term, "_")).nonEmpty
} =>
t
}
def add(prefix: String, term: String): PartialFunction[Tree, Tree] = {
//import exists
case b @ q"{ ..$stats }" if b.matches(hasImport(prefix, term)) =>
b
//importer exists
case b @ q"{ ..$stats }" if b.matches(hasImporter(prefix)) =>
b.transform {
case t @ importer"$ref.{..$importeesnel}" if ref.syntax == prefix => {
val withImport = importeesnel :+ Importee.Name(
Name.Indeterminate(term))
importer"$ref.{..$withImport}"
}
}
//Otherwise just add a new import statement to the first block
case b @ q"{ ..$stats }" if !b.matches(hasImport(prefix, term)) =>
val withImport: Seq[Stat] =
//This is a bit ugly, is there a better way?
//Should at least remove the .get, or is it better to blow up here?
s"import $prefix.$term".parse[Stat].get +: stats
q"{ ..$withImport}"
}
def addAll(prefix: String, terms: Seq[String]): PartialFunction[Tree, Tree] = {
case tree =>
terms
.map(term => add(prefix, term))
.foldLeft(tree)((value, op) => value.transform(op))
}
}
package scalafix.util
import scala.meta._
import org.scalatest.FunSuiteLike
import Import._
import syntax._
import scala.collection.immutable.Seq
class ImportTest extends FunSuiteLike with DiffAssertions {
val ast = q"""import scala.concurrent.Future
import cats.data.{ Xor, XorT, EitherT }
import cats.data.Xor._
sealed abstract class Done"""
test("matches works") {
val pred: PartialFunction[Tree, Tree] = {
case t @ q"import ..$importers" => t
}
assert(ast.matches(pred))
}
test("removes a single import from multiple") {
val expected = q"""import scala.concurrent.Future
import cats.data.{ XorT, EitherT }
sealed abstract class Done"""
val result = ast.transform(remove("cats.data", "Xor"))
assertNoDiff(result, expected)
}
test("removes multiple imports") {
val expected = q"""import scala.concurrent.Future
import cats.data.{ EitherT }
sealed abstract class Done"""
val result = ast.transform(removeAll("cats.data", Seq("Xor", "XorT")))
assertNoDiff(result, expected)
}
test("removes importer if all importees are removed") {
val expected = q"""import scala.concurrent.Future
sealed abstract class Done"""
val result =
ast.transform(removeAll("cats.data", Seq("Xor", "XorT", "EitherT")))
assertNoDiff(result, expected)
}
test("remove does nothing if not matched") {
val result = ast.transform(remove("cats.data", "NotFound"))
assertNoDiff(result, ast)
}
test("add does nothing if already exists") {
val result = ast.transform(add("cats.data", "EitherT"))
assertNoDiff(result, ast)
}
test("adds the importee if importer exist") {
val expected = q"""import scala.concurrent.Future
import cats.data.{ Xor, XorT, EitherT, OptionT }
import cats.data.Xor._
sealed abstract class Done"""
val result = ast.transform(add("cats.data", "OptionT"))
assertNoDiff(result, expected)
}
test("add does nothing if wildcard import already exists") {
val expected = q"""import scala.concurrent.Future
import cats.data.{ Xor, XorT, EitherT }
import cats.data.Xor._
sealed abstract class Done"""
val result = ast.transform(add("cats.data.Xor", "Right"))
assertNoDiff(result, expected)
}
//Pending - time for bed
test("add removes unimport if found") {
val subject = q"""import scala.concurrent.Future
import cats.data.{ Xor, XorT, EitherT }
import cats.data.Xor.{ Right => _, _ }
sealed abstract class Done"""
val expected = q"""import scala.concurrent.Future
import cats.data.{ Xor, XorT, EitherT }
import cats.data.Xor._
sealed abstract class Done"""
val result = subject.transform(add("cats.data.Xor", "Right"))
assertNoDiff(result, expected)
}
test("adds the importer and importee if neither exist") {
val expected = q"""import scala.util.Either
import scala.concurrent.Future
import cats.data.{ Xor, XorT, EitherT }
import cats.data.Xor._
sealed abstract class Done"""
val result = ast.transform(add("scala.util", "Either"))
assertNoDiff(result, expected)
}
}
import scala.collection.immutable.Seq
import scala.meta._
import scalafix.util.Import._
import scalafix.util.ChangeType._
val ast =
q"""
import scala.concurrent.Future
import cats.data.{ Xor, XorT }
import cats.data.Xor._
trait A {
type MyDisjunction = Xor[Int, String]
val s: Xor[Int, String] = Xor.Left(1 /* comment */)
val t: Xor[Int, String] = r.map(_ + "!")
val nest: Seq[Xor[Int, cats.data.Xor[String, Int]]]
val u: XorT[Future, Int, String] = ???
val v: Xor[Int, String] = l match {
case Xor.Right(foo) => Xor.Right(foo)
case Xor.Left(bar) => Xor.Left(bar)
}
val l: MyDisjunction = left(321)
val r: MyDisjunction = Right.apply("")
val rfo: MyDisjunction = Xor.right("hi")
}
sealed abstract class Done
case object Done extends Done {
val rightDone = Xor.right(Done)
}
"""
ast
//This all gets cleaner when I can use types, I was unable to get them going without
//The full rewrite which was slowing me down without the repl and unit tests I was used to
//and because I was unsure if I should modify rewrites to emit trees instead of patches
.transform(
//Not sure how I like this orElse syntax, just playing with partial functions
replaceType("XorT", "EitherT")
orElse replaceType("Xor", "Either")
)
//ideally I would detect if imports are even used before adding them
.transform(addAll("scala.util", Seq("Either", "Left", "Right")))
.transform(removeAll("cats.data", Seq("Xor", "XorT", "Either")))
.transform(add("cats.data", "EitherT"))
.transform(add("cats.implicits.syntax", "_"))
res0: scala.meta.Tree = {
import cats.implicits.syntax._
import cats.data.EitherT
import scala.util.{ Either, Left, Right }
import scala.concurrent.Future
trait A {
type MyDisjunction = Either[Int, String]
val s: Either[Int, String] = Either.Left(1)
val t: Either[Int, String] = r.map(_ + "!")
val nest: Seq[Either[Int, cats.data.Either[String, Int]]]
val u: EitherT[Future, Int, String] = ???
val v: Either[Int, String] = l match {
case Either.Right(foo) =>
Either.Right(foo)
case Either.Left(bar) =>
Either.Left(bar)
}
val l: MyDisjunction = left(321)
val r: MyDisjunction = Right.apply("")
val rfo: MyDisjunction = Either.right("hi")
}
sealed abstract class Done
case object Done extends Done { val rightDone = Either.right(Done) }
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment