Skip to content

Instantly share code, notes, and snippets.

@ahoy-jon
Last active January 5, 2018 23:51
Show Gist options
  • Save ahoy-jon/9f00b2e38ccfcf97e347d58836bcb165 to your computer and use it in GitHub Desktop.
Save ahoy-jon/9f00b2e38ccfcf97e347d58836bcb165 to your computer and use it in GitHub Desktop.
package meta
import org.apache.spark.SparkConf
import org.apache.spark.sql.types.{ArrayType, StringType, StructField, StructType}
import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession}
import org.sat4j.core.{Vec, VecInt}
import org.sat4j.maxsat.{MinCostDecorator, SolverFactory, WeightedMaxSatDecorator}
import org.sat4j.pb.IPBSolver
import utils.Gouache
import scala.util.Try
object SummerizeDs {
def optimizeCover[K, T](map: Map[Set[K], T]): Map[Set[K], T] = {
optimizeCoverW(map.mapValues(_ -> 1))
}
def optimizeCoverW[K, T](map: Map[Set[K], (T, Int)]): Map[Set[K], T] = {
val default: IPBSolver = SolverFactory.newDefault()
val (index, reverseIndex) = {
val source = map.keys.zipWithIndex.map({ case (s, i) ⇒ (s, i + 1) }).toSeq
(source.toMap, source.map(_.swap).toMap)
}
val decorated = new MinCostDecorator(default)
decorated.newVar(map.size)
index.toSeq.flatMap({ case (s, i) ⇒ s.map(_ -> i) }).groupBy(_._1).foreach({
case (_, s) ⇒
decorated.addClause(new VecInt(s.map(_._2).toArray))
})
index.toSeq.foreach({
case (s, i) ⇒
val cost:Int = map(s)._2
decorated.setCost(i,cost)
})
var res:Array[Int] = null
Try {
//YOLO
while (decorated.admitABetterSolution()) {
res = decorated.model()
decorated.discardCurrentSolution()
}
}
res.filter(_ > 0).map(reverseIndex).map(x ⇒ x -> map(x)._1).toMap
}
def dataWeight(r: Row): Int = {
Option(r.schema) match {
case Some(StructType(fields)) ⇒ fields.map({
case StructField(name, ArrayType(_: StructType, _), _, _) ⇒
r.getAs[Seq[Row]](name).map(dataWeight).sum
case StructField(name, ArrayType(_, _), _, _) ⇒
r.getAs[Seq[Any]](name).count(_ != null)
case StructField(name, _, _, _) ⇒
if (r.getAs[Any](name) != null) 1 else 0
}).sum
case None ⇒ r.toSeq.count(_ != null)
}
}
def structureFingerPrint(r: Row): Set[(String, String)] = {
val schema: Option[StructType] = Option(r.schema)
val res: Seq[(String, String)] = schema match {
case Some(StructType(fields)) ⇒ fields.flatMap(
{
case StructField(name, ArrayType(_: StructType, _), _, _) ⇒
r.getAs[Seq[Row]](name).flatMap(structureFingerPrint).map({ case (k, v) ⇒ (name + "." + k, v) })
case StructField(name, StringType, _, _) ⇒
Seq(name -> (r.getAs[String](name) match {
case null ⇒ "null"
case "" ⇒ "empty"
case _ ⇒ "notempty"
}))
case StructField(name, dataType, _, _) ⇒
Seq(name -> (if (r.getAs[Any](name) == null) "null" else "notnull"))
}
)
case None ⇒
r.toSeq.zipWithIndex.map({ case (v, i) ⇒ i.toString -> (if (v == null) "null" else "notnull") })
}
res.toSet
}
def sampleToRepresentStructures(df: DataFrame): DataFrame = {
val schema = df.schema
val ss = df.sparkSession
val res = optimizeCoverW[(String, String), Row](
df.rdd.keyBy(x ⇒ structureFingerPrint(x).toVector.sorted)
.mapValues(x ⇒ x -> dataWeight(x)).reduceByKey((a, b) ⇒ if (a._2 < b._2) a else b)
.collectAsMap().toMap.map({
case (k, v) ⇒ (k.toSet, v)
})
)
ss.createDataFrame(ss.sparkContext.makeRDD(res.values.toSeq, 1), schema)
}
def main(args: Array[String]): Unit = {
val conf = new SparkConf()
conf.setAppName("test")
conf.setMaster("local[*]")
val ss = SparkSession.builder().config(conf).getOrCreate()
val sqlContext = ss.sqlContext
val df = sqlContext.read.load("file://" + Gouache.path("trafic-garanti/trafic-garanti-plan/src/test/resources/modelehat_sample").getAbsolutePath)
sampleToRepresentStructures(df).write.mode(SaveMode.Overwrite).parquet("target/out/testModeleHAt")
}
}
object Sizing {
def main(args: Array[String]): Unit = {
val conf = new SparkConf()
conf.setAppName("test")
conf.setMaster("local[*]")
val ss = SparkSession.builder().config(conf).getOrCreate()
println(ss.read.parquet("target/out/testModeleHAt").count())
}
}
class OptimizeCoverSuite extends FunSuite {
test("disjoint") {
val map: Map[Set[Int], Int] = Map(Set(1, 2) -> 1, Set(3) -> 2)
assert(SummerizeDs.optimizeCover(map) == map)
}
test("totalOverlap") {
val map: Map[Set[Int], Int] = Map(Set(1, 2, 3) -> 1, Set(3) -> 2)
assert(SummerizeDs.optimizeCover(map) == Map(Set(1, 2, 3) -> 1))
}
test("disjoint x200") {
//Should not explode
val map: Map[Set[Int], Int] = (0 to 200).map(x ⇒ Set(x) -> x).toMap
assert(SummerizeDs.optimizeCover(map) == map)
}
test("weight test 1") {
assert(SummerizeDs.optimizeCoverW(Map(
Set(1, 2) -> ("a", 2),
Set(1) -> ("b", 2),
Set(2) -> ("c", 2))) ==
Map(Set(1, 2) -> "a"))
}
test("weight test 2") {
assert(SummerizeDs.optimizeCoverW(Map(
Set(1, 2, 3) -> ("a", 6),
Set(3) -> ("d", 4),
Set(1) -> ("b", 4),
Set(2) -> ("c", 4))) ==
Map(Set(1, 2, 3) -> "a"))
}
test("weight test 3") {
assert(SummerizeDs.optimizeCoverW(Map(
Set(1, 2, 3) -> ("a", 6),
Set(1, 2) -> ("e", 1),
Set(3) -> ("d", 4),
Set(1) -> ("b", 4),
Set(2) -> ("c", 4))) ==
Map(Set(1, 2) -> "e", Set(3) -> "d"
))
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment