Skip to content

Instantly share code, notes, and snippets.

@treo
Created May 2, 2016 09:14
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save treo/2947b21a55c1b175ac5ed24a8673924d to your computer and use it in GitHub Desktop.
Save treo/2947b21a55c1b175ac5ed24a8673924d to your computer and use it in GitHub Desktop.
Binary Word Vector Serializer
package com.example;
import org.apache.commons.io.IOUtils;
import org.apache.commons.io.LineIterator;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.*;
import java.nio.ByteBuffer;
import java.nio.FloatBuffer;
import java.nio.file.Path;
import java.nio.file.Paths;
import static org.deeplearning4j.models.embeddings.loader.WordVectorSerializer.fromPair;
/**
* Utility Methods for working with binary dl4j word vector files
*
* @author Paul Dubs
*/
public class BinaryWordVectorSerializer {
private static final Logger log = LoggerFactory.getLogger(BinaryWordVectorSerializer.class);
public static WordVectors loadWordVectors(String vectorFilePath) throws IOException {
return loadWordVectors(Paths.get(vectorFilePath));
}
public static WordVectors loadWordVectors(Path vectorFilePath) throws IOException {
Pair<InMemoryLookupTable, VocabCache> pair = loadBinaryWordVectors(vectorFilePath);
return fromPair(pair);
}
public static Pair<InMemoryLookupTable, VocabCache> loadBinaryWordVectors(Path vectorsDir) throws IOException {
try(ObjectInputStream inputStream = new ObjectInputStream(new BufferedInputStream(new FileInputStream(vectorsDir.toFile()), 16*1024*1024))) {
String magicString = inputStream.readUTF();
if (!"dl4jw2v".equals(magicString)) {
throw new IllegalArgumentException("The file you provided is either not a DL4J binary word vectors file or corrupted.");
}
int fileFormatVersion = inputStream.readInt();
if (1 != fileFormatVersion) {
throw new IllegalArgumentException("Not supported file format version.");
}
int wordCount = inputStream.readInt();
int vectorLength = inputStream.readInt();
boolean usingDoubles = inputStream.readBoolean();
INDArray syn = Nd4j.create(wordCount, vectorLength);
AbstractCache<VocabWord> cache = new AbstractCache<>();
byte[] bytes = new byte[vectorLength * 4];
float[] vector = new float[vectorLength];
int[] pos = new int[2];
for (int wordIdx = 0; wordIdx < wordCount; wordIdx++) {
pos[0] = wordIdx;
String word = inputStream.readUTF();
VocabWord vocabWord = new VocabWord(1.0, word);
vocabWord.setIndex(wordIdx);
cache.addToken(vocabWord);
cache.addWordToIndex(wordIdx, word);
cache.putVocabWord(word);
if (usingDoubles) {
throw new IllegalArgumentException("Using Doubles not Implemented!");
}else{
inputStream.readFully(bytes);
FloatBuffer floatBuffer = ByteBuffer.wrap(bytes).asFloatBuffer();
floatBuffer.get(vector);
for (int vecIdx = 0; vecIdx < vector.length; vecIdx++) {
pos[1] = vecIdx;
syn.putScalar(pos, vector[vecIdx]);
}
}
}
InMemoryLookupTable lookupTable = (InMemoryLookupTable) new InMemoryLookupTable.Builder()
.vectorLength(syn.columns())
.useAdaGrad(false).cache(cache)
.build();
lookupTable.setSyn0(syn);
return new Pair<>(lookupTable, cache);
}
}
public static void convertTextToBinary(String textSource, String binaryTarget) throws IOException {
convertTextToBinary(Paths.get(textSource), Paths.get(binaryTarget));
}
private static void convertTextToBinary(Path textSource, Path binaryTarget) throws IOException {
BufferedReader reader = new BufferedReader(new FileReader(textSource.toFile()));
ObjectOutputStream output = new ObjectOutputStream(new BufferedOutputStream(new FileOutputStream(binaryTarget.toFile())));
LineIterator iter = IOUtils.lineIterator(reader);
String line = null;
boolean hasHeader = false;
if (iter.hasNext()) {
line = iter.nextLine(); // skip header line
//look for spaces
if(!line.contains(" "))
hasHeader = true;
}
//reposition buffer to be one line ahead
if(hasHeader) {
iter.close();
reader = new BufferedReader(new FileReader(textSource.toFile()));
iter = IOUtils.lineIterator(reader);
iter.nextLine();
}
//collect word count and vector size
int wordCount = 0;
int vectorSize = 0;
while(iter.hasNext()){
line = iter.next();
wordCount++;
if(vectorSize == 0){
vectorSize = line.split(" ").length - 1;
}
}
iter.close();
reader = new BufferedReader(new FileReader(textSource.toFile()));
iter = IOUtils.lineIterator(reader);
//reposition buffer to be one line ahead
if(hasHeader) {
iter.nextLine();
}
writeHeader(output, wordCount, vectorSize, false);
float[] vector = new float[vectorSize];
while (iter.hasNext()) {
line = iter.nextLine();
String[] split = line.split(" ");
String word = split[0];
for (int i = 1; i < split.length; i++) {
vector[i-1] = Float.parseFloat(split[i]);
}
writeWordVector(output, word, vector);
}
output.close();
}
private static void writeHeader(ObjectOutputStream output, int wordCount, int vectorSize, boolean usingDoubles) throws IOException {
output.writeUTF("dl4jw2v"); // Magic String to make file recognition easier
output.writeInt(1); // File Format Version
output.writeInt(wordCount);
output.writeInt(vectorSize);
output.writeBoolean(usingDoubles);
}
private static void writeWordVector(ObjectOutputStream output, String word, float[] vector) throws IOException {
output.writeUTF(word);
byte[] backingArray = new byte[vector.length * 4];
FloatBuffer floatBuffer = ByteBuffer.wrap(backingArray).asFloatBuffer();
floatBuffer.put(vector);
output.write(backingArray);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment