Skip to content

Instantly share code, notes, and snippets.

@kmader
Last active August 29, 2015 14:04
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save kmader/4e9e99e7bc8c23607b8c to your computer and use it in GitHub Desktop.
Save kmader/4e9e99e7bc8c23607b8c to your computer and use it in GitHub Desktop.
Spark Streaming with Images
import org.apache.spark.SparkContext
import org.apache.spark.SparkContext._
import org.apache.spark.rdd.PairRDDFunctions._
import tipl.spark.IOOps._
import org.apache.spark.streaming.{Seconds, StreamingContext}
import org.apache.spark.streaming.StreamingContext._
val ssc = sc.toStreaming(10) // check the folder every 10s for new images
// read in a directory of tiffs (as a live stream)
val tiffSlices = ssc.tiffFolder("/Volumes/WORKDISK/WorkData/StreamingTests/tif").filter(_._1 contains ".tif")
// read the values as arrays of doubles
val doubleSlices = tiffSlices.loadAsValues
// format for storing image statistics
case class imstats(min: Double, mean: Double, max: Double)
def arrStats(inArr: Array[Double]) = imstats(inArr.min,inArr.sum/(1.0*inArr.length),inArr.max)
// structure for statSlices is (filename,(imstats,imArray))
val statSlices = doubleSlices.mapValues{
cArr => (arrStats(cArr.get),cArr.get)
}
val darks = statSlices.filter{cImg => cImg._2._1.mean<700}
val flats = statSlices.filter{cImg => cImg._2._1.mean>1750}
val projs = statSlices.filter{cImg => cImg._2._1.mean>=700 & cImg._2._1.mean<=1750}
// save just the summary statistics (no image array)
val stripImArray=(inval: (imstats,Array[Double])) => { inval._1 }
darks.mapValues(stripImArray).saveAsTextFiles("/Volumes/WORKDISK/WorkData/StreamingTests/darks.txt")
flats.mapValues(stripImArray).saveAsTextFiles("/Volumes/WORKDISK/WorkData/StreamingTests/flats.txt")
projs.mapValues(stripImArray).saveAsTextFiles("/Volumes/WORKDISK/WorkData/StreamingTests/projs.txt")
ssc.start
import breeze.linalg._
import org.apache.spark.streaming.dstream.DStream
val stripNameAndStats=(inval: (String,(imstats,Array[Double]))) => { DenseVector(inval._2._2) }
def calcAvgImg(inDS: DStream[(String, (imstats, Array[Double]))]) = {
val allImgs = inDS.map(stripNameAndStats).map(invec => (invec,1))
allImgs.reduce{(vec1,vec2) => (vec1._1+vec2._1,vec1._2+vec2._2)}.map{vc => (vc._1/(1.0*vc._2)).toArray}
}
val avgDark = calcAvgImg(darks)
val avgFlat = calcAvgImg(flats)
val avgProj = calcAvgImg(projs)
avgDark.saveAsTextFiles("/Volumes/WORKDISK/WorkData/StreamingTests/adarks.txt")
avgFlat.saveAsTextFiles("/Volumes/WORKDISK/WorkData/StreamingTests/aflats.txt")
avgProj.saveAsTextFiles("/Volumes/WORKDISK/WorkData/StreamingTests/aprojs.txt")
ssc.start
# mllib implementation
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.linalg.{Matrix, Matrices}
import org.apache.spark.mllib.linalg.distributed.RowMatrix
val rows: RDD[Vector] = ... // an RDD of local vectors
// Create a RowMatrix from an RDD[Vector].
val mat: RowMatrix = new RowMatrix(rows)
/** Uses pattern matching to identify slice types and then processes reach subgroup accordingly **/
import breeze.linalg._
def labelSlice(inSlice: (String,(imstats,Array[Double]))) = {
val sliceType = inSlice._2._1.mean match {
case c: Double if c<700 => 0 // dark
case c: Double if c<1750 => 2 // proj
case c: Double if c>=1750 => 1 // flat field
}
(sliceType,(DenseVector(inSlice._2._2),inSlice._1))
}
val groupedSlices = statSlices.map(labelSlice)
# the rdd-based code for each time step
import org.apache.spark.rdd.RDD
def calcAvgImg(inRDD: RDD[(Int, (DenseVector[Double],String))]) = {
val allImgs = inRDD.map{cvec => cvec._2._1}.map(invec => (invec,1))
allImgs.reduce{(vec1,vec2) => (vec1._1+vec2._1,vec1._2+vec2._2)}
}
def correctProj(curProj: DenseVector[Double], darkImg: (DenseVector[Double],Int), flatImg: (DenseVector[Double],Int)) = {
val darkVec = if (darkImg._2>0) darkImg._1/(1.0*darkImg._2) else curProj*0.0
val flatVec = if (flatImg._2>0) flatImg._1/(1.0*flatImg._2) else curProj*0.0+curProj.max
(curProj-darkVec)/(flatVec-darkVec)
}
groupedSlices.foreachRDD{ rdd =>
val avgDark = calcAvgImg(rdd.filter(_._1==0))
val avgFlat = calcAvgImg(rdd.filter(_._1==1))
val projs = rdd.filter(_._1==2).map{evec => (evec._2._2,evec._2._1)}.
mapValues{proj => arrStats(correctProj(proj,avgDark,avgFlat).toArray)}
projs.saveAsTextFile("/Volumes/WORKDISK/WorkData/StreamingTests/cor_projs.txt")
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment