Skip to content

Instantly share code, notes, and snippets.

@Habitats
Created May 9, 2016 14:19
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Habitats/d6abd0032b46dfe3b9b590aae7e89e6f to your computer and use it in GitHub Desktop.
Save Habitats/d6abd0032b46dfe3b9b590aae7e89e6f to your computer and use it in GitHub Desktop.
def cache(rdd: RDD): String = {
val docVecs = rdd.map(a => s"${a.id},${a.docVec.data().asFloat().mkString(",")}")
saveAsText(docVecs, s"document_vectors_${loader.confidence}")
}
def loadVectors(vectorFile: String): Map[String, INDArray] = {
Log.v(s"Loading cached W2V vectors ($vectorFile) ...")
if (!new File(vectorFile).exists) throw new FileNotFoundException(s"No cached vectors: ${vectorFile}")
val start = System.currentTimeMillis
var vec = sc.textFile(vectorFile)
.map(_.split(","))
.map(arr => (arr(0), arr.toSeq.slice(1, arr.length).map(_.toFloat).toArray))
.map(arr => {
val vector = Nd4j.create(arr._2)
val id = arr._1
(id, vector)
}).collect.toMap
// SHOULD ALWAYS BE EQUAL!
val s = vec.size
vec = vec.filter(_._2.size(1) == 1000)
if (s != vec.size) throw new IllegalStateException(s"Vector filtering went wrong. Should be ${s}, was ${vec.size}!")
Log.v(s"Loaded vectors in ${System.currentTimeMillis() - start} ms")
vec
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment