Skip to content

Instantly share code, notes, and snippets.

@libratiger
Created June 13, 2014 12:40
Show Gist options
  • Save libratiger/7e3b2c9eb33ff0037622 to your computer and use it in GitHub Desktop.
Save libratiger/7e3b2c9eb33ff0037622 to your computer and use it in GitHub Desktop.
package mliib
import scala.util.Random
import org.jblas.DoubleMatrix
import org.apache.spark._
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.MLUtils
object DataGenerator {
def main(args: Array[String]) {
val sparkMaster: String = "spark://192.168.35.10:7077"
val outputPath: String = "hdfs://192.168.35.10:54310/kmeans/data2"
val nexamples: Int = 500
val nfeatures: Int = 100000
val parts: Int = 100
val conf = new SparkConf()
conf.setMaster(sparkMaster)
conf.setAppName("svm_geneartor")
conf.setJars(Seq("target/scala-2.10/tinyfish_2.10-1.0.jar"))
conf.set("spark.executor.memory", "60g")
val sc = new SparkContext(conf)
val globalRnd = new Random(94720)
val trueWeights = new DoubleMatrix(1, nfeatures + 1,
Array.fill[Double](nfeatures + 1)(globalRnd.nextGaussian()):_*)
val data: RDD[LabeledPoint] = sc.parallelize(0 until nexamples, parts).map { idx =>
val rnd = new Random(42 + idx)
val x = Array.fill[Double](nfeatures) {
rnd.nextDouble() * 2.0 - 1.0
}
val yD = (new DoubleMatrix(1, x.length, x:_*)).dot(trueWeights) + rnd.nextGaussian() * 0.1
val y = if (yD < 0) 0.0 else 1.0
LabeledPoint(y, x)
}
MLUtils.saveLabeledData(data, outputPath)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment