Created
August 21, 2013 04:21
-
-
Save smarden1/6290286 to your computer and use it in GitHub Desktop.
is this crazy
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/scalding-core/src/main/scala/com/twitter/scalding/Job.scala b/scalding-core/src/main/scala/com/twitter/scalding/Job.scala | |
index 367f3f3..ab3c7f8 100644 | |
--- a/scalding-core/src/main/scala/com/twitter/scalding/Job.scala | |
+++ b/scalding-core/src/main/scala/com/twitter/scalding/Job.scala | |
@@ -18,7 +18,7 @@ package com.twitter.scalding | |
import com.twitter.chill.config.{ScalaMapConfig, ConfiguredInstantiator} | |
import cascading.pipe.assembly.AggregateBy | |
-import cascading.flow.{Flow, FlowDef, FlowProps, FlowListener, FlowSkipStrategy, FlowStepStrategy} | |
+import cascading.flow.{Flow, FlowDef, FlowProps, FlowListener, FlowSkipStrategy, FlowStepStrategy, FlowProcess} | |
import cascading.pipe.Pipe | |
import cascading.property.AppProps | |
import cascading.tuple.collect.SpillableProps | |
@@ -42,6 +42,20 @@ object Job { | |
asInstanceOf[Job] | |
} | |
+class FlowProcessWrapper() { | |
+ @transient | |
+ private var flowProcess : Option[FlowProcess[_]] = None | |
+ | |
+ def setFlowProcess(fp : FlowProcess[_]) = | |
+ flowProcess = Some(fp) | |
+ | |
+ def getFlowProcess() : Option[FlowProcess[_]] = | |
+ flowProcess | |
+ | |
+ def incrementCounter(group : String, counter : String, amount : Int) = | |
+ flowProcess.foreach( _.increment(group, counter, amount) ) | |
+} | |
+ | |
class Job(val args : Args) extends FieldConversions with java.io.Serializable { | |
// Set specific Mode | |
implicit def mode : Mode = Mode.getMode(args).getOrElse(sys.error("No Mode defined")) | |
@@ -79,6 +93,11 @@ class Job(val args : Args) extends FieldConversions with java.io.Serializable { | |
fd | |
} | |
+ //val context = JobContext(mode, args, flowDef) | |
+ | |
+ @transient | |
+ implicit val flowProcess = new FlowProcessWrapper() | |
+ | |
/** Copy this job | |
* By default, this uses reflection and the single argument Args constructor | |
*/ | |
diff --git a/scalding-core/src/main/scala/com/twitter/scalding/Operations.scala b/scalding-core/src/main/scala/com/twitter/scalding/Operations.scala | |
index b7b8aec..f99d0a6 100644 | |
--- a/scalding-core/src/main/scala/com/twitter/scalding/Operations.scala | |
+++ b/scalding-core/src/main/scala/com/twitter/scalding/Operations.scala | |
@@ -42,6 +42,21 @@ import com.twitter.scalding.mathematics.Poisson | |
} | |
} | |
+ class MapFunctionWithFlowProcess[S,T](@transient fn : S => T, fields : Fields, | |
+ conv : TupleConverter[S], set : TupleSetter[T], flowProcessWrapper : FlowProcessWrapper) | |
+ extends BaseOperation[Any](fields) with Function[Any] { | |
+ val lockedFn = MeatLocker(fn) | |
+ | |
+ override def prepare(flowProcess: FlowProcess[_], operationCall: OperationCall[Any]) { | |
+ flowProcessWrapper.setFlowProcess(flowProcess) | |
+ } | |
+ | |
+ def operate(flowProcess : FlowProcess[_], functionCall : FunctionCall[Any]) { | |
+ val res = lockedFn.get(conv(functionCall.getArguments)) | |
+ functionCall.getOutputCollector.add(set(res)) | |
+ } | |
+ } | |
+ | |
class MapFunction[S,T](@transient fn : S => T, fields : Fields, | |
conv : TupleConverter[S], set : TupleSetter[T]) | |
extends BaseOperation[Any](fields) with Function[Any] { | |
diff --git a/scalding-core/src/main/scala/com/twitter/scalding/RichPipe.scala b/scalding-core/src/main/scala/com/twitter/scalding/RichPipe.scala | |
index d601807..1b17cd2 100644 | |
--- a/scalding-core/src/main/scala/com/twitter/scalding/RichPipe.scala | |
+++ b/scalding-core/src/main/scala/com/twitter/scalding/RichPipe.scala | |
@@ -31,6 +31,7 @@ import scala.util.Random | |
import java.util.concurrent.atomic.AtomicInteger | |
+ | |
object RichPipe extends java.io.Serializable { | |
private val nextPipe = new AtomicInteger(-1) | |
@@ -375,6 +376,14 @@ class RichPipe(val pipe : Pipe) extends java.io.Serializable with JoinAlgorithms | |
setter.assertArityMatches(fs._2) | |
each(fs)(new MapFunction[A,T](fn, _, conv, setter)) | |
} | |
+ | |
+ def mapFunctionWithFlowProcess[A,T](fs : (Fields,Fields))(fn : A => T) | |
+ (implicit conv : TupleConverter[A], setter : TupleSetter[T], flowProcess : FlowProcessWrapper) : Pipe = { | |
+ conv.assertArityMatches(fs._1) | |
+ setter.assertArityMatches(fs._2) | |
+ each(fs)(new MapFunctionWithFlowProcess[A,T](fn, _, conv, setter, flowProcess)) | |
+ } | |
+ | |
def mapTo[A,T](fs : (Fields,Fields))(fn : A => T) | |
(implicit conv : TupleConverter[A], setter : TupleSetter[T]) : Pipe = { | |
conv.assertArityMatches(fs._1) | |
diff --git a/scalding-core/src/test/scala/com/twitter/scalding/CoreTest.scala b/scalding-core/src/test/scala/com/twitter/scalding/CoreTest.scala | |
index 44dce06..ebe818f 100644 | |
--- a/scalding-core/src/test/scala/com/twitter/scalding/CoreTest.scala | |
+++ b/scalding-core/src/test/scala/com/twitter/scalding/CoreTest.scala | |
@@ -1602,3 +1602,36 @@ class SampleWithReplacementTest extends Specification { | |
} | |
} | |
+ | |
+class FlowProcessJob(args : Args) extends Job(args) { | |
+ Tsv("input",('letter, 'x)) | |
+ .read | |
+ .mapFunctionWithFlowProcess(('letter, 'x) -> 'yPrime){ fields : (String, Int) => | |
+ flowProcess.incrementCounter("a", "b", 10) | |
+ fields._2 + 1 | |
+ .write(Tsv("output")) | |
+} | |
+ | |
+class FlowProcessWrapperTest extends Specification { | |
+ import Dsl._ | |
+ | |
+ noDetailedDiffs() //Fixes an issue with scala 2.9 | |
+ "A FlowProcess" should { | |
+ val input = List(("a", 1),("b", 2), ("c", 3), ("d", 1), ("e", 2)) | |
+ | |
+ JobTest(new FlowProcessJob(_)) | |
+ .source(Tsv("input",('letter,'x)), input) | |
+ .sink[(String, Int)](Tsv("output")) { outBuf => | |
+ "must contain all numbers in input except for 1" in { | |
+ outBuf.toList.sorted must be_==(List(("a", 1), ("b", 2), ("c", 3), ("e", 2))) | |
+ } | |
+ } | |
+ .sink[(String, Int)](Tsv("trapped")) { outBuf => | |
+ "must contain all 1s and fields in input" in { | |
+ outBuf.toList.sorted must be_==(List(("a", 1), ("d", 1))) | |
+ } | |
+ } | |
+ .run | |
+ .finish | |
+ } | |
+} | |
\ No newline at end of file |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment