Skip to content

Instantly share code, notes, and snippets.

@hobwekiva
Last active November 17, 2018 09:08
Show Gist options
  • Save hobwekiva/4412f775328ff25cbe29e63abb33393c to your computer and use it in GitHub Desktop.
Save hobwekiva/4412f775328ff25cbe29e63abb33393c to your computer and use it in GitHub Desktop.
/////////////////////////////////////////////////////////////
// A direct translation of this tweet
// https://twitter.com/tiarkrompf/status/963314799521222656
// to pure functional style.
/////////////////////////////////////////////////////////////
import cats.{Functor, Applicative, Monad}
import cats.effect.{ IO }
import cats.instances.list._
import cats.syntax.all._
final case class Num(val x: Double, val d: IORef[Double]) {
def +(that: Num) = ContT[IO, Unit, Num](k => for {
y <- Num.make(x + that.x)
_ <- k(y)
yd <- y.d.get
_ <- this.d.modify(_ + yd)
_ <- that.d.modify(_ + yd)
} yield ())
def *(that: Num) = ContT[IO, Unit, Num](k => for {
y <- Num.make(x * that.x)
_ <- k(y)
yd <- y.d.get
_ <- this.d.modify(_ + that.x * yd)
_ <- that.d.modify(_ + this.x * yd)
} yield ())
}
object Num {
def make(d: Double): IO[Num] =
IORef(0.0).map(ref => Num(d, ref))
}
def grad(f: Num => ContT[IO, Unit, Num])(x: Double): IO[Double] = for {
x1 <- Num.make(x)
_ <- f(x1)(n => n.d.set(1.0))
x1d <- x1.d.get
} yield x1d
def f(x: Num): ContT[IO, Unit, Num] = for {
x2 <- x * x
x3 <- x2 * x
r <- x + x3
} yield r
(0 until 10).toList.traverse_ { x =>
for {
r <- grad(f)(x)
_ <- putStrLn(s"$r == ${1 + 3 * x * x}")
} yield ()
}.unsafeRunSync()
//////////////////////////////////////////////////////////////////////////
def putStrLn(s: String): IO[Unit] = IO { println(s) }
trait IORef[A] {
def set(a: A): IO[Unit]
def get: IO[A]
def modify(f: A => A): IO[Unit]
}
object IORef {
def apply[A](a: A): IO[IORef[A]] = IO {
new IORef[A] {
var value: A = a
def set(a: A): IO[Unit] = IO { value = a }
def get: IO[A] = IO { value }
def modify(f: A => A): IO[Unit] = IO { value = f(value) }
}
}
}
final case class IndexedContT[F[_], R1, R2, +A](run: (A => F[R2]) => F[R1]) { A =>
def apply(k: A => F[R2]): F[R1] = run(k)
def map[B](f: A => B) =
IndexedContT[F, R1, R2, B](k => A(a => k(f(a))))
def flatMap[R3, B](f: A => IndexedContT[F, R2, R3, B]) =
IndexedContT[F, R1, R3, B](k => A(a => f(a)(k)))
}
object IndexedContT {
def pure[F[_], R, A](a: A) =
IndexedContT[F, R, R, A](k => k(a))
}
type ContT[F[_], R, A] = IndexedContT[F, R, R, A]
object ContT {
def apply[F[_], R, A](k: (A => F[R]) => F[R]): ContT[F, R, A] =
IndexedContT[F, R, R, A](k)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment