Skip to content

Instantly share code, notes, and snippets.

@Habitats
Created April 28, 2016 14:12
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Habitats/19fbf3a1d550b8be10b74a10a7af7e6c to your computer and use it in GitHub Desktop.
Save Habitats/19fbf3a1d550b8be10b74a10a7af7e6c to your computer and use it in GitHub Desktop.
def create(rdd: RDD[Article], name:String = "tsne.csv") = {
val iterations: Int = 100
Nd4j.dtype = Type.DOUBLE
Nd4j.factory.setDType(Type.DOUBLE)
Log.v("Load & Vectorize data....")
W2VLoader.preLoad()
val weights = rdd.map(a => (a.iptc, a.toND4JDocumentVector)).flatMap { case (iptc, v) => iptc.map(c => (c, v)) }.collect()
Log.v("Build model....")
val tsne: BarnesHutTsne = new BarnesHutTsne.Builder()
.setMaxIter(iterations)
.theta(0.5)
.normalize(false)
.learningRate(500)
.useAdaGrad(false)
.usePca(false)
.build()
Log.v("Store TSNE Coordinates for Plotting....")
val outputFile: String = Config.dataPath + name
val arr: Array[INDArray] = weights.map(_._2)
val labels: util.List[String] = weights.map(_._1).toList.asJava
val matrix = arr.reduce(Nd4j.vstack(_, _))
tsne.plot(matrix, 2, labels, outputFile)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment