Last active
January 5, 2018 23:51
-
-
Save ahoy-jon/9f00b2e38ccfcf97e347d58836bcb165 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package meta | |
import org.apache.spark.SparkConf | |
import org.apache.spark.sql.types.{ArrayType, StringType, StructField, StructType} | |
import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession} | |
import org.sat4j.core.{Vec, VecInt} | |
import org.sat4j.maxsat.{MinCostDecorator, SolverFactory, WeightedMaxSatDecorator} | |
import org.sat4j.pb.IPBSolver | |
import utils.Gouache | |
import scala.util.Try | |
object SummerizeDs { | |
def optimizeCover[K, T](map: Map[Set[K], T]): Map[Set[K], T] = { | |
optimizeCoverW(map.mapValues(_ -> 1)) | |
} | |
def optimizeCoverW[K, T](map: Map[Set[K], (T, Int)]): Map[Set[K], T] = { | |
val default: IPBSolver = SolverFactory.newDefault() | |
val (index, reverseIndex) = { | |
val source = map.keys.zipWithIndex.map({ case (s, i) ⇒ (s, i + 1) }).toSeq | |
(source.toMap, source.map(_.swap).toMap) | |
} | |
val decorated = new MinCostDecorator(default) | |
decorated.newVar(map.size) | |
index.toSeq.flatMap({ case (s, i) ⇒ s.map(_ -> i) }).groupBy(_._1).foreach({ | |
case (_, s) ⇒ | |
decorated.addClause(new VecInt(s.map(_._2).toArray)) | |
}) | |
index.toSeq.foreach({ | |
case (s, i) ⇒ | |
val cost:Int = map(s)._2 | |
decorated.setCost(i,cost) | |
}) | |
var res:Array[Int] = null | |
Try { | |
//YOLO | |
while (decorated.admitABetterSolution()) { | |
res = decorated.model() | |
decorated.discardCurrentSolution() | |
} | |
} | |
res.filter(_ > 0).map(reverseIndex).map(x ⇒ x -> map(x)._1).toMap | |
} | |
def dataWeight(r: Row): Int = { | |
Option(r.schema) match { | |
case Some(StructType(fields)) ⇒ fields.map({ | |
case StructField(name, ArrayType(_: StructType, _), _, _) ⇒ | |
r.getAs[Seq[Row]](name).map(dataWeight).sum | |
case StructField(name, ArrayType(_, _), _, _) ⇒ | |
r.getAs[Seq[Any]](name).count(_ != null) | |
case StructField(name, _, _, _) ⇒ | |
if (r.getAs[Any](name) != null) 1 else 0 | |
}).sum | |
case None ⇒ r.toSeq.count(_ != null) | |
} | |
} | |
def structureFingerPrint(r: Row): Set[(String, String)] = { | |
val schema: Option[StructType] = Option(r.schema) | |
val res: Seq[(String, String)] = schema match { | |
case Some(StructType(fields)) ⇒ fields.flatMap( | |
{ | |
case StructField(name, ArrayType(_: StructType, _), _, _) ⇒ | |
r.getAs[Seq[Row]](name).flatMap(structureFingerPrint).map({ case (k, v) ⇒ (name + "." + k, v) }) | |
case StructField(name, StringType, _, _) ⇒ | |
Seq(name -> (r.getAs[String](name) match { | |
case null ⇒ "null" | |
case "" ⇒ "empty" | |
case _ ⇒ "notempty" | |
})) | |
case StructField(name, dataType, _, _) ⇒ | |
Seq(name -> (if (r.getAs[Any](name) == null) "null" else "notnull")) | |
} | |
) | |
case None ⇒ | |
r.toSeq.zipWithIndex.map({ case (v, i) ⇒ i.toString -> (if (v == null) "null" else "notnull") }) | |
} | |
res.toSet | |
} | |
def sampleToRepresentStructures(df: DataFrame): DataFrame = { | |
val schema = df.schema | |
val ss = df.sparkSession | |
val res = optimizeCoverW[(String, String), Row]( | |
df.rdd.keyBy(x ⇒ structureFingerPrint(x).toVector.sorted) | |
.mapValues(x ⇒ x -> dataWeight(x)).reduceByKey((a, b) ⇒ if (a._2 < b._2) a else b) | |
.collectAsMap().toMap.map({ | |
case (k, v) ⇒ (k.toSet, v) | |
}) | |
) | |
ss.createDataFrame(ss.sparkContext.makeRDD(res.values.toSeq, 1), schema) | |
} | |
def main(args: Array[String]): Unit = { | |
val conf = new SparkConf() | |
conf.setAppName("test") | |
conf.setMaster("local[*]") | |
val ss = SparkSession.builder().config(conf).getOrCreate() | |
val sqlContext = ss.sqlContext | |
val df = sqlContext.read.load("file://" + Gouache.path("trafic-garanti/trafic-garanti-plan/src/test/resources/modelehat_sample").getAbsolutePath) | |
sampleToRepresentStructures(df).write.mode(SaveMode.Overwrite).parquet("target/out/testModeleHAt") | |
} | |
} | |
object Sizing { | |
def main(args: Array[String]): Unit = { | |
val conf = new SparkConf() | |
conf.setAppName("test") | |
conf.setMaster("local[*]") | |
val ss = SparkSession.builder().config(conf).getOrCreate() | |
println(ss.read.parquet("target/out/testModeleHAt").count()) | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class OptimizeCoverSuite extends FunSuite { | |
test("disjoint") { | |
val map: Map[Set[Int], Int] = Map(Set(1, 2) -> 1, Set(3) -> 2) | |
assert(SummerizeDs.optimizeCover(map) == map) | |
} | |
test("totalOverlap") { | |
val map: Map[Set[Int], Int] = Map(Set(1, 2, 3) -> 1, Set(3) -> 2) | |
assert(SummerizeDs.optimizeCover(map) == Map(Set(1, 2, 3) -> 1)) | |
} | |
test("disjoint x200") { | |
//Should not explode | |
val map: Map[Set[Int], Int] = (0 to 200).map(x ⇒ Set(x) -> x).toMap | |
assert(SummerizeDs.optimizeCover(map) == map) | |
} | |
test("weight test 1") { | |
assert(SummerizeDs.optimizeCoverW(Map( | |
Set(1, 2) -> ("a", 2), | |
Set(1) -> ("b", 2), | |
Set(2) -> ("c", 2))) == | |
Map(Set(1, 2) -> "a")) | |
} | |
test("weight test 2") { | |
assert(SummerizeDs.optimizeCoverW(Map( | |
Set(1, 2, 3) -> ("a", 6), | |
Set(3) -> ("d", 4), | |
Set(1) -> ("b", 4), | |
Set(2) -> ("c", 4))) == | |
Map(Set(1, 2, 3) -> "a")) | |
} | |
test("weight test 3") { | |
assert(SummerizeDs.optimizeCoverW(Map( | |
Set(1, 2, 3) -> ("a", 6), | |
Set(1, 2) -> ("e", 1), | |
Set(3) -> ("d", 4), | |
Set(1) -> ("b", 4), | |
Set(2) -> ("c", 4))) == | |
Map(Set(1, 2) -> "e", Set(3) -> "d" | |
)) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment