Skip to content

Instantly share code, notes, and snippets.

@gigamonkey
Forked from johnynek/twophase.scala
Last active December 22, 2015 13:29
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save gigamonkey/6479430 to your computer and use it in GitHub Desktop.
Save gigamonkey/6479430 to your computer and use it in GitHub Desktop.
What I came up with while mucking around trying to understand your code.
// Run this with scala <filename>
import java.util.concurrent.atomic.AtomicLong
val txIds = new AtomicLong(0)
// Dummied up transaction id provider
def nextTxId = txIds.incrementAndGet
/**
* A Two-phase commit Monad.
*/
trait Transaction[+T] {
def map[U](fn: T => U): Transaction[U] = flatMap { t => Constant(fn(t)) }
def flatMap[U](fn: T => Transaction[U]): Transaction[U] = FlatMapped(this, fn)
// In a real system, this should be wrapped in a Future.
def prepare(id: Long): Either[Transaction[T], Prepared[T]]
final def run(txId: Long): T = {
@annotation.tailrec
def go(id: Long, tx: Transaction[T]): T =
tx.prepare(id) match {
case Left(toRetry) => go(nextTxId, toRetry)
case Right(prepped) => prepped.commit.get
}
go(txId, this)
}
}
/**
* Represents the state when commit is possible. A Prepared's value is
* the value that would be committed. In a real system commit and
* rollback should return Futures of the types they return here.
*/
trait Prepared[+T] {
def value: T
def commit: Committed[T]
def rollback: Transaction[T]
}
/**
* A Type wrapper just used to mark success.
*/
case class Committed[+T](get: T)
// The most trivial instance
case class Constant[+T](get: T) extends Transaction[T] { self =>
def prepare(id: Long) = Right(new Prepared[T] {
def value = get
def commit = Committed(get)
def rollback = self
})
}
/** Implementation of flatMap (non-trampolined, so super big transactions can fail) */
case class FlatMapped[R,T](init: Transaction[R], fn: R => Transaction[T]) extends Transaction[T] {
def prepare(id: Long) = init.prepare(id) match {
case Left(ninit) =>
// Couldn't even prepare the transaction that was going to provide our value.
Left(FlatMapped(ninit, fn))
case Right(iprep) =>
val next = fn(iprep.value)
next.prepare(id) match {
case Left(_) => Left(FlatMapped(iprep.rollback, fn))
case Right(rprep) => Right(new Prepared[T] {
def value = rprep.value
def commit = {
// See how we need a Future (or some Monad) to sequence these
iprep.commit
rprep.commit
}
lazy val rollback = {
// See how we need a Future (or some Monad) to sequence these
rprep.rollback
FlatMapped(iprep.rollback, fn)
}
})
}
}
}
import java.util.concurrent.atomic.{AtomicReference => Atom}
// The value of our atomic ref is either a legit value in a Right or a
// prepared value marked with a transaction id in a Left.
type AtomicState[T] = Either[(Long, T, T), T]
// Wrapper for an atomic-ref
class Atomic[T](value: T) {
private val state = new Atom[AtomicState[T]](Right(value))
// The two things we can do: read and modify
def read: Transaction[T] = new AtomicAction(state, identity[T])
def modify(fn: T => T): Transaction[T] = new AtomicAction(state, fn)
}
class AtomicAction[T](atom: Atom[AtomicState[T]], fn: T => T) extends Transaction[T] {
def prepare(id: Long) =
atom.get match {
case Left((oldId, old, nu)) if (oldId != id) => Left(new AtomicAction(atom, fn))
case expected@Right(old) => claimForTx(expected, id, old, fn(old))
case expected@Left((oldId, old, nu)) => claimForTx(expected, id, old, fn(nu))
}
def claimForTx(expectedState: AtomicState[T], id: Long, old: T, nu: T) = {
// Our attempt to claim can fail if some other tx gets in before
// this compareAndSet
if (atom.compareAndSet(expectedState, Left((id, old, nu))))
Right(new Prepared[T] {
def value = nu
lazy val commit = {
atom.get match {
case state@Left((thisId, old, nu)) if (thisId == id) =>
atom.compareAndSet(state, Right(nu))
case _ => ()
}
Committed(value)
}
lazy val rollback = {
atom.get match {
case state@Left((thisId, old, nu)) if (thisId == id) =>
atom.compareAndSet(state, Right(old))
case _ => ()
}
new AtomicAction(atom, fn)
}
})
else
Left(new AtomicAction(atom, fn))
}
}
object Test {
def test(count: Int) = {
val mem = new Atomic[Int](42)
/** Thread that just does subtraction */
val sub = new Thread {
override def run {
@annotation.tailrec
def go(cnt: Int, m: Int): Int = {
if (cnt > 0) {
// Compose monadically -- the whole point of this exercise.
val tx = for {
x <- mem.read
y <- mem.modify { _ => x - 1}
} yield y
go(cnt - 1, m max tx.run(nextTxId))
}
else m
}
println("max seen by subber: " + go(count, 42))
}
}
/** Thread that just does addition */
val add = new Thread {
override def run {
@annotation.tailrec
def go(cnt: Int, m: Int): Int = {
if(cnt > 0) {
// Again, compose monadically
val tx = for {
x <- mem.read
y <- mem.modify { _ => x + 1}
} yield y
go(cnt - 1, m min tx.run(nextTxId))
}
else
m
}
println("min seen by adder: " + go(count, 42))
}
}
sub.start
add.start
sub.join
add.join
println("42 is always the answer (no race conditions) ==> " + mem.read.run(-count))
}
}
Test.test(1000000)
object Test2 {
var flakes = 10
def flakey[T](tx: Transaction[T]) = {
new Transaction[T] {
def prepare(id: Long) =
if (flakes > 0) {
flakes -= 1
Left(this)
} else {
tx.prepare(id)
}
}
}
def test() = {
val mem = new Atomic[Int](0)
var tx = for {
a <- mem.modify { _ + 1 }
b <- mem.modify { _ + 1 }
c <- mem.modify { _ + 1 }
d <- mem.modify { _ + 1 }
e <- flakey(mem.modify { _ + 1 })
f <- mem.modify { _ + 1 }
} yield f
val got = tx.run(0)
println("got: " + got + "; expected: 6")
}
}
Test2.test()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment