Skip to content

Instantly share code, notes, and snippets.

@krishnanraman
Created August 19, 2014 01:14
Show Gist options
  • Save krishnanraman/2581b14691074d5a12c9 to your computer and use it in GitHub Desktop.
Save krishnanraman/2581b14691074d5a12c9 to your computer and use it in GitHub Desktop.
import com.twitter.scalding._
import com.twitter.scalding.ExecutionContext._
import com.twitter.algebird.monad._
import com.twitter.scalding.typed.MemorySink
case class BisectPipe[T](pipe:TypedPipe[T], size:Int, sortBy: T => Double) {
def top = BisectPipe(pipe.groupAll.sortBy{ case x:T => sortBy(x) }.take((size/2).toInt).values, (size/2).toInt, sortBy)
def bottom = BisectPipe(pipe.groupAll.sortBy{ case x:T => sortBy(x) }.drop((size/2).toInt).values, (size/2).toInt, sortBy)
}
class bisect(args:Args)extends ExecutionContextJob(args) {
type Weights = Seq[Double]
def dot(a:Weights,b:Weights) = a.zip(b).map { case (c,d) => c * d } sum
def plus(a:Weights,b:Weights) = a.zip(b).map { case (c,d) => c + d }
def prod(a:Weights, b:Double) = a.map { _ * b }
def scale(a:Weights) = if (math.abs(a.min) > 0.1) prod(a, 1.0/a.min) else a
def update(w:Weights, x:(Weights,Int)) = plus(w, x match {case (weight, sign) => prod(weight, sign) }) // w = w + sign(x)*x
def misclassified(w:Weights, x:(Weights,Int)) = x match { case (weight, sign) => math.signum(dot(w, weight)) != sign }
// some very wiggly function
def curvy(x:Double) = math.sin(x*x/2)*(1+math.sqrt(25-(x-5)*(x-5)))*math.exp(-x/10) + 5
def pointAboveBelowLine(x:Double, above:Boolean):Weights = { // For every point (x,y) that lies above/below curve, return (x,y,1)
val sign = if (above) 1.0 else -1.0
val y = curvy(x) + sign * (math.random + 0.3)
Seq(x,y,1)
}
override def job: Reader[ExecutionContext, Nothing] = {
val n = args("n").toInt // number of points
// generate n points, n/2 above the wiggly curve & n/2 below
val pts = TypedPipe.from((0.0 until 10.0 by 10.0/n).toList.zipWithIndex )
.map { xidx =>
val (x, idx) = xidx
val above = (idx%2 == 0)
val sign = if (above) 1 else -1
(pointAboveBelowLine(x, above), sign)
}
def sortBy(x:(Weights, Int)) = {
val (coordinates, classification) = x
coordinates.head // this is the x coordinate. Sorting by x means we go left to right in ascending order
}
val bisectPipe = BisectPipe(pts, n, sortBy)
/*
ALGO:
Try classifying with a pipe n times ( n = 100, say )
If misclassification > 10%, bisect pipe & retry above.
*/
var iteration = 1
classifyOrBisect(bisectPipe)
def classifyOrBisect(points: BisectPipe[(Weights, Int)]):Unit = {
var guess:Weights = Seq(1,1,1) // x + y + 1 = 0
for(n <- 1 to 100) {
// define some sinks
val pointsSink = TypedTsv[(Weights, Int)]("points")
val weightsSink = TypedTsv[Weights]("weights" + iteration)
val sink = new MemorySink[Weights] // save the weights W here as well
val misClassifiedSink = new MemorySink[Seq[Int]] // save the weights W here as well
Execution.waitFor(Config.default, Local(false)) { implicit ec: ExecutionContext =>
points.pipe
.write(pointsSink)(flowDefFromContext, modeFromContext) // points
.groupAll
.foldLeft( guess ) {
(currWeights, point ) =>
if (misclassified(currWeights, point)) {
update(currWeights, point)
} else currWeights
}
.values
.write(sink)(flowDefFromContext, modeFromContext) // model
.write(weightsSink)(flowDefFromContext, modeFromContext) //model
}
iteration += 1
//println(sink.readResults.toList)
if( sink.readResults.toList.size > 0)
guess = sink.readResults.toList.head
Execution.waitFor(Config.default, Local(false)) { implicit ec: ExecutionContext =>
points.pipe.groupAll
.foldLeft(Seq(0,0)){ // Seq(0,0) = Seq(misclassified, total)
(stats, point) =>
if (misclassified(guess, point)) {
Seq((stats.head + 1), (stats.last + 1))
} else {
Seq(stats.head, (stats.last + 1))
}
}
.values
.write(misClassifiedSink)(flowDefFromContext, modeFromContext)
}
if ((n == 100) && (iteration<10000)) {
val stats:Seq[Int] = misClassifiedSink.readResults.toList.head
val misClassified = stats.head*100.0/stats.last
if (misClassified > 10.0) { // over 10% misclassified, so bisect pipe
println("--------------> Bisecting pipe, misclassified: " + misClassified + " iteration: " + iteration)
classifyOrBisect(points.top)
classifyOrBisect(points.bottom)
} else {
println("************* NOT Bisecting pipe, misclassified: " + misClassified + " iteration: " + iteration)
}
}
}// end for
}// end
ReaderFn(ec => { println("exiting"); sys.exit(1) } )
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment