Created
August 26, 2010 21:46
-
-
Save daggerrz/2ee0b136fe17ff4414b7 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package se.scalablesolutions.akka.dispatch | |
import java.util.Iterator | |
import org.specs.Specification | |
import java.util.concurrent.{CountDownLatch, TimeUnit, Semaphore} | |
import concurrent.forkjoin.LinkedTransferQueue | |
class BoundedTransferQueue[E <: AnyRef]( | |
val capacity: Int, | |
val pushTimeout: Long, | |
val pushTimeUnit: TimeUnit) | |
extends LinkedTransferQueue[E] { | |
require(capacity > 0) | |
require(pushTimeout > 0) | |
require(pushTimeUnit ne null) | |
protected val guard = new Semaphore(capacity) | |
//Enqueue an item within the push timeout (acquire Semaphore) | |
protected def enq(f: => Boolean): Boolean = { | |
if (guard.tryAcquire(pushTimeout, pushTimeUnit)) { | |
val result = try { | |
f | |
} catch { | |
case e => | |
guard.release //If something broke, release | |
throw e | |
} | |
if (!result) guard.release //Didn't add anything | |
result | |
} else | |
false | |
} | |
//Dequeue an item (release Semaphore) | |
protected def deq(e: E): E = { | |
if (e ne null) guard.release //Signal removal of item | |
e | |
} | |
override def take(): E = deq(super.take) | |
override def poll(): E = deq(super.poll) | |
override def poll(timeout: Long, unit: TimeUnit): E = deq(super.poll(timeout, unit)) | |
override def remainingCapacity = guard.availablePermits | |
override def remove(o: AnyRef): Boolean = { | |
if (super.remove(o)) { | |
guard.release | |
true | |
} else { | |
false | |
} | |
} | |
override def offer(e: E): Boolean = | |
enq(super.offer(e)) | |
override def offer(e: E, timeout: Long, unit: TimeUnit): Boolean = | |
enq(super.offer(e, timeout, unit)) | |
override def add(e: E): Boolean = | |
enq(super.add(e)) | |
override def put(e: E): Unit = | |
enq({super.put(e); true}) | |
override def tryTransfer(e: E): Boolean = | |
enq(super.tryTransfer(e)) | |
override def tryTransfer(e: E, timeout: Long, unit: TimeUnit): Boolean = | |
enq(super.tryTransfer(e, timeout, unit)) | |
override def transfer(e: E): Unit = | |
enq({super.transfer(e); true}) | |
override def iterator: Iterator[E] = { | |
val it = super.iterator | |
new Iterator[E] { | |
def hasNext = it.hasNext | |
def next = it.next | |
def remove { | |
it.remove | |
guard.release //Assume remove worked if no exception was thrown | |
} | |
} | |
} | |
} | |
class BoundedTransferQueueSpec extends Specification { | |
class Switch extends CountDownLatch(1) { | |
def switch = super.countDown | |
def isSwitched = getCount == 0 | |
} | |
val switch = new Switch | |
val ITEM = new Object | |
def spawn(f: => Unit) : Thread = { | |
val t = new Thread() { | |
override def run = { | |
try { | |
f.apply | |
} catch { | |
case e: InterruptedException => | |
} | |
} | |
} | |
t.start | |
t | |
} | |
def spawn(n : Int)(f: => Unit) : List[Thread] = { | |
(0 to n).map(_ => spawn(f)).toList | |
} | |
def w(millis: Long) = Thread.sleep(millis) | |
def queue(capacity: Int, millis: Long) = new BoundedTransferQueue[Object](capacity, millis, TimeUnit.MILLISECONDS) | |
// These might be useless, they just confirm that the new functionality | |
// hasn't broken the old one. | |
"When capacity is not reached and consumers are available, queue" should { | |
"return false immediately on tryTransfer without timeout specified" in { | |
val q = queue(1, 1) | |
q.tryTransfer(ITEM) must beFalse | |
} | |
"time out on tryTransfer with timeout" in { | |
val q = queue(1, 1) | |
val t = spawn { | |
if (q.tryTransfer(ITEM, 1, TimeUnit.MILLISECONDS)) | |
switch.switch | |
} | |
w(10) | |
t.interrupt | |
switch.isSwitched must beFalse | |
} | |
"wait indefinitely on transfer()" in { | |
val q = queue(1, 1) | |
val t = spawn { | |
q.transfer(ITEM) | |
switch.switch | |
} | |
w(100) | |
t.interrupt | |
switch.isSwitched must beFalse | |
} | |
} | |
// Tests for capacity, i.e the new functionality | |
"Queue" should { | |
"block and grow up to the specified capacity" in { | |
val q = queue(2, 100000) | |
val ts = spawn(2) { q.transfer(ITEM) } // Will block | |
w(100) | |
q.remainingCapacity must_== 0 | |
ts.foreach { _.interrupt } | |
} | |
"block but not grow over specified capacity" in { | |
val q = queue(2, 100000) | |
val ts = spawn(10) { q.transfer(ITEM) } // Will block | |
w(100) | |
q.size must_== 2 | |
ts.foreach { _.interrupt } | |
} | |
} | |
// Release checks | |
"Queue" should { | |
"release capacity as content is consumed" in { | |
val q = queue(10, 100000) | |
val prods = spawn(10) { q.transfer(ITEM) } | |
w(100) | |
q.remainingCapacity must_== 0 | |
val cons = spawn(5) { q.take } | |
w(100) | |
q.remainingCapacity must_== 5 | |
(prods ::: cons).foreach { _.interrupt } | |
q.remainingCapacity must_== 10 | |
} | |
"release capacity even if an exception is thrown inside the guard" in { | |
val q = queue(1, 10000) | |
val ts = spawn (2) { q.transfer(ITEM) } | |
w(100) | |
// First is interrupted waiting for the semaphore | |
// Second is interrupted waiting for a consumer | |
ts.foreach { _.interrupt } | |
q.remainingCapacity must_== 1 | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment