Skip to content

Instantly share code, notes, and snippets.

@zyxue
Last active November 25, 2022 08:10
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 2 You must be signed in to fork a gist
  • Save zyxue/69100fa70e2b26abc46d229022f2d1ef to your computer and use it in GitHub Desktop.
Save zyxue/69100fa70e2b26abc46d229022f2d1ef to your computer and use it in GitHub Desktop.
import scala.collection.mutable.Map
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.Encoders
import spark.implicits._
import org.apache.spark.sql.types._
case class Span(
ref_name: String,
bc: String,
beg: Int,
end: Int,
read_count: Int)
val spanSchema = StructType(
Array(
StructField("ref_name", StringType, true),
StructField("bc", StringType, true),
StructField("beg", IntegerType, true),
StructField("end", IntegerType, true),
StructField("read_count", IntegerType, true)
)
)
object CalcBreakPoints extends Aggregator[Span, Map[Int, Int], Array[Int]] {
// Reduce an array of spans to coverage, then to break points
// A zero value for this aggregation. Should satisfy the property that any b + zero = b
def zero: Map[Int, Int] = Map[Int, Int]()
// Combine two values to produce a new value. For performance, the function
// may modify `buffer` and return it instead of constructing a new object
def reduce(buffer: Map[Int, Int], span: Span): Map[Int, Int] = {
(span.beg until span.end).foreach(
i => buffer += (i -> (buffer.getOrElse[Int](i, 0) + 1)))
buffer
}
// Merge two intermediate values
def merge(b1: Map[Int, Int], b2: Map[Int, Int]): Map[Int, Int] = {
b2.foreach {
case (key, value) => b1 += (key -> (value + b1.getOrElse[Int](key, 0)))
}
b1
}
// Transform the output of the reduction, convert to BreakPoint
def finish(coverage: Map[Int, Int]): Array[Int] = {
val cov_cutoff = 20;
val f = (i: Int) => if (i >= cov_cutoff) 1 else 0
val coords = coverage.keys.toArray.sorted;
val bp = coords.slice(1, coords.length).map(
c => {
val current = f(coverage(c))
val previous_step = f(coverage.getOrElse(c - 1, 0))
(c, current - previous_step)
})
.filter { case(c, d) => d != 0}
.map {case (c, d) => c}
// val qualified = qualified.slice(1, qualified.length).map {
// case (c, b) =>
// c => if (coverage(c) >= read_count_cutoff) (c, 1) else (c, 0))
// val diff = coords.slice(1, coords.length).map(c => (c, (reduction(c) - reduction.getOrElse(c - 1, 0))))
// val bp = diff.filter {case (c, d) => d != 0} map {case (c, d) => c}
bp
}
// Specifies the Encoder for the intermediate value type
def bufferEncoder: Encoder[Map[Int, Int]] = Encoders.kryo
// Specifies the Encoder for the final output value type
def outputEncoder: Encoder[Array[Int]] = Encoders.kryo
}
val ds = spark.read.option("sep", "\t").schema(spanSchema).csv("/projects/btl/zxue/assembly_correction/celegans/toy_cov.csv").as[Span]
val cc = CalcBreakPoints.toColumn.name("bp")
val res = ds.groupByKey(a => a.ref_name).agg(cc)
res.write.format("parquet").save("./lele.parquet")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment