Created
January 17, 2017 06:12
-
-
Save ShaneDelmore/38e414740dce27cfde76ad47f4624e30 to your computer and use it in GitHub Desktop.
More fun with scalatest
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package scalafix.util | |
import scala.collection.immutable.Seq | |
import scala.meta._ | |
import syntax._ | |
object Import { | |
def remove(prefix: String, term: String): PartialFunction[Tree, Tree] = { | |
case b @ q"{ ..$stats }" /* verify targeted term exists */ => | |
val newStats: Seq[Stat] = stats.collect { | |
case t @ q"import ..$importersnel" | |
if importersnel.exists( | |
imprtr => | |
imprtr.ref.syntax == prefix && | |
imprtr.importees.exists(_.syntax == term) && | |
imprtr.importees.filterNot(_.syntax == term).length > 0) => | |
val filtered: Seq[Importer] = importersnel.collect { | |
case importer"$ref.{..$importeesnel}" if { | |
importeesnel.length > 1 | |
} => | |
val filteredImports: Seq[Importee] = | |
importeesnel.filterNot(_.syntax == term) | |
importer"$ref.{..$filteredImports}" | |
} | |
q"import ..$filtered" | |
case unchanged if { | |
unchanged match { | |
//discard if an import with one item matching excluded term | |
case q"import ..$importersnel" => | |
val toMatch = s"$prefix.$term" | |
!(importersnel.head.syntax == toMatch | |
|| importersnel.head.syntax.startsWith(toMatch)) | |
//Otherwise passthrough | |
case _ => true | |
} | |
} => | |
unchanged | |
} | |
q"{ ..$newStats }" | |
} | |
//apply all unimport rewrites | |
//tail recursive version is safer but I like the readabilty of the fold more. | |
//I could also fold the functions up then apply them: ops.foldRight(identity)(_ andThen _) | |
def removeAll(prefix: String, | |
terms: Seq[String]): PartialFunction[Tree, Tree] = { | |
case tree => | |
terms | |
.map(term => remove(prefix, term)) | |
.foldLeft(tree)((value, op) => value.transform(op)) | |
} | |
//I am still experimenting with various ways to stop transforming a tree. | |
//Might need transformFirst or something else like that, with imports I may match multiple places to add | |
//them but only want to add a single time and stop | |
def hasImporter(prefix: String): PartialFunction[Tree, Tree] = { | |
case t @ importer"$ref.{..$importeesnel}" if ref.syntax == prefix => t | |
} | |
def hasImport(prefix: String, term: String): PartialFunction[Tree, Tree] = { | |
case t @ importer"$ref.{..$importeesnel}" if { | |
ref.syntax == prefix && | |
importeesnel.map(_.syntax).intersect(Seq(term, "_")).nonEmpty | |
} => | |
t | |
} | |
def add(prefix: String, term: String): PartialFunction[Tree, Tree] = { | |
//import exists | |
case b @ q"{ ..$stats }" if b.matches(hasImport(prefix, term)) => | |
b | |
//importer exists | |
case b @ q"{ ..$stats }" if b.matches(hasImporter(prefix)) => | |
b.transform { | |
case t @ importer"$ref.{..$importeesnel}" if ref.syntax == prefix => { | |
val withImport = importeesnel :+ Importee.Name( | |
Name.Indeterminate(term)) | |
importer"$ref.{..$withImport}" | |
} | |
} | |
//Otherwise just add a new import statement to the first block | |
case b @ q"{ ..$stats }" if !b.matches(hasImport(prefix, term)) => | |
val withImport: Seq[Stat] = | |
//This is a bit ugly, is there a better way? | |
//Should at least remove the .get, or is it better to blow up here? | |
s"import $prefix.$term".parse[Stat].get +: stats | |
q"{ ..$withImport}" | |
} | |
def addAll(prefix: String, terms: Seq[String]): PartialFunction[Tree, Tree] = { | |
case tree => | |
terms | |
.map(term => add(prefix, term)) | |
.foldLeft(tree)((value, op) => value.transform(op)) | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package scalafix.util | |
import scala.meta._ | |
import org.scalatest.FunSuiteLike | |
import Import._ | |
import syntax._ | |
import scala.collection.immutable.Seq | |
class ImportTest extends FunSuiteLike with DiffAssertions { | |
val ast = q"""import scala.concurrent.Future | |
import cats.data.{ Xor, XorT, EitherT } | |
import cats.data.Xor._ | |
sealed abstract class Done""" | |
test("matches works") { | |
val pred: PartialFunction[Tree, Tree] = { | |
case t @ q"import ..$importers" => t | |
} | |
assert(ast.matches(pred)) | |
} | |
test("removes a single import from multiple") { | |
val expected = q"""import scala.concurrent.Future | |
import cats.data.{ XorT, EitherT } | |
sealed abstract class Done""" | |
val result = ast.transform(remove("cats.data", "Xor")) | |
assertNoDiff(result, expected) | |
} | |
test("removes multiple imports") { | |
val expected = q"""import scala.concurrent.Future | |
import cats.data.{ EitherT } | |
sealed abstract class Done""" | |
val result = ast.transform(removeAll("cats.data", Seq("Xor", "XorT"))) | |
assertNoDiff(result, expected) | |
} | |
test("removes importer if all importees are removed") { | |
val expected = q"""import scala.concurrent.Future | |
sealed abstract class Done""" | |
val result = | |
ast.transform(removeAll("cats.data", Seq("Xor", "XorT", "EitherT"))) | |
assertNoDiff(result, expected) | |
} | |
test("remove does nothing if not matched") { | |
val result = ast.transform(remove("cats.data", "NotFound")) | |
assertNoDiff(result, ast) | |
} | |
test("add does nothing if already exists") { | |
val result = ast.transform(add("cats.data", "EitherT")) | |
assertNoDiff(result, ast) | |
} | |
test("adds the importee if importer exist") { | |
val expected = q"""import scala.concurrent.Future | |
import cats.data.{ Xor, XorT, EitherT, OptionT } | |
import cats.data.Xor._ | |
sealed abstract class Done""" | |
val result = ast.transform(add("cats.data", "OptionT")) | |
assertNoDiff(result, expected) | |
} | |
test("add does nothing if wildcard import already exists") { | |
val expected = q"""import scala.concurrent.Future | |
import cats.data.{ Xor, XorT, EitherT } | |
import cats.data.Xor._ | |
sealed abstract class Done""" | |
val result = ast.transform(add("cats.data.Xor", "Right")) | |
assertNoDiff(result, expected) | |
} | |
//Pending - time for bed | |
test("add removes unimport if found") { | |
val subject = q"""import scala.concurrent.Future | |
import cats.data.{ Xor, XorT, EitherT } | |
import cats.data.Xor.{ Right => _, _ } | |
sealed abstract class Done""" | |
val expected = q"""import scala.concurrent.Future | |
import cats.data.{ Xor, XorT, EitherT } | |
import cats.data.Xor._ | |
sealed abstract class Done""" | |
val result = subject.transform(add("cats.data.Xor", "Right")) | |
assertNoDiff(result, expected) | |
} | |
test("adds the importer and importee if neither exist") { | |
val expected = q"""import scala.util.Either | |
import scala.concurrent.Future | |
import cats.data.{ Xor, XorT, EitherT } | |
import cats.data.Xor._ | |
sealed abstract class Done""" | |
val result = ast.transform(add("scala.util", "Either")) | |
assertNoDiff(result, expected) | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import scala.collection.immutable.Seq | |
import scala.meta._ | |
import scalafix.util.Import._ | |
import scalafix.util.ChangeType._ | |
val ast = | |
q""" | |
import scala.concurrent.Future | |
import cats.data.{ Xor, XorT } | |
import cats.data.Xor._ | |
trait A { | |
type MyDisjunction = Xor[Int, String] | |
val s: Xor[Int, String] = Xor.Left(1 /* comment */) | |
val t: Xor[Int, String] = r.map(_ + "!") | |
val nest: Seq[Xor[Int, cats.data.Xor[String, Int]]] | |
val u: XorT[Future, Int, String] = ??? | |
val v: Xor[Int, String] = l match { | |
case Xor.Right(foo) => Xor.Right(foo) | |
case Xor.Left(bar) => Xor.Left(bar) | |
} | |
val l: MyDisjunction = left(321) | |
val r: MyDisjunction = Right.apply("") | |
val rfo: MyDisjunction = Xor.right("hi") | |
} | |
sealed abstract class Done | |
case object Done extends Done { | |
val rightDone = Xor.right(Done) | |
} | |
""" | |
ast | |
//This all gets cleaner when I can use types, I was unable to get them going without | |
//The full rewrite which was slowing me down without the repl and unit tests I was used to | |
//and because I was unsure if I should modify rewrites to emit trees instead of patches | |
.transform( | |
//Not sure how I like this orElse syntax, just playing with partial functions | |
replaceType("XorT", "EitherT") | |
orElse replaceType("Xor", "Either") | |
) | |
//ideally I would detect if imports are even used before adding them | |
.transform(addAll("scala.util", Seq("Either", "Left", "Right"))) | |
.transform(removeAll("cats.data", Seq("Xor", "XorT", "Either"))) | |
.transform(add("cats.data", "EitherT")) | |
.transform(add("cats.implicits.syntax", "_")) | |
res0: scala.meta.Tree = { | |
import cats.implicits.syntax._ | |
import cats.data.EitherT | |
import scala.util.{ Either, Left, Right } | |
import scala.concurrent.Future | |
trait A { | |
type MyDisjunction = Either[Int, String] | |
val s: Either[Int, String] = Either.Left(1) | |
val t: Either[Int, String] = r.map(_ + "!") | |
val nest: Seq[Either[Int, cats.data.Either[String, Int]]] | |
val u: EitherT[Future, Int, String] = ??? | |
val v: Either[Int, String] = l match { | |
case Either.Right(foo) => | |
Either.Right(foo) | |
case Either.Left(bar) => | |
Either.Left(bar) | |
} | |
val l: MyDisjunction = left(321) | |
val r: MyDisjunction = Right.apply("") | |
val rfo: MyDisjunction = Either.right("hi") | |
} | |
sealed abstract class Done | |
case object Done extends Done { val rightDone = Either.right(Done) } | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment