Skip to content

Instantly share code, notes, and snippets.

@retronym
Created March 2, 2020 12:01
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save retronym/b7a6fc9f8f7bd6316c8927eeb2812417 to your computer and use it in GitHub Desktop.
Save retronym/b7a6fc9f8f7bd6316c8927eeb2812417 to your computer and use it in GitHub Desktop.
diff --git a/src/compiler/scala/tools/nsc/transform/async/ExprBuilder.scala b/src/compiler/scala/tools/nsc/transform/async/ExprBuilder.scala
index 4975a4e04f..389ca1637a 100644
--- a/src/compiler/scala/tools/nsc/transform/async/ExprBuilder.scala
+++ b/src/compiler/scala/tools/nsc/transform/async/ExprBuilder.scala
@@ -68,6 +68,17 @@ trait ExprBuilder extends TransformUtils {
s"AsyncState #$state, next = $nextState"
}
+ final class OpaqueAsyncState(var stats: List[Tree], val state: Int, val nextStates: Array[Int], symLookup: SymLookup)
+ extends AsyncState {
+
+ def mkHandlerCaseForState[T]: CaseDef = {
+ mkHandlerCase(state, adaptToUnitIgnoringNothing(stats))
+ }
+
+ override val toString: String =
+ s"OpaqueAsyncState #$state"
+ }
+
/** A sequence of statements with a conditional transition to the next state, which will represent
* a branch of an `if` or a `match`.
*/
@@ -104,7 +115,7 @@ trait ExprBuilder extends TransformUtils {
val ifTree =
If(Apply(null_ne, Ident(tempCompletedSym) :: Nil),
adaptToUnit(ifIsFailureTree[T](Ident(tempCompletedSym)) :: Nil),
- Block(toList(callOnComplete), Return(literalUnit)))
+ flattenBlock(callOnComplete, Return(literalUnit)))
initAwaitableTemp :: initTempCompleted :: ifTree :: Nil
} else {
val callOnComplete = futureSystemOps.onComplete[Any, Unit](awaitable.expr, fun, Ident(nme.execContext), definitions.AnyTpe)
@@ -153,7 +164,7 @@ trait ExprBuilder extends TransformUtils {
/*
* Builder for a single state of an async expression.
*/
- final class AsyncStateBuilder(state: Int, private val symLookup: SymLookup) {
+ private final class AsyncStateBuilder(state: Int, private val symLookup: SymLookup) {
/* Statements preceding an await call. */
private val stats = ListBuffer[Tree]()
/** The state of the target of a LabelDef application (while loop jump) */
@@ -194,6 +205,9 @@ trait ExprBuilder extends TransformUtils {
def resultSimple(nextState: Int): AsyncState = {
new SimpleAsyncState(stats.toList, state, effectiveNextState(nextState), symLookup)
}
+ def resultOpaque(nextStates: Array[Int]): OpaqueAsyncState = {
+ new OpaqueAsyncState(stats.toList, state, nextStates, symLookup)
+ }
def resultWithIf(condTree: Tree, thenState: Int, elseState: Int): AsyncState = {
def mkBranch(state: Int) = mkStateTree(state, symLookup)
@@ -246,6 +260,7 @@ trait ExprBuilder extends TransformUtils {
private val symLookup: SymLookup) {
val asyncStates = ListBuffer[AsyncState]()
+ var isCase: Boolean = false
var stateBuilder = new AsyncStateBuilder(startState, symLookup)
var currState = startState
@@ -279,7 +294,7 @@ trait ExprBuilder extends TransformUtils {
private def containsForeignLabelJump(t: Tree): Boolean = {
val labelDefs = t.collect { case ld: LabelDef => ld.symbol }.toSet
t.exists {
- case rt: RefTree => rt.symbol != null && isLabel(rt.symbol) && !(labelDefs contains rt.symbol)
+ case rt: RefTree => rt.symbol != null && isLabel(rt.symbol) && !(labelDefs contains rt.symbol) && !rt.symbol.name.startsWith("case")
case _ => false
}
}
@@ -341,17 +356,30 @@ trait ExprBuilder extends TransformUtils {
currState = afterMatchState
stateBuilder = new AsyncStateBuilder(currState, symLookup)
- case ld @ LabelDef(name, params, rhs)
- if containsAwait(rhs) || directlyAdjacentLabelDefs(ld).exists(containsAwait) =>
-
- val startLabelState = stateIdForLabel(ld.symbol)
- val afterLabelState = afterState.getOrElse(nextState())
- asyncStates += stateBuilder.resultWithLabel(startLabelState, symLookup)
- labelDefStates(ld.symbol) = startLabelState
- val builder = nestedBlockBuilder(rhs, startLabelState, afterLabelState)
- asyncStates ++= builder.asyncStates
- currState = afterLabelState
- stateBuilder = new AsyncStateBuilder(currState, symLookup)
+ case ld @ LabelDef(name, params, rhs) =>
+ val adjacent = directlyAdjacentLabelDefs(ld)
+ if (containsAwait(rhs)) {
+ val startLabelState = stateIdForLabel(ld.symbol)
+ val afterLabelState = afterState.getOrElse(nextState())
+ asyncStates += stateBuilder.resultWithLabel(startLabelState, symLookup)
+ labelDefStates(ld.symbol) = startLabelState
+ val builder = nestedBlockBuilder(rhs, startLabelState, afterLabelState)
+ asyncStates ++= builder.asyncStates
+ currState = afterLabelState
+ stateBuilder = new AsyncStateBuilder(currState, symLookup)
+ } else if (adjacent.exists(containsAwait)) {
+ val startLabelState = stateIdForLabel(ld.symbol)
+ labelDefStates(ld.symbol) = startLabelState
+ stateBuilder = new AsyncStateBuilder(currState, symLookup)
+ val caseLabelSyms = adjacent.map(_.symbol).toSet
+ val nextStates = rhs.collect {
+ case Apply(fun, args) if isLabel(fun.symbol) && caseLabelSyms.contains(fun.symbol) =>
+ stateIdForLabel(fun.symbol)
+ }
+ stateBuilder += rhs
+ asyncStates += stateBuilder.resultOpaque(nextStates.toArray.distinct)
+ currState = startLabelState
+ }
case b @ Block(stats, expr) =>
for (stat <- stats) add(stat)
add(expr, afterState = Some(endState))
@@ -485,7 +513,10 @@ trait ExprBuilder extends TransformUtils {
seen.add(state.state)
for (i <- state.nextStates) {
if (i != Int.MaxValue && !seen.contains(i)) {
- loop(map(i))
+ map.get(i) match {
+ case Some(x) => loop(x)
+ case None =>
+ }
}
}
}
@@ -646,5 +677,9 @@ trait ExprBuilder extends TransformUtils {
case Block(stats, expr) if isLiteralUnit(expr) => stats
case _ => tree :: Nil
}
-
+ private def flattenBlock(stats: Tree, expr: Tree): Tree = stats match {
+ case Block(stats1, expr1) if isLiteralUnit(expr1) => treeCopy.Block(stats, stats1, expr).clearType()
+ case Block(stats1, expr1) => treeCopy.Block(stats, stats1 :+ expr1, expr).clearType()
+ case tree => treeCopy.Block(stats, stats :: Nil, expr).clearType()
+ }
}
diff --git a/src/compiler/scala/tools/nsc/transform/patmat/MatchCodeGen.scala b/src/compiler/scala/tools/nsc/transform/patmat/MatchCodeGen.scala
index f11d07ad98..440fdf3475 100644
--- a/src/compiler/scala/tools/nsc/transform/patmat/MatchCodeGen.scala
+++ b/src/compiler/scala/tools/nsc/transform/patmat/MatchCodeGen.scala
@@ -204,7 +204,15 @@ trait MatchCodeGen extends Interface {
// res: T
// returns MatchMonad[T]
def one(res: Tree): Tree = matchEnd APPLY (res) // a jump to a case label is special-cased in typedApply
- protected def zero: Tree = nextCase APPLY ()
+ protected final def zero: Tree = nextCase APPLY ()
+ override def ifThenElseZero(c: Tree, thenp: Tree): Tree = {
+ thenp match {
+ case Block(stats, expr) =>
+ Block(If(NOT(c), zero, EmptyTree) :: stats, expr)
+ case _ =>
+ Block(If(NOT(c), zero, EmptyTree) :: Nil, thenp)
+ }
+ }
// prev: MatchMonad[T]
// b: T
@@ -212,14 +220,21 @@ trait MatchCodeGen extends Interface {
// returns MatchMonad[U]
def flatMap(prev: Tree, b: Symbol, next: Tree): Tree = {
val prevSym = freshSym(prev.pos, prev.tpe, "o")
- BLOCK(
- ValDef(prevSym, prev),
- // must be isEmpty and get as we don't control the target of the call (prev is an extractor call)
+ val nextTree = // must be isEmpty and get as we don't control the target of the call (prev is an extractor call)
ifThenElseZero(
NOT(prevSym DOT vpmName.isEmpty),
Substitution(b, prevSym DOT vpmName.get)(next)
)
- )
+ nextTree match {
+ case Block(stats, expr) =>
+ Block((ValDef(prevSym, prev) :: stats), expr)
+ case _ =>
+ BLOCK(
+ ValDef(prevSym, prev),
+ // must be isEmpty and get as we don't control the target of the call (prev is an extractor call)
+ nextTree
+ )
+ }
}
// cond: Boolean
@@ -230,9 +245,14 @@ trait MatchCodeGen extends Interface {
def flatMapCond(cond: Tree, res: Tree, nextBinder: Symbol, next: Tree): Tree = {
val rest = (
// only emit a local val for `nextBinder` if it's actually referenced in `next`
- if (next.exists(_.symbol eq nextBinder))
- Block(ValDef(nextBinder, res) :: Nil, next)
- else next
+ if (next.exists(_.symbol eq nextBinder)) {
+ next match {
+ case Block(stats, expr) =>
+ Block(ValDef(nextBinder, res) :: stats, expr)
+ case _ =>
+ Block(ValDef(nextBinder, res) :: Nil, next)
+ }
+ } else next
)
ifThenElseZero(cond, rest)
}
diff --git a/src/compiler/scala/tools/nsc/transform/patmat/MatchOptimization.scala b/src/compiler/scala/tools/nsc/transform/patmat/MatchOptimization.scala
index ba25b21820..96438c766b 100644
--- a/src/compiler/scala/tools/nsc/transform/patmat/MatchOptimization.scala
+++ b/src/compiler/scala/tools/nsc/transform/patmat/MatchOptimization.scala
@@ -591,7 +591,7 @@ trait MatchOptimization extends MatchTreeMaking with MatchAnalysis {
with CommonSubconditionElimination {
override def optimizeCases(prevBinder: Symbol, cases: List[List[TreeMaker]], pt: Type, selectorPos: Position): (List[List[TreeMaker]], List[Tree]) = {
// TODO: do CSE on result of doDCE(prevBinder, cases, pt)
- val optCases = doCSE(prevBinder, cases, pt, selectorPos)
+ val optCases = cases //doCSE(prevBinder, cases, pt, selectorPos)
val toHoist = (
for (treeMakers <- optCases)
yield treeMakers.collect{case tm: ReusedCondTreeMaker => tm.treesToHoist}
diff --git a/src/compiler/scala/tools/nsc/transform/patmat/MatchTreeMaking.scala b/src/compiler/scala/tools/nsc/transform/patmat/MatchTreeMaking.scala
index 4a6731744d..1c21959914 100644
--- a/src/compiler/scala/tools/nsc/transform/patmat/MatchTreeMaking.scala
+++ b/src/compiler/scala/tools/nsc/transform/patmat/MatchTreeMaking.scala
@@ -194,7 +194,13 @@ trait MatchTreeMaking extends MatchCodeGen with Debugging {
else {
// only store binders actually used
val (subPatBindersStored, subPatRefsStored) = stored.filter{case (b, _) => usedBinders(b)}.unzip
- Block(map2(subPatBindersStored.toList, subPatRefsStored.toList)(ValDef(_, _)), in)
+ val bindings = map2(subPatBindersStored.toList, subPatRefsStored.toList)(ValDef(_, _))
+ in match {
+ case Block(stats, expr) =>
+ Block(bindings ::: stats, expr)
+ case _ =>
+ Block(bindings, in)
+ }
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment