import scala.collection.mutable.Map
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.Encoders
import spark.implicits._
import org.apache.spark.sql.types._
case class Span(
ref_name: String,
bc: String,
beg: Int,
end: Int,
read_count: Int)
val spanSchema = StructType(
StructField("ref_name", StringType, true),
StructField("bc", StringType, true),
StructField("beg", IntegerType, true),
StructField("end", IntegerType, true),
StructField("read_count", IntegerType, true)
object CalcBreakPoints extends Aggregator[Span, Map[Int, Int], Array[Int]] {
// Reduce an array of spans to coverage, then to break points
// A zero value for this aggregation. Should satisfy the property that any b + zero = b
def zero: Map[Int, Int] = Map[Int, Int]()
// Combine two values to produce a new value. For performance, the function
// may modify `buffer` and return it instead of constructing a new object
def reduce(buffer: Map[Int, Int], span: Span): Map[Int, Int] = {
(span.beg until span.end).foreach(
i => buffer += (i -> (buffer.getOrElse[Int](i, 0) + 1)))
// Merge two intermediate values
def merge(b1: Map[Int, Int], b2: Map[Int, Int]): Map[Int, Int] = {
b2.foreach {
case (key, value) => b1 += (key -> (value + b1.getOrElse[Int](key, 0)))
// Transform the output of the reduction, convert to BreakPoint
def finish(coverage: Map[Int, Int]): Array[Int] = {
val cov_cutoff = 20;
val f = (i: Int) => if (i >= cov_cutoff) 1 else 0
val coords = coverage.keys.toArray.sorted;
val bp = coords.slice(1, coords.length).map(
c => {
val current = f(coverage(c))
val previous_step = f(coverage.getOrElse(c - 1, 0))
(c, current - previous_step)
.filter { case(c, d) => d != 0}
.map {case (c, d) => c}
// val qualified = qualified.slice(1, qualified.length).map {
// case (c, b) =>
// c => if (coverage(c) >= read_count_cutoff) (c, 1) else (c, 0))
// val diff = coords.slice(1, coords.length).map(c => (c, (reduction(c) - reduction.getOrElse(c - 1, 0))))
// val bp = diff.filter {case (c, d) => d != 0} map {case (c, d) => c}
// Specifies the Encoder for the intermediate value type
def bufferEncoder: Encoder[Map[Int, Int]] = Encoders.kryo
// Specifies the Encoder for the final output value type
def outputEncoder: Encoder[Array[Int]] = Encoders.kryo
val ds ="sep", "\t").schema(spanSchema).csv("/projects/btl/zxue/assembly_correction/celegans/toy_cov.csv").as[Span]
val cc ="bp")
val res = ds.groupByKey(a => a.ref_name).agg(cc)
