Last active
January 28, 2017 21:45
-
-
Save gbersac/c71f96379b76f6de1304d35a2b275556 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
trait RNG { | |
def nextInt: (Int, RNG) // Should generate a random `Int`. We'll later define other functions in terms of `nextInt`. | |
} | |
object RNG { | |
// NB - this was called SimpleRNG in the book text | |
case class Simple(seed: Long) extends RNG { | |
def nextInt: (Int, RNG) = { | |
val newSeed = (seed * 0x5DEECE66DL + 0xBL) & 0xFFFFFFFFFFFFL // `&` is bitwise AND. We use the current seed to generate a new seed. | |
val nextRNG = Simple(newSeed) // The next state, which is an `RNG` instance created from the new seed. | |
val n = (newSeed >>> 16).toInt // `>>>` is right binary shift with zero fill. The value `n` is our new pseudo-random integer. | |
(n, nextRNG) // The return value is a tuple containing both a pseudo-random integer and the next `RNG` state. | |
} | |
} | |
case class DummyRNG(i: Int) extends RNG { def nextInt: (Int, RNG) = { (i, RNG.Simple(i)) } } | |
// We need to be quite careful not to skew the generator. | |
// Since `Int.Minvalue` is 1 smaller than `-(Int.MaxValue)`, | |
// it suffices to increment the negative numbers by 1 and make them positive. | |
// This maps Int.MinValue to Int.MaxValue and -1 to 0. | |
def nonNegativeInt(rng: RNG): (Int, RNG) = { | |
val (i, r) = rng.nextInt | |
(if (i < 0) -(i + 1) else i, r) | |
} | |
// We generate an integer >= 0 and divide it by one higher than the | |
// maximum. This is just one possible solution. | |
def double(rng: RNG): (Double, RNG) = { | |
val (i, r) = nonNegativeInt(rng) | |
(i / (Int.MaxValue.toDouble + 1), r) | |
} | |
def boolean(rng: RNG): (Boolean, RNG) = | |
rng.nextInt match { case (i,rng2) => (i%2==0,rng2) } | |
def intDouble(rng: RNG): ((Int, Double), RNG) = { | |
val (i, r1) = rng.nextInt | |
val (d, r2) = double(r1) | |
((i, d), r2) | |
} | |
def doubleInt(rng: RNG): ((Double, Int), RNG) = { | |
val ((i, d), r) = intDouble(rng) | |
((d, i), r) | |
} | |
def double3(rng: RNG): ((Double, Double, Double), RNG) = { | |
val (d1, r1) = double(rng) | |
val (d2, r2) = double(r1) | |
val (d3, r3) = double(r2) | |
((d1, d2, d3), r3) | |
} | |
// There is something terribly repetitive about passing the RNG along | |
// every time. What could we do to eliminate some of this duplication | |
// of effort? | |
// A simple recursive solution | |
def ints(count: Int)(rng: RNG): (List[Int], RNG) = | |
if (count == 0) | |
(List(), rng) | |
else { | |
val (x, r1) = rng.nextInt | |
val (xs, r2) = ints(count - 1)(r1) | |
(x :: xs, r2) | |
} | |
// A tail-recursive solution | |
def ints2(count: Int)(rng: RNG): (List[Int], RNG) = { | |
def go(count: Int, r: RNG, xs: List[Int]): (List[Int], RNG) = | |
if (count == 0) | |
(xs, r) | |
else { | |
val (x, r2) = r.nextInt | |
go(count - 1, r2, x :: xs) | |
} | |
go(count, rng, List()) | |
} | |
type Rand[+A] = RNG => (A, RNG) | |
val int: Rand[Int] = _.nextInt | |
def unit[A](a: A): Rand[A] = rng => (a, rng) | |
def map[A,B](s: Rand[A])(f: A => B): Rand[B] = | |
rng => { | |
val (a, rng2) = s(rng) | |
(f(a), rng2) | |
} | |
val _double: Rand[Double] = | |
map(nonNegativeInt)(_ / (Int.MaxValue.toDouble + 1)) | |
// This implementation of map2 passes the initial RNG to the first argument | |
// and the resulting RNG to the second argument. It's not necessarily wrong | |
// to do this the other way around, since the results are random anyway. | |
// We could even pass the initial RNG to both `f` and `g`, but that might | |
// have unexpected results. E.g. if both arguments are `RNG.int` then we would | |
// always get two of the same `Int` in the result. When implementing functions | |
// like this, it's important to consider how we would test them for | |
// correctness. | |
def map2[A,B,C](ra: Rand[A], rb: Rand[B])(f: (A, B) => C): Rand[C] = | |
rng => { | |
val (a, r1) = ra(rng) | |
val (b, r2) = rb(r1) | |
(f(a, b), r2) | |
} | |
def both[A,B](ra: Rand[A], rb: Rand[B]): Rand[(A,B)] = | |
map2(ra, rb)((_, _)) | |
val randIntDouble: Rand[(Int, Double)] = | |
both(int, double) | |
val randDoubleInt: Rand[(Double, Int)] = | |
both(double, int) | |
// In `sequence`, the base case of the fold is a `unit` action that returns | |
// the empty list. At each step in the fold, we accumulate in `acc` | |
// and `f` is the current element in the list. | |
// `map2(f, acc)(_ :: _)` results in a value of type `Rand[List[A]]` | |
// We map over that to prepend (cons) the element onto the accumulated list. | |
// | |
// We are using `foldRight`. If we used `foldLeft` then the values in the | |
// resulting list would appear in reverse order. It would be arguably better | |
// to use `foldLeft` followed by `reverse`. What do you think? | |
def sequence[A](fs: List[Rand[A]]): Rand[List[A]] = | |
fs.foldRight(unit(List[A]()))((f, acc) => map2(f, acc)(_ :: _)) | |
// It's interesting that we never actually need to talk about the `RNG` value | |
// in `sequence`. This is a strong hint that we could make this function | |
// polymorphic in that type. | |
def _ints(count: Int): Rand[List[Int]] = | |
sequence(List.fill(count)(int)) | |
def flatMap[A,B](f: Rand[A])(g: A => Rand[B]): Rand[B] = | |
rng => { | |
val (a, r1) = f(rng) | |
g(a)(r1) // We pass the new state along | |
} | |
def nonNegativeLessThan(n: Int): Rand[Int] = { | |
flatMap(nonNegativeInt) { i => | |
val mod = i % n | |
if (i + (n-1) - mod >= 0) unit(mod) else nonNegativeLessThan(n) | |
} | |
} | |
def _map[A,B](s: Rand[A])(f: A => B): Rand[B] = | |
flatMap(s)(a => unit(f(a))) | |
def _map2[A,B,C](ra: Rand[A], rb: Rand[B])(f: (A, B) => C): Rand[C] = | |
flatMap(ra)(a => map(rb)(b => f(a, b))) | |
} | |
import State._ | |
case class State[S, +A](run: S => (A, S)) { | |
def map[B](f: A => B): State[S, B] = | |
flatMap(a => State.unit(f(a))) | |
def map2[B,C](sb: State[S, B])(f: (A, B) => C): State[S, C] = | |
flatMap(a => sb.map(b => f(a, b))) | |
def flatMap[B](f: A => State[S, B]): State[S, B] = State(s => { | |
val (a, s1) = run(s) | |
f(a).run(s1) | |
}) | |
} | |
object State { | |
type Rand[A] = State[RNG, A] | |
def unit[S, A](a: A): State[S, A] = | |
State(s => (a, s)) | |
// The idiomatic solution is expressed via foldRight | |
def sequenceViaFoldRight[S,A](sas: List[State[S, A]]): State[S, List[A]] = | |
sas.foldRight(unit[S, List[A]](List()))((f, acc) => f.map2(acc)(_ :: _)) | |
// This implementation uses a loop internally and is the same recursion | |
// pattern as a left fold. It is quite common with left folds to build | |
// up a list in reverse order, then reverse it at the end. | |
// (We could also use a collection.mutable.ListBuffer internally.) | |
def sequence[S, A](sas: List[State[S, A]]): State[S, List[A]] = { | |
def go(s: S, actions: List[State[S,A]], acc: List[A]): (List[A],S) = | |
actions match { | |
case Nil => (acc.reverse,s) | |
case h :: t => h.run(s) match { case (a,s2) => go(s2, t, a :: acc) } | |
} | |
State((s: S) => go(s,sas,List())) | |
} | |
// We can also write the loop using a left fold. This is tail recursive like the | |
// previous solution, but it reverses the list _before_ folding it instead of after. | |
// You might think that this is slower than the `foldRight` solution since it | |
// walks over the list twice, but it's actually faster! The `foldRight` solution | |
// technically has to also walk the list twice, since it has to unravel the call | |
// stack, not being tail recursive. And the call stack will be as tall as the list | |
// is long. | |
def sequenceViaFoldLeft[S,A](l: List[State[S, A]]): State[S, List[A]] = | |
l.reverse.foldLeft(unit[S, List[A]](List()))((acc, f) => f.map2(acc)( _ :: _ )) | |
def modify[S](f: S => S): State[S, Unit] = for { | |
s <- get // Gets the current state and assigns it to `s`. | |
_ <- set(f(s)) // Sets the new state to `f` applied to `s`. | |
} yield () | |
def get[S]: State[S, S] = State(s => (s, s)) | |
def set[S](s: S): State[S, Unit] = State(_ => ((), s)) | |
} | |
//////////////////////////////////////////////////////////////////////////////// | |
//////////////////////////////////////////////////////////////////////////////// | |
//////////////////////////////////////////////////////////////////////////////// | |
case class Gen[A](sample: State[RNG,A]) { | |
def flatMap[B](f: A => Gen[B]): Gen[B] = | |
Gen(sample.flatMap(a => f(a).sample)) | |
def listOfN(size: Int): Gen[List[A]] = | |
Gen(State.sequence(List.fill(size)(sample))) | |
// def listOfN(size: Gen[Int]): Gen[List[A]] = | |
// size.flatMap(i => Gen(State( s => { | |
// val (a, s2) = sample.run(s) | |
// ( List.fill(i)(a), s2) | |
// })) ) | |
def listOfN(size: Gen[Int]): Gen[List[A]] = | |
size flatMap (n => this.listOfN(n)) | |
} | |
object Gen { | |
def choose(start: Int, stopExclusive: Int): Gen[Int] = | |
Gen(State(RNG.nonNegativeInt).map(n => start + n % (stopExclusive-start))) | |
def unit[A](a: => A): Gen[A] = Gen(State(rng => (a, rng))) | |
def boolean: Gen[Boolean] = Gen(State(RNG.map(RNG.int)(i => i % 2 == 0))) | |
def listOfN[A](n: Int, g: Gen[A]): Gen[List[A]] = | |
Gen(State.sequence(List.fill(n)(g.sample))) | |
def weighted[A](g1: (Gen[A],Double), g2: (Gen[A],Double)): Gen[A] = | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment