Last active
November 10, 2015 19:47
-
-
Save johnynek/2f2c00bf99dcf44feeff to your computer and use it in GitHub Desktop.
mutable variables via a state thread in scala. This is just meant as an illustration of how ST in haskell works, and more generally, how you can implement mutation with immutable APIs.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
object STRef { | |
/** | |
* Here is a container where we have a single "state thread" = ST. | |
* When we run this thread of type ST[T] we get a result of type T. | |
*/ | |
sealed abstract class ST[+T] { | |
def map[R](fn: T => R): ST[R] = flatMap(t => Const(fn(t))) | |
def flatMap[R](fn: T => ST[R]): ST[R] = FlatMapped(this, fn) | |
} | |
/** | |
* This is a mutable reference, but we can only take actions | |
* inside the ST container | |
*/ | |
sealed abstract class Ref[T] { | |
def read: ST[T] = Read(this) | |
def write(t: T): ST[Unit] = Write(t, this) | |
def modify(fn: T => T): ST[Unit] = read.flatMap { t => write(fn(t)) } | |
} | |
private case class Const[T](t: T) extends ST[T] | |
private case class Read[T](ref: Ref[T]) extends ST[T] | |
private case class Write[T](value: T, ref: Ref[T]) extends ST[Unit] | |
private case class FlatMapped[T, R](in: ST[T], fn: T => ST[R]) extends ST[R] | |
/** | |
* Here are how we can make new computations, the rest are created | |
* with flatMap/map or for { } | |
*/ | |
def const[T](t: T): ST[T] = Const(t) | |
def lzy[T](t: => T): ST[T] = Const(()).flatMap(_ => Const(t)) | |
/** | |
* Here is how we create new mutable references. | |
* Note, it is impossible to create a bare Ref[T] outside of a state thread | |
*/ | |
def newRef[T](t: T): ST[Ref[T]] = { | |
val ref = new Ref[T] { } | |
ref.write(t).map(_ => ref) | |
} | |
/** | |
* Here is how we run a state thread. There is mutation, but it is | |
* confined to be inside this method and safe, because no other thread | |
* or functions can see where the mutation is happening, so this is | |
* a pure function, even though internally it is implemented using | |
* mutation. | |
*/ | |
def run[T](st: ST[T]): T = { | |
// This is mutable, but it is local to this method | |
// and no closures can see it, so no references can escape | |
val mmap = collection.mutable.Map[Ref[_], Any]() | |
// exercise: implement this a way that can't overflow the stack. | |
def go[R](in: ST[R]): R = in match { | |
case Const(t) => t | |
case FlatMapped(m, fn) => go(fn(go(m))) | |
case Read(ref) => | |
// we only write using a Write, so we are sure the type of Ref matchs | |
// what is in mmap(ref), so this cast is safe | |
mmap(ref).asInstanceOf[R] | |
case Write(v, ref) => | |
mmap += (ref -> v) | |
// if in is Write, then R must be Unit, but scala can't see that | |
// so this cast is safe | |
().asInstanceOf[R] | |
} | |
go(st) | |
} | |
// Short example: | |
def example(): Unit = { | |
val st = for { | |
ref <- newRef(40) | |
_ <- ref.modify(_ + 2) | |
result <- ref.read | |
} yield result | |
// prints 42 | |
println(run(st)) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment