Skip to content

Instantly share code, notes, and snippets.

@guojc
Last active January 4, 2016 15:49
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save guojc/8643741 to your computer and use it in GitHub Desktop.
Save guojc/8643741 to your computer and use it in GitHub Desktop.
Spark ExternalJoin
import AssemblyKeys._
name :="SparkTest"
version := "0.1"
scalaVersion := "2.10.3"
scalacOptions += "-optimize"
scalacOptions += "-deprecation"
libraryDependencies += "net.sf.trove4j" % "trove4j" % "3.0.3"
libraryDependencies += "org.rogach" %% "scallop" % "0.9.4"
assemblySettings
excludedJars in assembly <<= (fullClasspath in assembly) map { cp =>
cp filter {_.data.getName.contains("spark")}
}
package com.baidu.msa.model;
import com.baidu.msa.model.utility._
import org.apache.spark._
import org.apache.spark.SparkContext
import org.apache.spark.SparkContext._
import org.rogach.scallop._;
import gnu.trove.set.hash.TLongHashSet
import gnu.trove.map.hash.TLongFloatHashMap
import scala.util.Sorting
import org.apache.spark.storage.StorageLevel
import com.esotericsoftware.kryo.Kryo
import org.apache.spark.serializer.KryoRegistrator
class SparkExternalJoinRegistrator extends KryoRegistrator {
override def registerClasses(kryo: Kryo) {
kryo.register(classOf[Ins])
kryo.register(classOf[TLongFloatHashMap], new TLongFloatHashMapSerializer())
kryo.register(classOf[TLongHashSet], new TLongHashSetSerializer())
}
}
object SparkExternalJoin extends App{
def generateKeyPair( dataBlock:(Iterator[(Int,Ins)]) ):Iterator[(Long,Int)]={
val feaSet = new TLongHashSet()
var partitionId :Int =0
for ( insTuple <- dataBlock){
val (id,ins)=insTuple
partitionId=id
var i = 0
while( i < ins.featureSet.length){
feaSet.add(ins.featureSet(i))
i+=1;
}
}
val iterator =feaSet.iterator
val it = new Iterator[(Long,Int)]{
def hasNext = iterator.hasNext
def next = (iterator.next ,partitionId)
}
it
}
System.setProperty("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
System.setProperty("spark.kryo.registrator", "com.baidu.msa.model.SparkExternalJoinRegistrator")
System.setProperty("spark.shuffle.file.buffer.kb","4")
System.setProperty("spark.storage.compression.codec","spark.storage.SnappyCompressionCodec")
val opts = new ScallopConf(args){
version("SparkExternalJoin 1.0 ")
banner("""Usage: SparkExternalJoin [OPTION]... [pet-name]
|Options:
|""".stripMargin)
val cluster=opt[String]("cluster",short='c',required=true)
val ins=opt[String]("ins",required=true)
val model=opt[String]("model",required=true)
val output=opt[String]("output",required=true)
}
val sc = new SparkContext(opts.cluster(), "SparkExternalJoin",System.getenv("SPARK_HOME"))
val workerNumber=160
val insSample = new rdd.CoalescedRDD(sc.textFile( opts.ins()),workerNumber).mapPartitionsWithIndex((id:Int,x:Iterator[String])=>{
val b = new Iterator[(Int,Ins)]{
def hasNext=x.hasNext
def next = (id,Ins(x.next.split("\t")(0)))
}
b},true)
val insWeightSet:rdd.RDD[(Long,Int)] = insSample.mapPartitions( generateKeyPair )
val weightTable:rdd.RDD[(Long,Float)] = new rdd.CoalescedRDD(sc.textFile(opts.model()),workerNumber).map ( line => {val parts=line.split("\t"); (FeaidConverter.fromString(parts(0)),parts(1).toFloat)})
val weightTableJoined= weightTable.join(insWeightSet)
weightTableJoined.saveAsTextFile(opts.output())
}
package com.baidu.msa.model.utility;
import com.esotericsoftware.kryo.Kryo
import com.esotericsoftware.kryo.KryoSerializable
import com.esotericsoftware.kryo.Serializer
import com.esotericsoftware.kryo.io.{Output => KryoOutput,Input =>KryoInput}
import org.apache.spark.serializer.KryoRegistrator
import gnu.trove.set.hash.TLongHashSet
import gnu.trove.map.hash.TLongFloatHashMap
import scala.collection.mutable.ArrayBuilder
import scala.io.Source
class TLongFloatHashMapSerializer extends Serializer[TLongFloatHashMap] {
override def write (kryo: Kryo, output: KryoOutput, map:TLongFloatHashMap ):Unit= {
output.writeInt(map.size)
var count = 0
val it = map.iterator
while(it.hasNext){
it.advance
output.writeLong(it.key)
output.writeFloat(it.value)
count +=1
}
assert(count == map.size)
}
override def read (kryo:Kryo , input:KryoInput, _type:Class[TLongFloatHashMap]):TLongFloatHashMap= {
val size = input.readInt
val map = new TLongFloatHashMap(size)
var count =0
while( count < size){
val key = input.readLong
val value = input.readFloat
map.put(key,value)
count +=1
}
map
}
}
class Ins (val show:Int,val click:Int, val featureSet:Array[Long],val slotSet:Array[Short]) extends java.io.Serializable
object Ins {
def parseFromPlainText(line:String):Ins={
val parts=line.split(" ")
val show=parts(0).toInt
val click=parts(1).toInt
val fea_set=ArrayBuilder.make[Long]()
val slot_set=ArrayBuilder.make[Short]()
for(part<-parts.slice(2,parts.length)){
val Array(fea_id_str,slot_str)=part.split(":")
val fea_id=FeaidConverter.fromString(fea_id_str)
val slot=slot_str.toShort
fea_set+=fea_id
slot_set+=slot
}
new Ins(show,click,fea_set.result,slot_set.result)
}
def apply(line:String)={
parseFromPlainText(line)
}
}
object FeaidConverter{
def fromString(s:String):Long={
BigInt(s).toLong
}
def toString(id:Long):String = {
if( id >=0 )
id.toString
else
(BigInt(id)+(BigInt(Long.MaxValue)+1)*2).toString
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment