Skip to content

Instantly share code, notes, and snippets.

@tanishiking
Created March 10, 2023 07:57
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tanishiking/aa743479d65008d5c33c75be12c7b46b to your computer and use it in GitHub Desktop.
Save tanishiking/aa743479d65008d5c33c75be12c7b46b to your computer and use it in GitHub Desktop.
package dotty.tools.dotc
package transform
import ast.{TreeTypeMap, tpd}
import config.Printers.tailrec
import core.*
import Contexts.*, Flags.*, Symbols.*
import Constants.Constant
import NameKinds.{TailLabelName, TailLocalName, TailTempName}
import StdNames.nme
import reporting.*
import transform.MegaPhase.MiniPhase
import util.LinearSet
import dotty.tools.uncheckedNN
/** A Tail Rec Transformer.
*
* What it does:
*
* Finds method calls in tail-position and replaces them with jumps.
* A call is in a tail-position if it is the last instruction to be
* executed in the body of a method. This includes being in
* tail-position of a `return` from a `Labeled` block which is itself
* in tail-position (which is critical for tail-recursive calls in the
* cases of a `match`). To identify tail positions, we recurse over
* the trees that may contain calls in tail-position (trees that can't
* contain such calls are not transformed).
*
* When a method contains at least one tail-recursive call, its rhs
* is wrapped in the following structure:
*
* ```
* var localForParam1: T1 = param1
* ...
* while (<empty>) {
* tailResult[ResultType]: {
* return {
* // original rhs with tail recursive calls transformed (see below)
* }
* }
* }
* ```
*
* Self-recursive calls in tail-position are then replaced by (a)
* reassigning the local `var`s substituting formal parameters and
* (b) a `return` from the `tailResult` labeled block, which has the
* net effect of looping back to the beginning of the method.
* If the receiver is modifed in a recursive call, an additional `var`
* is used to replace `this`.
*
* As a complete example of the transformation, the classical `fact`
* function, defined as:
*
* ```
* def fact(n: Int, acc: Int): Int =
* if (n == 0) acc
* else fact(n - 1, acc * n)
* ```
*
* is rewritten as:
*
* ```
* def fact(n: Int, acc: Int): Int = {
* var acc$tailLocal1: Int = acc
* var n$tailLocal1: Int = n
* while (<empty>) {
* tailLabel1[Unit]: {
* return {
* if (n$tailLocal1 == 0)
* acc$tailLocal1
* else {
* val n$tailLocal1$tmp1: Int = n$tailLocal1 - 1
* val acc$tailLocal1$tmp1: Int = acc$tailLocal1 * n$tailLocal1
* n$tailLocal1 = n$tailLocal1$tmp1
* acc$tailLocal1 = acc$tailLocal1$tmp1
* (return[tailLabel1] ()): Int
* }
* }
* }
* }
* }
* ```
*
* As the JVM provides no way to jump from a method to another one,
* non-recursive calls in tail-position are not optimized.
*
* A method call is self-recursive if it calls the current method and
* the method is final (otherwise, it could be a call to an overridden
* method in a subclass). Recursive calls on a different instance are
* optimized.
*
* This phase has been moved after erasure to allow the use of vars
* for the parameters combined with a `WhileDo`. This is also
* beneficial to support polymorphic tail-recursive calls.
*
* In scalac, if the method had type parameters, the call must contain
* the same parameters as type arguments. This is no longer the case in
* dotc thanks to being located after erasure.
* In scalac, this is named tailCall but it does only provide optimization for
* self recursive functions, that's why it's renamed to tailrec
*
* @author
* Erik Stenman, Iulian Dragos,
* ported and heavily modified for dotty by Dmitry Petrashko
* moved after erasure and adapted to emit `Labeled` blocks by Sébastien Doeraene
*/
class TailRec extends MiniPhase {
import tpd._
override def phaseName: String = TailRec.name
override def description: String = TailRec.description
override def runsAfter: Set[String] = Set(Erasure.name) // tailrec assumes erased types
override def transformDefDef(tree: DefDef)(using Context): Tree = {
val method = tree.symbol
val mandatory = method.hasAnnotation(defn.TailrecAnnot)
def noTailTransform(failureReported: Boolean) = {
// FIXME: want to report this error on `tree.nameSpan`, but
// because of extension method getting a weird position, it is
// better to report on method symbol so there's no overlap.
// We don't report a new error if failures were reported
// during the transformation.
if (mandatory && !failureReported)
report.error(TailrecNotApplicable(method), method.srcPos)
tree
}
val isCandidate = method.isEffectivelyFinal &&
!(method.is(Accessor) || tree.rhs.eq(EmptyTree))
if (isCandidate) {
val enclosingClass = method.enclosingClass.asClass
// Note: this can be split in two separate transforms(in different groups),
// than first one will collect info about which transformations and rewritings should be applied
// and second one will actually apply,
// now this speculatively transforms tree and throws away result in many cases
val transformer = new TailRecElimination(method, enclosingClass, tree.termParamss.head.map(_.symbol), mandatory)
val rhsSemiTransformed = transformer.transform(tree.rhs)
if (transformer.rewrote) {
val varForRewrittenThis = transformer.varForRewrittenThis
val rewrittenParamSyms = transformer.rewrittenParamSyms
val varsForRewrittenParamSyms = transformer.varsForRewrittenParamSyms
val initialVarDefs = {
val initialParamVarDefs = rewrittenParamSyms.lazyZip(varsForRewrittenParamSyms).map {
(param, local) => ValDef(local.asTerm, ref(param))
}
varForRewrittenThis match {
case Some(local) => ValDef(local.asTerm, This(enclosingClass)) :: initialParamVarDefs
case none => initialParamVarDefs
}
}
val rhsFullyTransformed = varForRewrittenThis match {
case Some(localThisSym) =>
val thisRef = localThisSym.termRef
val substitute = new TreeTypeMap(
typeMap = _.substThisUnlessStatic(enclosingClass, thisRef)
.subst(rewrittenParamSyms, varsForRewrittenParamSyms.map(_.termRef)),
treeMap = {
case tree: This if tree.symbol == enclosingClass => Ident(thisRef)
case tree => tree
}
)
// The previous map will map `This` references to `Ident`s even under `Super`.
// This violates super's contract. We fix this by cleaning up `Ident`s under
// super, mapping them back to the original `This` reference. This is not
// very elegant, but I did not manage to find a cleaner way to handle this.
// See pos/tailrec-super.scala for a test case.
val cleanup = new TreeMap:
override def transform(t: Tree)(using Context) = t match
case Super(qual: Ident, mix) if !qual.tpe.isInstanceOf[Types.ThisType] =>
cpy.Super(t)(This(enclosingClass), mix)
case _ =>
super.transform(t)
cleanup.transform(substitute.transform(rhsSemiTransformed))
case none =>
new TreeTypeMap(
typeMap = _.subst(rewrittenParamSyms, varsForRewrittenParamSyms.map(_.termRef))
).transform(rhsSemiTransformed)
}
/** Is the RHS a direct recursive tailcall, possibly with swapped arguments or modified pure arguments.
* ```
* def f(<params>): T = f(<args>)
* ```
* where `<args>` are pure arguments or references to parameters in `<params>`.
*/
def isInfiniteRecCall(tree: Tree): Boolean = {
def tailArgOrPureExpr(stat: Tree): Boolean = stat match {
case stat: ValDef if stat.name.is(TailTempName) || !stat.symbol.is(Mutable) => tailArgOrPureExpr(stat.rhs)
case Assign(lhs: Ident, rhs) if lhs.symbol.name.is(TailLocalName) => tailArgOrPureExpr(rhs)
case Assign(lhs: Ident, rhs: Ident) => lhs.symbol == rhs.symbol
case stat: Ident if stat.symbol.name.is(TailLocalName) => true
case _ => tpd.isPureExpr(stat)
}
tree match {
case Typed(expr, _) => isInfiniteRecCall(expr)
case Return(Literal(Constant(())), label) => label.symbol == transformer.continueLabel
case Block(stats, expr) => stats.forall(tailArgOrPureExpr) && isInfiniteRecCall(expr)
case _ => false
}
}
if isInfiniteRecCall(rhsFullyTransformed) then
report.warning("Infinite recursive call", tree.srcPos)
cpy.DefDef(tree)(rhs =
Block(
initialVarDefs,
WhileDo(EmptyTree, {
Labeled(transformer.continueLabel.asTerm, {
Return(rhsFullyTransformed, method)
})
})
)
)
}
else noTailTransform(failureReported = transformer.failureReported)
}
else noTailTransform(failureReported = false)
}
class TailRecElimination(method: Symbol, enclosingClass: ClassSymbol, paramSyms: List[Symbol], isMandatory: Boolean) extends TreeMap {
var rewrote: Boolean = false
var failureReported: Boolean = false
/** The `tailLabelN` label symbol, used to encode a `continue` from the infinite `while` loop. */
private var myContinueLabel: Symbol | Null = _
def continueLabel(using Context): Symbol = {
if (myContinueLabel == null)
myContinueLabel = newSymbol(method, TailLabelName.fresh(), Label, defn.UnitType)
myContinueLabel.uncheckedNN
}
/** The local `var` that replaces `this`, if it is modified in at least one recursive call. */
var varForRewrittenThis: Option[Symbol] = None
/** The subset of `paramSyms` that are modified in at least one recursive call, and which therefore need a replacement `var`. */
var rewrittenParamSyms: List[Symbol] = Nil
/** The replacement `var`s for the params in `rewrittenParamSyms`. */
var varsForRewrittenParamSyms: List[Symbol] = Nil
private def getVarForRewrittenThis()(using Context): Symbol =
varForRewrittenThis match {
case Some(sym) => sym
case none =>
val tpe =
if (enclosingClass.is(Module)) enclosingClass.thisType
else enclosingClass.classInfo.selfType
val sym = newSymbol(method, TailLocalName.fresh(nme.SELF), Synthetic | Mutable, tpe)
varForRewrittenThis = Some(sym)
sym
}
private def getVarForRewrittenParam(param: Symbol)(using Context): Symbol =
rewrittenParamSyms.indexOf(param) match {
case -1 =>
val sym = newSymbol(method, TailLocalName.fresh(param.name.toTermName), Synthetic | Mutable, param.info)
rewrittenParamSyms ::= param
varsForRewrittenParamSyms ::= sym
sym
case index => varsForRewrittenParamSyms(index)
}
/** Symbols of Labeled blocks that are in tail position. */
private var tailPositionLabeledSyms = LinearSet.empty[Symbol]
private var inTailPosition = true
/** Rewrite this tree to contain no tail recursive calls */
def transform(tree: Tree, tailPosition: Boolean)(using Context): Tree =
if (inTailPosition == tailPosition) transform(tree)
else {
val saved = inTailPosition
inTailPosition = tailPosition
try transform(tree)
finally inTailPosition = saved
}
def yesTailTransform(tree: Tree)(using Context): Tree =
transform(tree, tailPosition = true)
def noTailTransform(tree: Tree)(using Context): Tree =
transform(tree, tailPosition = false)
def noTailTransforms[Tr <: Tree](trees: List[Tr])(using Context): List[Tr] =
trees.mapConserve(noTailTransform).asInstanceOf[List[Tr]]
override def transform(tree: Tree)(using Context): Tree = {
/* Rewrite an Apply to be considered for tail call transformation. */
def rewriteApply(tree: Apply): Tree = {
val arguments = noTailTransforms(tree.args)
def continue =
cpy.Apply(tree)(noTailTransform(tree.fun), arguments)
def fail(reason: String) = {
if (isMandatory) {
failureReported = true
report.error(s"Cannot rewrite recursive call: $reason", tree.srcPos)
}
else
tailrec.println("Cannot rewrite recursive call at: " + tree.span + " because: " + reason)
continue
}
val calledMethod = tree.fun.symbol
val prefix = tree.fun match {
case Select(qual, _) => qual
case x: Ident if x.symbol eq method => EmptyTree
case x => x
}
val isRecursiveCall = calledMethod eq method
def isRecursiveSuperCall = (method.name eq calledMethod.name) &&
method.matches(calledMethod) &&
enclosingClass.appliedRef.widen <:< prefix.tpe.widenDealias
if (isRecursiveCall)
if (inTailPosition) {
tailrec.println("Rewriting tail recursive call: " + tree.span)
rewrote = true
val assignParamPairs = for {
(param, arg) <- paramSyms.zip(arguments)
if (arg match {
case arg: Ident => arg.symbol != param
case _ => true
})
}
yield
(getVarForRewrittenParam(param), arg)
val assignThisAndParamPairs = prefix match
case EmptyTree =>
assignParamPairs
case prefix: This if prefix.symbol == enclosingClass =>
// Avoid assigning `this = this`
assignParamPairs
case _ =>
(getVarForRewrittenThis(), noTailTransform(prefix)) :: assignParamPairs
val assignments = assignThisAndParamPairs match {
case (lhs, rhs) :: Nil =>
Assign(ref(lhs), rhs) :: Nil
case _ :: _ =>
val (tempValDefs, assigns) = (for ((lhs, rhs) <- assignThisAndParamPairs) yield {
val temp = newSymbol(method, TailTempName.fresh(lhs.name.toTermName), Synthetic, lhs.info)
(ValDef(temp, rhs), Assign(ref(lhs), ref(temp)).withSpan(tree.span))
}).unzip
tempValDefs ::: assigns
case nil =>
Nil
}
/* The `Typed` node is necessary to perfectly preserve the type of the node.
* Without it, lubbing in enclosing if/else or match can infer a different type,
* which can cause Ycheck errors.
*/
val tpt = TypeTree(method.info.resultType)
seq(assignments, Typed(Return(unitLiteral.withSpan(tree.span), continueLabel), tpt))
}
else fail("it is not in tail position")
else if (isRecursiveSuperCall)
fail("it targets a supertype")
else
continue
}
tree match {
case tree @ Apply(fun, args) =>
val meth = fun.symbol
if (meth == defn.Boolean_|| || meth == defn.Boolean_&&)
cpy.Apply(tree)(noTailTransform(fun), transform(args))
else
rewriteApply(tree)
case tree @ Select(qual, name) =>
cpy.Select(tree)(noTailTransform(qual), name)
case tree @ Block(stats, expr) =>
cpy.Block(tree)(
noTailTransforms(stats),
transform(expr)
)
case tree @ If(cond, thenp, elsep) =>
cpy.If(tree)(
noTailTransform(cond),
transform(thenp),
transform(elsep)
)
case tree: CaseDef =>
cpy.CaseDef(tree)(body = transform(tree.body))
case tree @ Match(selector, cases) =>
cpy.Match(tree)(
noTailTransform(selector),
transformSub(cases)
)
case tree: Try =>
val expr = noTailTransform(tree.expr)
if (tree.finalizer eq EmptyTree)
// SI-1672 Catches are in tail position when there is no finalizer
cpy.Try(tree)(expr, transformSub(tree.cases), EmptyTree)
else cpy.Try(tree)(
expr,
noTailTransforms(tree.cases),
noTailTransform(tree.finalizer)
)
case tree @ WhileDo(cond, body) =>
cpy.WhileDo(tree)(
noTailTransform(cond),
noTailTransform(body)
)
case _: Alternative | _: Bind =>
assert(false, "We should never have gotten inside a pattern")
tree
case tree: ValOrDefDef =>
if (isMandatory) noTailTransform(tree.rhs)
tree
case _: Super | _: This | _: Literal | _: TypeTree | _: TypeDef | EmptyTree =>
tree
case Labeled(bind, expr) =>
if (inTailPosition)
tailPositionLabeledSyms += bind.symbol
try cpy.Labeled(tree)(bind, transform(expr))
finally
if (inTailPosition)
tailPositionLabeledSyms -= bind.symbol
case Return(expr, from) =>
val fromSym = from.symbol
val inTailPosition = !fromSym.is(Label) || tailPositionLabeledSyms.contains(fromSym)
cpy.Return(tree)(transform(expr, inTailPosition), from)
case _ =>
super.transform(tree)
}
}
}
}
object TailRec {
val name: String = "tailrec"
val description: String = "rewrite tail recursion to loops"
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment