Create a gist now

Instantly share code, notes, and snippets.

@kmader /1-loadData.scala Secret forked from kmader/1-loadData.scala
Last active Aug 29, 2015

What would you like to do?
A streaming free version of the projection correction code (meant to operate on existing folders)
val imgPath = "/Volumes/WORKDISK/WorkData/StreamingTests/tinytif/*.tif"
val savePath = "/Volumes/WORKDISK/WorkData/StreamingTests/"
import org.apache.spark.SparkContext
import org.apache.spark.SparkContext._
import org.apache.spark.rdd.PairRDDFunctions._
import tipl.spark.IOOps._
import tipl.util.TImgBlock
import tipl.util.TImgTools
// read in a directory of tiffs (as a live stream)
val tiffSlices = sc.tiffFolder(imgPath)
val sliceCnt = tiffSlices.count
// read the values as arrays of doubles
val doubleSlices = tiffSlices.loadAs2D(false).repartition(sliceCnt.toInt/2).cache
// format for storing image statistics
case class imstats(min: Double, mean: Double, max: Double)
def arrStats(inArr: Array[Double]) = imstats(inArr.min,inArr.sum/(1.0*inArr.length),inArr.max)
// structure for statSlices is (filename,(imstats,imArray))
val statSlices = doubleSlices.mapValues{
cArr => (arrStats(cArr.get),cArr.get)
}
/** Uses pattern matching to identify slice types and then processes reach subgroup accordingly **/
import breeze.linalg._
import tipl.util.D3int
import org.apache.spark.rdd.RDD
// classify the slices based on their mean intensity
def labelSlice(inSlice: (D3int,(imstats,Array[Double]))) = {
val sliceType = inSlice._2._1.mean match {
case c: Double if c<700 => 0 // dark
case c: Double if c<1750 => 2 // proj
case c: Double if c>=1750 => 1 // flat field
}
(sliceType,(DenseVector(inSlice._2._2),inSlice._1))
}
val groupedSlices = statSlices.map(labelSlice)
// for averaging together flats and darks
def calcAvgImg(inRDD: RDD[(Int, (DenseVector[Double],D3int))]) = {
val allImgs = inRDD.map{cvec => cvec._2._1}.map(invec => (invec,1))
allImgs.reduce{(vec1,vec2) => (vec1._1+vec2._1,vec1._2+vec2._2)}
}
// for correcting projections and not crashing if the flats or darks are missing
def correctProj(curProj: DenseVector[Double], darkImg: (DenseVector[Double],Int), flatImg: (DenseVector[Double],Int)) = {
val darkVec = if (darkImg._2>0) darkImg._1/(1.0*darkImg._2) else curProj*0.0
val flatVec = if (flatImg._2>0) flatImg._1/(1.0*flatImg._2) else curProj*0.0+curProj.max
(curProj-darkVec)/(flatVec-darkVec)
}
val avgDark = calcAvgImg(groupedSlices.filter(_._1==0))
val avgFlat = calcAvgImg(groupedSlices.filter(_._1==1))
val projs = groupedSlices.filter(_._1==2).map{evec => (evec._2._2,evec._2._1)}.
mapValues{proj => correctProj(proj,avgDark,avgFlat)}
// just write out the statistics
projs.mapValues{proj => arrStats(proj.toArray)}.saveAsTextFile(savePath+"cor_projs.txt")
# mllib implementation
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.linalg.{Matrix, Matrices}
import org.apache.spark.mllib.linalg.distributed.{IndexedRow, IndexedRowMatrix, RowMatrix}
val objSize = doubleSlices.first()._2.getDim()
val projCount = projs.map(_._1.z).max
// sort projects by filename and replace with an index
val idProjs = projs.repartition(projCount).
map(inval => (inval._2,inval._1.z)).
map{inProj => (inProj._2,new DenseMatrix(objSize.x,objSize.y,inProj._1.toArray))}
// calculate the projCount (largest dimension of the output array
val maxSinogramOut = 1000
// flatten out into a list of rows
val idRows = idProjs.flatMap{ inProj =>
val projId = inProj._1
val projData = inProj._2
for(c<-0 until projData.cols if c<maxSinogramOut) yield (c,(projId,projData(::,c)))
}
// generate sinograms from idrows
val idPreSino = idRows.groupByKey.cache()
val idSino = idPreSino.mapValues{
inRows =>
val startMat = DenseMatrix.zeros[Double](objSize.x,projCount+1)
// combine rows into a single output array using fold
inRows.foldLeft(startMat)(
(accMat,newLine) => {
accMat(::,newLine._1) := newLine._2
accMat})
}
// write the sinograms to disk as csv files
idSino.foreach{ csino => csvwrite(new java.io.File(savePath+"sino"+csino._1+".csv"),csino._2)}
// take advantage of mllibs distributed matrix (can calculate SVD, K-Means, etc)
// extract a single row from the image
val sinogramrdd = idRows.filter(_._1._2==100).map(invec => (invec._1._1,invec._2)).map(inv => new IndexedRow(inv._1,Vectors.dense(inv._2.t.toArray)))
// turn it into an indexedrowmatrix so we can take advantage of mllib
val sinogram: IndexedRowMatrix = new IndexedRowMatrix(sinogramrdd)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment