Created
March 11, 2010 00:03
-
-
Save aboisvert/328621 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Interactive example: | |
Welcome to Scala version 2.8.0.Beta1-prerelease (Java HotSpot(TM) Server VM, Java 1.6.0_17). | |
Type in expressions to have them evaluated. | |
Type :help for more information. | |
scala> :load MapReduce.scala | |
Loading MapReduce.scala... | |
defined class MapReduce | |
scala> case class Input(val x: Int) | |
defined class Input | |
scala> case class Output(val x: Int) | |
defined class Output | |
scala> val inputs = Array.fill[Input](1000) { Input(1) } | |
inputs: Array[Input] = Array(Input(1), Input(1), Input(1), Input(1), Input(1), ... | |
scala> val mr = new MapReduce[Input, Output] { | |
| def map(input: Input) = { | |
| // double input value | |
| Output(input.x*2) | |
| } | |
| | |
| def reduce(o1: Output, o2: Output) = { | |
| // add up all outputs | |
| Output(o1.x+o2.x) | |
| } | |
| } | |
mr: MapReduce[Input,Output] = $anon$1@1c98c1b | |
scala> val future = mr.submit(inputs) | |
future: java.util.concurrent.Future[Output] = MapReduce$Job@1b31c23 | |
scala> future.get | |
res0: Output = Output(2000) | |
---------8<-------- MapReduce.scala -------8<------------------ | |
abstract class MapReduce[Input <: AnyRef, Output <: AnyRef] { | |
import java.util.concurrent.{ExecutorService, Executors, Future, FutureTask, TimeUnit} | |
/** Executor service to concurrent processing. | |
* Defaults to a fixed thread pool with (availableProcessors + 1) threads. | |
*/ | |
val executor: ExecutorService = Executors.newFixedThreadPool( | |
Runtime.getRuntime.availableProcessors + 1 | |
) | |
/** Implementation should override this operation */ | |
def map(input: Input): Output | |
/** Implementation should override this operation */ | |
def reduce(o1: Output, o2: Output): Output | |
/** Optional callback upon failure of worker */ | |
def reportException(t: Throwable) { | |
// t.printStackTrace | |
} | |
/** Submit a number of inputs to be mapped and reduced. | |
* Returns a Future with the eventual output. | |
*/ | |
final def submit(inputs: Traversable[Input]): Future[Output] = { | |
val job = new Job(inputs.size) | |
for (i <- inputs) { | |
executor submit (new Worker(i, job)) | |
} | |
job | |
} | |
/** Job holds completion status and computation output */ | |
private class Job(val workersExpected: Int) extends Future[Output] { | |
@volatile private var cancelled = false | |
private var workersCompleted = 0 | |
private var exception: Throwable = _ | |
var output: Output = _ | |
/** Signal completion of worker */ | |
private[MapReduce] def workerCompleted(): Unit = synchronized { | |
workersCompleted += 1 | |
if (workersCompleted == workersExpected) { | |
notifyAll() | |
} | |
} | |
/** Signal an exception during processing */ | |
private[MapReduce] def reportException(t: Throwable): Unit = synchronized { | |
exception = t | |
notifyAll() | |
} | |
/** Attempts to cancel execution of this task. */ | |
override def cancel(mayInterruptIfRunning: Boolean): Boolean = { | |
cancelled = true | |
(workersCompleted != workersExpected) && (exception ne null) | |
} | |
/** Waits if necessary for the computation to complete, and then retrieves its result. */ | |
override def get: Output = get(-1L, TimeUnit.MILLISECONDS) | |
/** Waits if necessary for at most the given time for the computation to complete, | |
* and then retrieves its result, if available. | |
*/ | |
def get(timeout: Long, unit: TimeUnit): Output = { | |
val start = System.currentTimeMillis | |
var deadline = if (timeout >= 0) { | |
System.currentTimeMillis + unit.toMillis(timeout) | |
} else { | |
Long.MaxValue | |
} | |
while (true) { | |
synchronized { | |
if (cancelled) { | |
throw new java.util.concurrent.CancellationException("MapReduce was cancelled") | |
} | |
if (exception ne null) { | |
throw exception | |
} | |
if (workersCompleted == workersExpected) { | |
return output | |
} | |
wait(deadline - System.currentTimeMillis) | |
} | |
} | |
error("Unreachable") | |
} | |
/** Returns true if this task was cancelled before it completed normally. */ | |
def isCancelled: Boolean = cancelled | |
/** Returns true if this task completed. Completion may be due to normal termination, | |
* an exception, or cancellation -- in all of these cases, this method will return true. | |
*/ | |
def isDone: Boolean = synchronized { | |
(workersCompleted == workersExpected) || (exception ne null) || cancelled | |
} | |
} | |
/** Map-reduce worker. */ | |
private class Worker(input: Input, job: Job) extends Runnable { | |
def run: Unit = { | |
try { | |
if (job.isCancelled) return | |
// first perform the map() operation | |
var output = map(input) | |
// reduce output with existing output until there | |
// is no outstanding output available | |
while (output ne null) { | |
if (job.isCancelled) return | |
val existing: Output = job.synchronized { | |
val existing = job.output | |
if (existing eq null) { | |
job.output = output | |
output = null.asInstanceOf[Output] | |
} else { | |
job.output = null.asInstanceOf[Output] | |
} | |
existing | |
} | |
// reduce happens outside of synchronized block | |
// which means other threads can provide additional output, | |
// or take, reduce and update the output. | |
if (existing ne null) { | |
output = reduce(existing, output) | |
} | |
} | |
} catch { | |
case ex => job.reportException(ex) | |
} finally { | |
job.workerCompleted() | |
} | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment