Skip to content

Instantly share code, notes, and snippets.

@ahoy-jon
Created January 12, 2018 14:07
Show Gist options
  • Save ahoy-jon/a2540bc8d284df0bd895dd86b9ad8bbc to your computer and use it in GitHub Desktop.
Save ahoy-jon/a2540bc8d284df0bd895dd86b9ad8bbc to your computer and use it in GitHub Desktop.
package recursion
import matryoshka.Algebra
import matryoshka.data.Fix
import org.scalatest.FunSuite
import slamdata.Predef.Int
import scala.annotation.tailrec
import scalaz.Free.Trampoline
sealed trait Expr
case class Times(a:Expr, b:Expr) extends Expr
case class Plus(a:Expr,b:Expr) extends Expr
case class IntW(i:Int) extends Expr
class Recursion extends FunSuite {
def evalExpr(expr:Expr):Int = {
expr match {
//optims for the lulz
case Times(IntW(0),_) => 0
case Times(_,IntW(0)) => 0
case Times(a,b) => evalExpr(a) match {
case 0 => 0
case x => x * evalExpr(b)
}
//naive
case Times(a,b) => evalExpr(a) * evalExpr(b)
case Plus(a,b) => evalExpr(a) + evalExpr(b)
case IntW(i) => i
}
}
implicit def intToExpr(i:Int):IntW = IntW(i)
test("1") {
//1 + 2 * 3
assert(evalExpr(Plus(1,Times(2, 3))) == 7)
}
test("2") {
val seq = 1 to 100000
val res: Expr = seq.foldLeft[Expr](0)((e, i) => Plus(e,i))
//java.lang.StackOverflowError
assert(seq.sum == evalExpr(res))
}
//we can do better
def recursiveFold[B](expr:Expr)(fi:Int => B, fp: (B,B) => B, ft: (B,B) => B):B = {
trait StackFrame
case class UnaryFrame(b:B) extends StackFrame
case class BinaryFrame(lhs:Either[Expr,B], rhs:Either[Expr,B], op:(B,B) => B) extends StackFrame {
def add(b:B):BinaryFrame =
(lhs, rhs) match {
case (Left(_),_) => this.copy(lhs = Right(b))
case (_,Left(_)) => this.copy(rhs = Right(b))
}
}
def toStackFrame(expr: Expr):StackFrame = expr match {
case Plus(a,b) => BinaryFrame(Left(a),Left(b), fp)
case Times(a,b) => BinaryFrame(Left(a),Left(b), ft)
case IntW(i) => UnaryFrame(fi(i))
}
@tailrec
def go(stack: List[StackFrame]):B = {
stack match {
case BinaryFrame(Right(a),Right(b),op) :: tail => go(UnaryFrame(op(a,b)) :: tail)
case BinaryFrame(Left(a), _, _) :: _ => go(toStackFrame(a) :: stack)
case BinaryFrame(_, Left(a), _) :: _ => go(toStackFrame(a) :: stack)
case UnaryFrame(b) :: Nil => b
case UnaryFrame(b) :: (nextFrame: BinaryFrame) :: tail => go(nextFrame.add(b) :: tail)
}
}
go(List(toStackFrame(expr)))
}
test("3") {
val seq = 1 to 100000
val res: Expr = seq.foldLeft[Expr](0)((e, i) => Plus(e,i))
assert(seq.sum == recursiveFold[Int](res)(x => x, _ + _ , _ * _))
}
def recursiveFoldRPN[B](expr:Expr)(fi:Int => B, fp: (B,B) => B, ft: (B,B) => B):B = {
sealed trait StackElement
case class ExprW(expr: Expr) extends StackElement
case class BinaryOp(op: (B, B) => B) extends StackElement
case class Unary(b: B) extends StackElement
case class UnaryOp(op: B => B) extends StackElement
def splitToStackPart(expr: Expr): List[StackElement] =
expr match {
case IntW(i) => Unary(fi(i)) :: Nil
case Plus(a, b) => ExprW(a) :: ExprW(b) :: BinaryOp(fp) :: Nil
case Times(a, b) => ExprW(a) :: ExprW(b) :: BinaryOp(ft) :: Nil
}
@tailrec
def evalRpnStack(stack: List[StackElement]): B = {
stack match {
case ExprW(a) :: tail => evalRpnStack(splitToStackPart(a) ::: tail)
case Unary(a) :: Unary(b) :: BinaryOp(op) :: tail => evalRpnStack(Unary(op(a, b)) :: tail)
case Unary(a) :: ExprW(b) :: BinaryOp(op) :: tail => evalRpnStack(splitToStackPart(b) ::: (UnaryOp(op.curried(a)) :: tail))
case Unary(a) :: UnaryOp(op) :: tail => evalRpnStack(Unary(op(a)) :: tail)
case Unary(a) :: Nil => a
}
}
evalRpnStack(splitToStackPart(expr))
}
test("4") {
val seq = 1 to 100000
val res: Expr = seq.foldLeft[Expr](0)((e, i) => Plus(e,i))
assert(seq.sum == recursiveFoldRPN[Int](res)(x => x, _ + _ , _ * _))
}
//Le truc c'est que c'est chiant à gérer de manière spécifique à chaque fois, donc on peut abstraire la recursion de manière générique en utilisant ...
//Matryoska
test("5") {
import WithMatryoshka._
val seq = 1 to 100000
val res: Fixed = seq.foldLeft[Fixed](0)((e, i) => plus(e,i))
import slamdata.Predef._
import matryoshka._
import matryoshka.data._
import matryoshka.implicits._
import scalaz._
val evaluate: Algebra[WithMatryoshka.Expr, Int] = {
case WithMatryoshka.IntW(v) => v
case WithMatryoshka.Plus(x, y) => x + y
case WithMatryoshka.Times(x, y) => x * y
}
//cela explose pareil si
assert(seq.sum == res.cata(evaluate))
}
test("6") {
import WithMatryoshka._
val seq = 1 to 100000
val res: Fixed = seq.foldLeft[Fixed](0)((e, i) => plus(e,i))
import slamdata.Predef._
import matryoshka._
import matryoshka.data._
import matryoshka.implicits._
import scalaz._
val evaluateMG: GAlgebraM[Trampoline,Trampoline, WithMatryoshka.Expr, Int] = {
case WithMatryoshka.IntW(v) => Trampoline.done(v)
case WithMatryoshka.Plus(x, y) => for {xv <- x
yv <- y}
yield xv + yv
case WithMatryoshka.Times(x, y) => for {xv <- x
yv <- y}
yield xv * yv
}
val distribute:
/*
would be nice, but still explode
val evaluateM: AlgebraM[Trampoline, WithMatryoshka.Expr, Int] = {
case WithMatryoshka.IntW(v) => Trampoline.done(v)
case WithMatryoshka.Plus(x, y) => for {xv <- Trampoline.delay(x)
yv <- Trampoline.delay(y)}
yield xv + yv
case WithMatryoshka.Times(x, y) => for {xv <- Trampoline.delay(x)
yv <- Trampoline.delay(y)}
yield xv * yv
}
//cela explose pareil si
val value = res.cataM(evaluateM)
*/
val value = res.gcataM(evaluateMG)
assert(seq.sum == value)
}
}
object WithMatryoshka {
import slamdata.Predef._
import matryoshka._
import matryoshka.data._
import matryoshka.implicits._
import scalaz._
sealed trait Expr[A]
case class IntW[A](value: Int) extends Expr[A]
case class Plus[A](x: A, y: A) extends Expr[A]
case class Times[A](x: A, y: A) extends Expr[A]
implicit val exprFunctor: Functor[Expr] = new Functor[Expr] {
def map[A, B](expr: Expr[A])(f: A => B): Expr[B] = expr match {
case IntW(v) => IntW(v)
case Plus(x, y) => Plus(f(x), f(y))
case Times(x, y) => Times(f(x), f(y))
}
}
implicit val exprTraverse: Traverse[Expr] = new Traverse[Expr] {
override def traverseImpl[G[_], A, B](fa: Expr[A])(f: (A) => G[B])(implicit app: Applicative[G]): G[Expr[B]] = fa match {
case IntW(value) => app.pure(IntW(value))
case Plus(x,y) => app.apply2(f(x), f(y))(Plus(_,_))
case Times(x,y) => app.apply2(f(x),f(y))(Times(_,_))
}
}
object Fixed {
implicit def wrapInt(value: Int): Fix[Expr] = Fix(IntW(value))
def times(x: Fix[Expr], y: Fix[Expr]): Fix[Expr] = Fix(Times(x, y))
def plus(x: Fix[Expr], y: Fix[Expr]): Fix[Expr] = Fix(Plus(x, y))
}
type Fixed = Fix[Expr]
type Mued = Mu[Expr]
object Mued {
/* implicit def wrapInt(value: Int):Mued = Mu(IntW(value))
def times(x: Fix[Expr], y: Fix[Expr]): Mued = Fix(Times(x, y))
def plus(x: Fix[Expr], y: Fix[Expr]):Mued = Fix(Plus(x, y))
*/
}
}
@vil1
Copy link

vil1 commented Jan 12, 2018

 object Fixed {
    implicit def wrapInt[T[_[_]]](value: Int)(implicit T: CorecursiveT[T]): T[Expr] = T.embedT[Expr](IntW(value))
    def times[T[_[_]]](x: T[Expr], y: T[Expr])(implicit T: CorecursiveT[T]): T[Expr] = T.embedT[Expr](Times(x, y))
    // etc ...
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment