Skip to content

Instantly share code, notes, and snippets.

@clintval
Last active January 2, 2022 22:56
Show Gist options
  • Save clintval/f37798af8c4572aa6999c38dfc124567 to your computer and use it in GitHub Desktop.
Save clintval/f37798af8c4572aa6999c38dfc124567 to your computer and use it in GitHub Desktop.
package io.cvbio.collection
import io.cvbio.io.Io
import com.fulcrumgenomics.commons.collection.SelfClosingIterator
import java.util.concurrent.atomic.AtomicReference
import java.util.concurrent._
import scala.concurrent.duration.{Duration, DurationInt}
import scala.concurrent.{Await, ExecutionContext, Future}
/** Helpers for parallel work over iterators. */
object ParIterator {
/** The default maximum size for capacity queues used in parallel iteration. */
val DefaultQueueCapacity: Int = 128000
/** A helper function for mapping over objects in an iterator in parallel while caching only <capacity> result objects
* at a time. It is recommended to only pass thread-safe functions to the <fn> parameter. Internally, this function
* may use a fixed thread pool and a blocking queue to buffer results of size <capacity> but only if more than one
* thread is requested. If more than 1 thread is requested, then we assume the main thread calling this method will
* not be under load and will be kept waiting. Any exception caused by the source <iterator> or by the applied
* function <fn> will be raised even if running in a multi-threaded context.
*/
def map[A, B](
iterator: Iterator[A],
fn: A => B,
threads: Int = Io.AvailableProcessors,
capacity: Option[Int] = Some(DefaultQueueCapacity)
): Iterator[B] = {
if (threads == 1) { iterator.map(fn) } else {
val pool = Executors.newFixedThreadPool(threads)
val results = iterator.parMap(fn, capacity = capacity)(ExecutionContext.fromExecutorService(pool))
new SelfClosingIterator[B](results, pool.shutdown)
}
}
/** Implicitly add parallel operations onto Scala's base iterator class. */
implicit class ParIteratorImpl[A](private val iterator: Iterator[A]) {
/** Parallelize work over an iterator using a given execution context buffering <capacity> results at a time. An
* additional single-thread execution context will be created to manage the side-effect of submitting all work
* to the primary <executor> which means there is no condition under which this iterator will deadlock infinitely
* unless you ask for infinite timeouts while awaiting computations (default is 1 hour). If <capacity> is set to
* <None> then a dynamically expanding linked blocking queue is used but if <capacity> is set to a fixed size then
* an array blocking queue is used. Any exception caused by the source <iterator> or by the applied function <fn>
* will be raised even if running in a multi-threaded context.
*
* @param fn the method to map over the elements in the iterator.
* @param capacity the number of results to buffer at a time in the underlying blocking queue.
* @param timeOut await each result this amount of time before cancelling the computation and raising an exception.
*/
def parMap[B](fn: A => B, capacity: Option[Int] = Some(DefaultQueueCapacity), timeOut: Duration = 1.hour)(
implicit executor: ExecutionContext
): Iterator[B] = {
val throwable = new AtomicReference[Throwable](null) // A place for any exceptions raised in the source iterator.
val finished = new CountDownLatch(1) // Set this to zero when we have finished sending jobs to the thread pool.
val ioPool = Executors.newSingleThreadExecutor
val ioContext = ExecutionContext.fromExecutorService(ioPool)
// Use a dynamically-expanding queue if no capacity was explicitly asked for, otherwise pre-allocate an array.
val queue: BlockingQueue[Option[Future[B]]] = capacity match {
case Some(size) => new ArrayBlockingQueue(size)
case None => new LinkedBlockingQueue()
}
// Use the IO execution context to fill the queue with results and terminate the queue with a final `None` to
// indicate that the input iterator is fully exhausted and all Futures have been scheduled. We wrap this call in a
// try-catch block in the exceptional case that exceptions are raised not in the input function, but in the source
// iteration itself (`iterator.foreach(???)`)! Any exception will be saved, then the iterator will short-circuit.
// Once the iterator short-circuits, an iterator exhaustion hook (defined below) will be called which includes a
// method to raise the exception properly so it is not silenced. If we did not handle exceptions this way, then
// the iterator could be truncated and data lost.
Future {
try { try iterator.foreach(elem => queue.put(Some(Future(fn(elem))(executor)))) finally queue.put(None) }
catch { case thr: Throwable => throwable.compareAndSet(null, thr) }
finally { finished.countDown() }
} (ioContext)
// Build the return iterator which will await results from the queue until the queue is empty.
new Iterator[B] {
/** Whether or not there is still pending work that is filling the queue with results. */
private var alive: Boolean = true
/** The next element in the queue as that element is pulled from this thread. */
private var nextFuture: Option[Future[B]] = None
/** If the iterator still has more object to yield. */
override def hasNext: Boolean = {
alive && {
if (nextFuture.isEmpty) {
nextFuture = queue.take() match {
case None => alive = false; None
case some => some
}
}
// If there are no more Futures in the queue, then await the signal which indicates submission to the queue
// has finished and any exceptions that were raised are saved to `throwable`. Once the queue is no longer
// needed shutdown the queue to prevent a memory leak. Finally, Raise any exceptions that occurred during
// source iteration so the exceptions are not silently dropped. It is critically important to call these
// methods in this order because a race condition may occur when the source iterator raises an exception
// and short-circuits, but we have not yet had a chance to save the exception message before finishing the
// final call to `hasNext` (occurring in a separate thread). Awaiting the final countdown latch guarantees
// we will raise the exception message if it is present.
if (!alive) {
finished.await()
ioPool.shutdown()
Option(throwable.get).foreach(throw _)
}
alive
}
}
/** Return the next object in the iterator or raise an exception if there are no more objects. */
override def next(): B = {
if (!hasNext) { Iterator.empty.next() } else {
val value = Await.result(nextFuture.get, atMost = timeOut)
nextFuture = None
value
}
}
}
}
}
}
package io.cvbio.collection
import io.cvbio.collection.ParIterator.ParIteratorImpl
import io.cvbio.io.Io
import io.cvbio.testing.UnitSpec
import java.util.concurrent.Executors
import scala.concurrent.ExecutionContext
/** Unit tests for [[ParIterator]]. */
class ParIteratorTest extends UnitSpec {
/** The number of threads to use in all thread pools. */
private val ThreadCount = Io.AvailableProcessors
"ParIterator.map" should "return elements in the correct order when parallelized" in {
val expected = Range(1, 1000).inclusive
val actual = ParIterator.map[Int, Int](
expected.iterator,
identity,
threads = ThreadCount,
capacity = Some(100)
).toSeq
actual should contain theSameElementsInOrderAs expected
}
if (ThreadCount > 1) { // To run these tests you must have more than one available processor.
it should "raise exceptions that occur within passed function running in threads, but only if multiple threads are used" in {
def raise(num: Int): Int = throw new IllegalArgumentException(num.toString)
an[IllegalArgumentException] shouldBe thrownBy {
ParIterator.map(
iterator = Range(1, 10).iterator,
fn = raise,
threads = ThreadCount
).toSeq
}
}
it should "raise exceptions that occur within the input iterator running in threads, but only if multiple threads are used" in {
def raise(num: Int): Int = throw new IllegalArgumentException(num.toString)
an[IllegalArgumentException] shouldBe thrownBy {
ParIterator.map(
iterator = Range(1, 10).iterator.map(raise),
fn = identity[Int],
threads = ThreadCount
).toSeq
}
}
}
"ParIterator.parMap" should "map over elements using a fixed size thread pool and a near-unlimited buffer" in {
val pool = Executors.newFixedThreadPool(ThreadCount)
val context = ExecutionContext.fromExecutorService(pool)
def addTen(int: Int): Int = int + 10
val integers = Range(1, 10)
val actual = integers.iterator.parMap(addTen, capacity = None)(context).toSeq
pool.shutdown()
actual should contain theSameElementsInOrderAs integers.map(addTen)
}
it should "not deadlock if a fixed thread pool with one thread is requested" in {
val pool = Executors.newFixedThreadPool(1)
val context = ExecutionContext.fromExecutorService(pool)
def addTen(int: Int): Int = int + 10
val integers = Range(1, 10)
val actual = integers.iterator.parMap(addTen, capacity = None)(context).toSeq
pool.shutdown()
actual should contain theSameElementsInOrderAs integers.map(addTen)
}
it should "map over elements using a fixed size thread pool and a buffer of a size smaller than the collection" in {
val pool = Executors.newFixedThreadPool(ThreadCount)
val context = ExecutionContext.fromExecutorService(pool)
def addTen(int: Int): Int = int + 10
val integers = Range(1, 10)
val actual = integers.iterator.parMap(addTen, capacity = Some(1))(context).toSeq
pool.shutdown()
actual should contain theSameElementsInOrderAs integers.map(addTen)
}
it should "map over elements using the user-defined execution context and a right-sized buffer" in {
val pool = Executors.newFixedThreadPool(ThreadCount)
val context = ExecutionContext.fromExecutorService(pool)
def addTen(int: Int): Int = int + 10
val integers = Range(1, 10)
val actual = integers.iterator.parMap(addTen, capacity = Some(integers.length))(context).toSeq
pool.shutdown()
actual should contain theSameElementsInOrderAs integers.map(addTen)
}
if (ThreadCount > 1) { // To run these tests you must have more than one available processor.
it should "raise exceptions that occur within passed function running in threads, but only if multiple threads are used" in {
val pool = Executors.newFixedThreadPool(ThreadCount)
val context = ExecutionContext.fromExecutorService(pool)
def raise(num: Int): Int = throw new IllegalArgumentException(num.toString)
an[IllegalArgumentException] shouldBe thrownBy { Range(1, 10).iterator.parMap(raise)(context).toSeq }
pool.shutdown()
}
it should "raise exceptions that occur within the input iterator running in threads, but only if multiple threads are used" in {
val pool = Executors.newFixedThreadPool(ThreadCount)
val context = ExecutionContext.fromExecutorService(pool)
def raise(num: Int): Int = throw new IllegalArgumentException(num.toString)
an[IllegalArgumentException] shouldBe thrownBy { Range(1, 10).iterator.map(raise).parMap(identity[Int])(context).toSeq }
pool.shutdown()
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment