Skip to content

Instantly share code, notes, and snippets.

@jim-collins
Last active August 29, 2015 14:18
Show Gist options
  • Save jim-collins/ee2a7e0ed1510002dbf3 to your computer and use it in GitHub Desktop.
Save jim-collins/ee2a7e0ed1510002dbf3 to your computer and use it in GitHub Desktop.
package state
import State._
case class State[S,+A](run: S => (A,S)) {
def map[B](f: A => B): State[S, B] =
flatMap(a => unit(f(a)))
def map2[B,C](sb: State[S, B])(f: (A, B) => C): State[S, C] =
flatMap(a => sb.map(b => f(a, b)))
def flatMap[B](f: A => State[S, B]): State[S, B] =
State(s => {
val (a,b) = run(s)
f(a).run(b)
})
def get = State[S, S](s => (s, s))
def set(s: S): State[S, Unit] = State(_ => ((), s))
}
object State {
def unit[S, A](a: A): State[S, A] =
State((s: S) => (a, s))
def sequence[S,A](l: List[State[S, A]]): State[S, List[A]] =
l.reverse.foldLeft(unit[S, List[A]](List()))((acc, f) => f.map2(acc)( _ :: _ ))
}
object RandState {
import scala.util.Random
type Rand[A] = State[Random, A]
def nextInt: Rand[Int] = State(r => (r.nextInt(), r))
def nonNegative(): Rand[Int] =
nextInt.map(a => if (a < 0) -(a + 1) else a)
def nonNegLessThan(n: Int): Rand[Int] =
nonNegative.flatMap{
i =>
val mod = i % n
if (i + (n-1) - mod >= 0) unit(mod) else nonNegLessThan(n)
}
def main(args: Array[String]): Unit = {
nextInt.run(new Random(22L))
nonNegative().run(new Random(22L))
sequence(List(nonNegative(),nonNegative(),nonNegative(),nonNegative())).run(new Random(22))
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment