Skip to content

Instantly share code, notes, and snippets.

@johnynek
Last active November 10, 2015 19:47
Show Gist options
  • Star 4 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save johnynek/2f2c00bf99dcf44feeff to your computer and use it in GitHub Desktop.
Save johnynek/2f2c00bf99dcf44feeff to your computer and use it in GitHub Desktop.
mutable variables via a state thread in scala. This is just meant as an illustration of how ST in haskell works, and more generally, how you can implement mutation with immutable APIs.
object STRef {
/**
* Here is a container where we have a single "state thread" = ST.
* When we run this thread of type ST[T] we get a result of type T.
*/
sealed abstract class ST[+T] {
def map[R](fn: T => R): ST[R] = flatMap(t => Const(fn(t)))
def flatMap[R](fn: T => ST[R]): ST[R] = FlatMapped(this, fn)
}
/**
* This is a mutable reference, but we can only take actions
* inside the ST container
*/
sealed abstract class Ref[T] {
def read: ST[T] = Read(this)
def write(t: T): ST[Unit] = Write(t, this)
def modify(fn: T => T): ST[Unit] = read.flatMap { t => write(fn(t)) }
}
private case class Const[T](t: T) extends ST[T]
private case class Read[T](ref: Ref[T]) extends ST[T]
private case class Write[T](value: T, ref: Ref[T]) extends ST[Unit]
private case class FlatMapped[T, R](in: ST[T], fn: T => ST[R]) extends ST[R]
/**
* Here are how we can make new computations, the rest are created
* with flatMap/map or for { }
*/
def const[T](t: T): ST[T] = Const(t)
def lzy[T](t: => T): ST[T] = Const(()).flatMap(_ => Const(t))
/**
* Here is how we create new mutable references.
* Note, it is impossible to create a bare Ref[T] outside of a state thread
*/
def newRef[T](t: T): ST[Ref[T]] = {
val ref = new Ref[T] { }
ref.write(t).map(_ => ref)
}
/**
* Here is how we run a state thread. There is mutation, but it is
* confined to be inside this method and safe, because no other thread
* or functions can see where the mutation is happening, so this is
* a pure function, even though internally it is implemented using
* mutation.
*/
def run[T](st: ST[T]): T = {
// This is mutable, but it is local to this method
// and no closures can see it, so no references can escape
val mmap = collection.mutable.Map[Ref[_], Any]()
// exercise: implement this a way that can't overflow the stack.
def go[R](in: ST[R]): R = in match {
case Const(t) => t
case FlatMapped(m, fn) => go(fn(go(m)))
case Read(ref) =>
// we only write using a Write, so we are sure the type of Ref matchs
// what is in mmap(ref), so this cast is safe
mmap(ref).asInstanceOf[R]
case Write(v, ref) =>
mmap += (ref -> v)
// if in is Write, then R must be Unit, but scala can't see that
// so this cast is safe
().asInstanceOf[R]
}
go(st)
}
// Short example:
def example(): Unit = {
val st = for {
ref <- newRef(40)
_ <- ref.modify(_ + 2)
result <- ref.read
} yield result
// prints 42
println(run(st))
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment