Created
January 12, 2018 14:07
-
-
Save ahoy-jon/a2540bc8d284df0bd895dd86b9ad8bbc to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package recursion | |
import matryoshka.Algebra | |
import matryoshka.data.Fix | |
import org.scalatest.FunSuite | |
import slamdata.Predef.Int | |
import scala.annotation.tailrec | |
import scalaz.Free.Trampoline | |
sealed trait Expr | |
case class Times(a:Expr, b:Expr) extends Expr | |
case class Plus(a:Expr,b:Expr) extends Expr | |
case class IntW(i:Int) extends Expr | |
class Recursion extends FunSuite { | |
def evalExpr(expr:Expr):Int = { | |
expr match { | |
//optims for the lulz | |
case Times(IntW(0),_) => 0 | |
case Times(_,IntW(0)) => 0 | |
case Times(a,b) => evalExpr(a) match { | |
case 0 => 0 | |
case x => x * evalExpr(b) | |
} | |
//naive | |
case Times(a,b) => evalExpr(a) * evalExpr(b) | |
case Plus(a,b) => evalExpr(a) + evalExpr(b) | |
case IntW(i) => i | |
} | |
} | |
implicit def intToExpr(i:Int):IntW = IntW(i) | |
test("1") { | |
//1 + 2 * 3 | |
assert(evalExpr(Plus(1,Times(2, 3))) == 7) | |
} | |
test("2") { | |
val seq = 1 to 100000 | |
val res: Expr = seq.foldLeft[Expr](0)((e, i) => Plus(e,i)) | |
//java.lang.StackOverflowError | |
assert(seq.sum == evalExpr(res)) | |
} | |
//we can do better | |
def recursiveFold[B](expr:Expr)(fi:Int => B, fp: (B,B) => B, ft: (B,B) => B):B = { | |
trait StackFrame | |
case class UnaryFrame(b:B) extends StackFrame | |
case class BinaryFrame(lhs:Either[Expr,B], rhs:Either[Expr,B], op:(B,B) => B) extends StackFrame { | |
def add(b:B):BinaryFrame = | |
(lhs, rhs) match { | |
case (Left(_),_) => this.copy(lhs = Right(b)) | |
case (_,Left(_)) => this.copy(rhs = Right(b)) | |
} | |
} | |
def toStackFrame(expr: Expr):StackFrame = expr match { | |
case Plus(a,b) => BinaryFrame(Left(a),Left(b), fp) | |
case Times(a,b) => BinaryFrame(Left(a),Left(b), ft) | |
case IntW(i) => UnaryFrame(fi(i)) | |
} | |
@tailrec | |
def go(stack: List[StackFrame]):B = { | |
stack match { | |
case BinaryFrame(Right(a),Right(b),op) :: tail => go(UnaryFrame(op(a,b)) :: tail) | |
case BinaryFrame(Left(a), _, _) :: _ => go(toStackFrame(a) :: stack) | |
case BinaryFrame(_, Left(a), _) :: _ => go(toStackFrame(a) :: stack) | |
case UnaryFrame(b) :: Nil => b | |
case UnaryFrame(b) :: (nextFrame: BinaryFrame) :: tail => go(nextFrame.add(b) :: tail) | |
} | |
} | |
go(List(toStackFrame(expr))) | |
} | |
test("3") { | |
val seq = 1 to 100000 | |
val res: Expr = seq.foldLeft[Expr](0)((e, i) => Plus(e,i)) | |
assert(seq.sum == recursiveFold[Int](res)(x => x, _ + _ , _ * _)) | |
} | |
def recursiveFoldRPN[B](expr:Expr)(fi:Int => B, fp: (B,B) => B, ft: (B,B) => B):B = { | |
sealed trait StackElement | |
case class ExprW(expr: Expr) extends StackElement | |
case class BinaryOp(op: (B, B) => B) extends StackElement | |
case class Unary(b: B) extends StackElement | |
case class UnaryOp(op: B => B) extends StackElement | |
def splitToStackPart(expr: Expr): List[StackElement] = | |
expr match { | |
case IntW(i) => Unary(fi(i)) :: Nil | |
case Plus(a, b) => ExprW(a) :: ExprW(b) :: BinaryOp(fp) :: Nil | |
case Times(a, b) => ExprW(a) :: ExprW(b) :: BinaryOp(ft) :: Nil | |
} | |
@tailrec | |
def evalRpnStack(stack: List[StackElement]): B = { | |
stack match { | |
case ExprW(a) :: tail => evalRpnStack(splitToStackPart(a) ::: tail) | |
case Unary(a) :: Unary(b) :: BinaryOp(op) :: tail => evalRpnStack(Unary(op(a, b)) :: tail) | |
case Unary(a) :: ExprW(b) :: BinaryOp(op) :: tail => evalRpnStack(splitToStackPart(b) ::: (UnaryOp(op.curried(a)) :: tail)) | |
case Unary(a) :: UnaryOp(op) :: tail => evalRpnStack(Unary(op(a)) :: tail) | |
case Unary(a) :: Nil => a | |
} | |
} | |
evalRpnStack(splitToStackPart(expr)) | |
} | |
test("4") { | |
val seq = 1 to 100000 | |
val res: Expr = seq.foldLeft[Expr](0)((e, i) => Plus(e,i)) | |
assert(seq.sum == recursiveFoldRPN[Int](res)(x => x, _ + _ , _ * _)) | |
} | |
//Le truc c'est que c'est chiant à gérer de manière spécifique à chaque fois, donc on peut abstraire la recursion de manière générique en utilisant ... | |
//Matryoska | |
test("5") { | |
import WithMatryoshka._ | |
val seq = 1 to 100000 | |
val res: Fixed = seq.foldLeft[Fixed](0)((e, i) => plus(e,i)) | |
import slamdata.Predef._ | |
import matryoshka._ | |
import matryoshka.data._ | |
import matryoshka.implicits._ | |
import scalaz._ | |
val evaluate: Algebra[WithMatryoshka.Expr, Int] = { | |
case WithMatryoshka.IntW(v) => v | |
case WithMatryoshka.Plus(x, y) => x + y | |
case WithMatryoshka.Times(x, y) => x * y | |
} | |
//cela explose pareil si | |
assert(seq.sum == res.cata(evaluate)) | |
} | |
test("6") { | |
import WithMatryoshka._ | |
val seq = 1 to 100000 | |
val res: Fixed = seq.foldLeft[Fixed](0)((e, i) => plus(e,i)) | |
import slamdata.Predef._ | |
import matryoshka._ | |
import matryoshka.data._ | |
import matryoshka.implicits._ | |
import scalaz._ | |
val evaluateMG: GAlgebraM[Trampoline,Trampoline, WithMatryoshka.Expr, Int] = { | |
case WithMatryoshka.IntW(v) => Trampoline.done(v) | |
case WithMatryoshka.Plus(x, y) => for {xv <- x | |
yv <- y} | |
yield xv + yv | |
case WithMatryoshka.Times(x, y) => for {xv <- x | |
yv <- y} | |
yield xv * yv | |
} | |
val distribute: | |
/* | |
would be nice, but still explode | |
val evaluateM: AlgebraM[Trampoline, WithMatryoshka.Expr, Int] = { | |
case WithMatryoshka.IntW(v) => Trampoline.done(v) | |
case WithMatryoshka.Plus(x, y) => for {xv <- Trampoline.delay(x) | |
yv <- Trampoline.delay(y)} | |
yield xv + yv | |
case WithMatryoshka.Times(x, y) => for {xv <- Trampoline.delay(x) | |
yv <- Trampoline.delay(y)} | |
yield xv * yv | |
} | |
//cela explose pareil si | |
val value = res.cataM(evaluateM) | |
*/ | |
val value = res.gcataM(evaluateMG) | |
assert(seq.sum == value) | |
} | |
} | |
object WithMatryoshka { | |
import slamdata.Predef._ | |
import matryoshka._ | |
import matryoshka.data._ | |
import matryoshka.implicits._ | |
import scalaz._ | |
sealed trait Expr[A] | |
case class IntW[A](value: Int) extends Expr[A] | |
case class Plus[A](x: A, y: A) extends Expr[A] | |
case class Times[A](x: A, y: A) extends Expr[A] | |
implicit val exprFunctor: Functor[Expr] = new Functor[Expr] { | |
def map[A, B](expr: Expr[A])(f: A => B): Expr[B] = expr match { | |
case IntW(v) => IntW(v) | |
case Plus(x, y) => Plus(f(x), f(y)) | |
case Times(x, y) => Times(f(x), f(y)) | |
} | |
} | |
implicit val exprTraverse: Traverse[Expr] = new Traverse[Expr] { | |
override def traverseImpl[G[_], A, B](fa: Expr[A])(f: (A) => G[B])(implicit app: Applicative[G]): G[Expr[B]] = fa match { | |
case IntW(value) => app.pure(IntW(value)) | |
case Plus(x,y) => app.apply2(f(x), f(y))(Plus(_,_)) | |
case Times(x,y) => app.apply2(f(x),f(y))(Times(_,_)) | |
} | |
} | |
object Fixed { | |
implicit def wrapInt(value: Int): Fix[Expr] = Fix(IntW(value)) | |
def times(x: Fix[Expr], y: Fix[Expr]): Fix[Expr] = Fix(Times(x, y)) | |
def plus(x: Fix[Expr], y: Fix[Expr]): Fix[Expr] = Fix(Plus(x, y)) | |
} | |
type Fixed = Fix[Expr] | |
type Mued = Mu[Expr] | |
object Mued { | |
/* implicit def wrapInt(value: Int):Mued = Mu(IntW(value)) | |
def times(x: Fix[Expr], y: Fix[Expr]): Mued = Fix(Times(x, y)) | |
def plus(x: Fix[Expr], y: Fix[Expr]):Mued = Fix(Plus(x, y)) | |
*/ | |
} | |
} |
vil1
commented
Jan 12, 2018
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment