Last active
January 4, 2016 15:49
-
-
Save guojc/8643741 to your computer and use it in GitHub Desktop.
Spark ExternalJoin
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import AssemblyKeys._ | |
name :="SparkTest" | |
version := "0.1" | |
scalaVersion := "2.10.3" | |
scalacOptions += "-optimize" | |
scalacOptions += "-deprecation" | |
libraryDependencies += "net.sf.trove4j" % "trove4j" % "3.0.3" | |
libraryDependencies += "org.rogach" %% "scallop" % "0.9.4" | |
assemblySettings | |
excludedJars in assembly <<= (fullClasspath in assembly) map { cp => | |
cp filter {_.data.getName.contains("spark")} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package com.baidu.msa.model; | |
import com.baidu.msa.model.utility._ | |
import org.apache.spark._ | |
import org.apache.spark.SparkContext | |
import org.apache.spark.SparkContext._ | |
import org.rogach.scallop._; | |
import gnu.trove.set.hash.TLongHashSet | |
import gnu.trove.map.hash.TLongFloatHashMap | |
import scala.util.Sorting | |
import org.apache.spark.storage.StorageLevel | |
import com.esotericsoftware.kryo.Kryo | |
import org.apache.spark.serializer.KryoRegistrator | |
class SparkExternalJoinRegistrator extends KryoRegistrator { | |
override def registerClasses(kryo: Kryo) { | |
kryo.register(classOf[Ins]) | |
kryo.register(classOf[TLongFloatHashMap], new TLongFloatHashMapSerializer()) | |
kryo.register(classOf[TLongHashSet], new TLongHashSetSerializer()) | |
} | |
} | |
object SparkExternalJoin extends App{ | |
def generateKeyPair( dataBlock:(Iterator[(Int,Ins)]) ):Iterator[(Long,Int)]={ | |
val feaSet = new TLongHashSet() | |
var partitionId :Int =0 | |
for ( insTuple <- dataBlock){ | |
val (id,ins)=insTuple | |
partitionId=id | |
var i = 0 | |
while( i < ins.featureSet.length){ | |
feaSet.add(ins.featureSet(i)) | |
i+=1; | |
} | |
} | |
val iterator =feaSet.iterator | |
val it = new Iterator[(Long,Int)]{ | |
def hasNext = iterator.hasNext | |
def next = (iterator.next ,partitionId) | |
} | |
it | |
} | |
System.setProperty("spark.serializer", "org.apache.spark.serializer.KryoSerializer") | |
System.setProperty("spark.kryo.registrator", "com.baidu.msa.model.SparkExternalJoinRegistrator") | |
System.setProperty("spark.shuffle.file.buffer.kb","4") | |
System.setProperty("spark.storage.compression.codec","spark.storage.SnappyCompressionCodec") | |
val opts = new ScallopConf(args){ | |
version("SparkExternalJoin 1.0 ") | |
banner("""Usage: SparkExternalJoin [OPTION]... [pet-name] | |
|Options: | |
|""".stripMargin) | |
val cluster=opt[String]("cluster",short='c',required=true) | |
val ins=opt[String]("ins",required=true) | |
val model=opt[String]("model",required=true) | |
val output=opt[String]("output",required=true) | |
} | |
val sc = new SparkContext(opts.cluster(), "SparkExternalJoin",System.getenv("SPARK_HOME")) | |
val workerNumber=160 | |
val insSample = new rdd.CoalescedRDD(sc.textFile( opts.ins()),workerNumber).mapPartitionsWithIndex((id:Int,x:Iterator[String])=>{ | |
val b = new Iterator[(Int,Ins)]{ | |
def hasNext=x.hasNext | |
def next = (id,Ins(x.next.split("\t")(0))) | |
} | |
b},true) | |
val insWeightSet:rdd.RDD[(Long,Int)] = insSample.mapPartitions( generateKeyPair ) | |
val weightTable:rdd.RDD[(Long,Float)] = new rdd.CoalescedRDD(sc.textFile(opts.model()),workerNumber).map ( line => {val parts=line.split("\t"); (FeaidConverter.fromString(parts(0)),parts(1).toFloat)}) | |
val weightTableJoined= weightTable.join(insWeightSet) | |
weightTableJoined.saveAsTextFile(opts.output()) | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package com.baidu.msa.model.utility; | |
import com.esotericsoftware.kryo.Kryo | |
import com.esotericsoftware.kryo.KryoSerializable | |
import com.esotericsoftware.kryo.Serializer | |
import com.esotericsoftware.kryo.io.{Output => KryoOutput,Input =>KryoInput} | |
import org.apache.spark.serializer.KryoRegistrator | |
import gnu.trove.set.hash.TLongHashSet | |
import gnu.trove.map.hash.TLongFloatHashMap | |
import scala.collection.mutable.ArrayBuilder | |
import scala.io.Source | |
class TLongFloatHashMapSerializer extends Serializer[TLongFloatHashMap] { | |
override def write (kryo: Kryo, output: KryoOutput, map:TLongFloatHashMap ):Unit= { | |
output.writeInt(map.size) | |
var count = 0 | |
val it = map.iterator | |
while(it.hasNext){ | |
it.advance | |
output.writeLong(it.key) | |
output.writeFloat(it.value) | |
count +=1 | |
} | |
assert(count == map.size) | |
} | |
override def read (kryo:Kryo , input:KryoInput, _type:Class[TLongFloatHashMap]):TLongFloatHashMap= { | |
val size = input.readInt | |
val map = new TLongFloatHashMap(size) | |
var count =0 | |
while( count < size){ | |
val key = input.readLong | |
val value = input.readFloat | |
map.put(key,value) | |
count +=1 | |
} | |
map | |
} | |
} | |
class Ins (val show:Int,val click:Int, val featureSet:Array[Long],val slotSet:Array[Short]) extends java.io.Serializable | |
object Ins { | |
def parseFromPlainText(line:String):Ins={ | |
val parts=line.split(" ") | |
val show=parts(0).toInt | |
val click=parts(1).toInt | |
val fea_set=ArrayBuilder.make[Long]() | |
val slot_set=ArrayBuilder.make[Short]() | |
for(part<-parts.slice(2,parts.length)){ | |
val Array(fea_id_str,slot_str)=part.split(":") | |
val fea_id=FeaidConverter.fromString(fea_id_str) | |
val slot=slot_str.toShort | |
fea_set+=fea_id | |
slot_set+=slot | |
} | |
new Ins(show,click,fea_set.result,slot_set.result) | |
} | |
def apply(line:String)={ | |
parseFromPlainText(line) | |
} | |
} | |
object FeaidConverter{ | |
def fromString(s:String):Long={ | |
BigInt(s).toLong | |
} | |
def toString(id:Long):String = { | |
if( id >=0 ) | |
id.toString | |
else | |
(BigInt(id)+(BigInt(Long.MaxValue)+1)*2).toString | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment