Skip to content

Instantly share code, notes, and snippets.

@retronym
Created January 16, 2022 23:42
Show Gist options
  • Save retronym/79ff0e6e18bf114b9e3458ae2cdedd52 to your computer and use it in GitHub Desktop.
Save retronym/79ff0e6e18bf114b9e3458ae2cdedd52 to your computer and use it in GitHub Desktop.
iterator escape
package demo.compat
import scala.collection.{GenIterableLike, mutable}
import scala.tools.nsc.plugins.{Plugin, PluginComponent}
import scala.tools.nsc.{Global, Phase, Settings}
object IteratorLinter {
def main(args: Array[String]): Unit = {
println("go!")
val g = new Global(new Settings) {
self =>
settings.usejavacp.value = true
object linter extends Plugin {
override val name: String = "linter"
override val components: List[PluginComponent] = linterComponent :: Nil
override val description: String = ""
override val global: self.type = self
object linterComponent extends PluginComponent {
override val global: self.type = self
override val phaseName: String = "linter"
override val runsAfter: List[String] = "typer" :: Nil
override def newPhase(prev: Phase): Phase = new StdPhase(prev) {
override def apply(unit: linterComponent.global.CompilationUnit): Unit = {
linter.traverse(unit.body)
}
}
}
object linter extends Traverser {
val enclosingTrees = mutable.ArrayStack[Tree]()
def iteratorEscapes(): Boolean = {
val it = enclosingTrees.iterator
var lastT: Tree = EmptyTree
while (it.hasNext) {
val t = it.next()
t match {
case _: Apply | _: TypeApply | _: Select =>
lastT = t
case _ =>
if (lastT.tpe.typeSymbol.isSubClass(definitions.IteratorClass)) {
return true
}
}
lastT = t
}
false
}
override def traverse(t: Tree): Unit = {
t match {
case Apply(fun, args) =>
try {
enclosingTrees.push(t)
traverse(fun)
} finally {
enclosingTrees.pop()
}
traverseTrees(args)
case Select(qual, name) =>
try {
enclosingTrees.push(t)
if (t.symbol.name.string_==("iterator") && t.symbol.overriddenSymbol(symbolOf[GenIterableLike[Any, Any]]) != NoSymbol) {
if (iteratorEscapes()) {
reporter.warning(t.pos, "iterator usage discouraged")
}
}
super.traverse(t)
} finally {
enclosingTrees.pop()
}
case _: Annotated =>
super.traverse(t)
case _: TypeApply =>
super.traverse(t)
case _ =>
try {
enclosingTrees.push(EmptyTree)
super.traverse(t)
} finally {
enclosingTrees.pop()
}
}
}
}
}
override protected def loadPlugins(): List[Plugin] = {
linter :: Nil
}
}
import g._
val code = s"""
class C {
Nil.iterator.map(x => x).map(x => x).toList // ok
Nil.iterator.map(x => x).map(x => x).toList.map(x => x) // ok
Nil.iterator // nok
Nil.iterator.map(x => x) // nok
(if (true) Nil.iterator else ???).map(x => x).toList // nok
}
"""
val r = new Run
r.compileSources(newSourceFile(code) :: Nil)
}
}
@lrytz
Copy link

lrytz commented Jan 17, 2022

prefix with

// using scala 2.12.15
// using lib org.scala-lang:scala-compiler:2.12.15

to run it with scala-cli

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment