Skip to content

Instantly share code, notes, and snippets.

@treo
Last active August 30, 2016 09:34
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save treo/f5a346d53f89566b51bf88a9a42c67c7 to your computer and use it in GitHub Desktop.
Save treo/f5a346d53f89566b51bf88a9a42c67c7 to your computer and use it in GitHub Desktop.
/*
Copyright 2016 Paul Dubs
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package com.example;
import org.apache.commons.io.IOUtils;
import org.apache.commons.io.LineIterator;
import org.apache.commons.lang3.tuple.Pair;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.*;
import java.nio.ByteBuffer;
import java.nio.FloatBuffer;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;
import java.nio.charset.StandardCharsets;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.stream.Collectors;
/**
* Utility Methods for working with binary dl4j word vector files
*
* @author Paul Dubs
*/
public class BinaryWordVectorSerializer {
private static final Logger log = LoggerFactory.getLogger(BinaryWordVectorSerializer.class);
public static WordVectors loadWordVectors(String vectorFilePath) throws IOException {
return loadWordVectors(Paths.get(vectorFilePath));
}
public static WordVectors loadWordVectors(Path vectorFilePath) throws IOException {
return null;
}
public static BinaryVectorizer loadBinaryWordVectors(Path vectorsDir) throws IOException {
RandomAccessFile file = new RandomAccessFile(vectorsDir.toFile(), "rw");
String magicString = "dl4jw2v";
byte[] magicBytes = new byte[magicString.length()];
file.readFully(magicBytes);
if (!magicString.equals(new String(magicBytes, StandardCharsets.US_ASCII))) {
throw new IllegalArgumentException("The file you provided is either not a DL4J binary word vectors file or corrupted.");
}
int fileFormatVersion = file.read();
if (1 != fileFormatVersion) {
throw new IllegalArgumentException("Not supported file format version.");
}
ByteBuffer wordCountRaw = ByteBuffer.allocate(4);
file.readFully(wordCountRaw.array());
int wordCount = wordCountRaw.getInt();
ByteBuffer vectorSizeRaw = ByteBuffer.allocate(4);
file.readFully(vectorSizeRaw.array());
int vectorLength = vectorSizeRaw.getInt();
ByteBuffer stringsSizeRaw = ByteBuffer.allocate(4);
file.readFully(stringsSizeRaw.array());
int stringsSize = stringsSizeRaw.getInt();
boolean usingDoubles = file.read() == 1;
long offset = magicBytes.length + 1 + 4 + 4 + 4 + 1;
String[] words = new String[wordCount];
MappedByteBuffer buffer = file.getChannel().map(FileChannel.MapMode.READ_ONLY, offset, stringsSize);
for (int i = 0; i < wordCount; i++) {
int size = buffer.getInt();
byte[] word = new byte[size];
buffer.get(word);
words[i] = new String(word, StandardCharsets.UTF_8);
}
System.out.println("Loaded "+words.length);
return new BinaryVectorizer(file, words, vectorLength, offset + stringsSize);
}
public static void convertTextToBinary(String textSource, String binaryTarget) throws IOException {
convertTextToBinary(Paths.get(textSource), Paths.get(binaryTarget));
}
private static void convertTextToBinary(Path textSource, Path binaryTarget) throws IOException {
BufferedReader reader = new BufferedReader(new FileReader(textSource.toFile()));
BufferedOutputStream output = new BufferedOutputStream(new FileOutputStream(binaryTarget.toFile()));
LineIterator iter = IOUtils.lineIterator(reader);
String line = null;
boolean hasHeader = false;
if (iter.hasNext()) {
line = iter.nextLine(); // skip header line
//look for spaces
if(!line.contains(" "))
hasHeader = true;
}
iter.close();
reader = new BufferedReader(new FileReader(textSource.toFile()));
iter = IOUtils.lineIterator(reader);
//reposition buffer to be one line ahead
if(hasHeader) {
iter.nextLine();
}
LinkedList<Pair<byte[], float[]>> pairs = new LinkedList<>();
while (iter.hasNext()) {
line = iter.nextLine();
String[] split = line.split(" ");
float[] vector = new float[split.length - 1];
String word = split[0];
for (int i = 1; i < split.length; i++) {
vector[i-1] = Float.parseFloat(split[i]);
}
pairs.add(Pair.of(word.getBytes(StandardCharsets.UTF_8), vector));
}
//collect word count and vector size
int wordCount = pairs.size();
int vectorSize = pairs.get(0).getRight().length;
int allStringsLength = pairs.stream().collect(Collectors.summingInt(x -> x.getLeft().length));
int stringsSectionLength = allStringsLength + wordCount * 4;
writeHeader(output, wordCount, vectorSize, stringsSectionLength, false);
pairs.sort((o1, o2) -> new String(o1.getLeft(), StandardCharsets.UTF_8).compareTo(new String(o2.getLeft(), StandardCharsets.UTF_8)));
writeAll(output, pairs);
output.close();
}
private static void writeAll(BufferedOutputStream output, LinkedList<Pair<byte[], float[]>> pairs) throws IOException {
for (Pair<byte[], float[]> pair : pairs) {
output.write(ByteBuffer.allocate(4).putInt(pair.getLeft().length).array());
output.write(pair.getLeft());
}
for (Pair<byte[], float[]> pair : pairs) {
float[] vector = pair.getRight();
ByteBuffer buffer = ByteBuffer.allocate(vector.length * 4);
FloatBuffer floatBuffer = buffer.asFloatBuffer();
floatBuffer.put(vector);
output.write(buffer.array());
}
}
private static void writeHeader(BufferedOutputStream output, int wordCount, int vectorSize, int stringsSectionLength, boolean usingDoubles) throws IOException {
output.write("dl4jw2v".getBytes(StandardCharsets.US_ASCII)); // Magic String to make file recognition easier
output.write(1); // File Format Version
output.write(ByteBuffer.allocate(4).putInt(wordCount).array());
output.write(ByteBuffer.allocate(4).putInt(vectorSize).array());
output.write(ByteBuffer.allocate(4).putInt(stringsSectionLength).array());
output.write(usingDoubles ? 1 : 0);
}
public static class BinaryVectorizer {
private final int vectorSize;
public final String[] words;
private final FloatBuffer[] parts;
private final int maxVectorsPerPartition;
BinaryVectorizer(RandomAccessFile file, String[] words, int vectorSize, long vectorStartOffset) throws IOException {
this.words = words;
this.vectorSize = vectorSize;
int maxBytesPerPartition = Integer.MAX_VALUE;
int maxFloatsPerPartition = maxBytesPerPartition / 4;
this.maxVectorsPerPartition = maxFloatsPerPartition / vectorSize;
int maxPartitionSizeBytes = maxVectorsPerPartition * vectorSize * 4;
int numVectors = words.length;
int neededPartitions = numVectors / maxVectorsPerPartition;
if (numVectors % maxPartitionSizeBytes > 0){
neededPartitions += 1;
}
this.parts = new FloatBuffer[neededPartitions];
FileChannel channel = file.getChannel();
for (int i = 0; i < neededPartitions; i++) {
long start = vectorStartOffset + ((long) i * maxPartitionSizeBytes);
long length = maxPartitionSizeBytes;
if(i == neededPartitions - 1){
length = (numVectors % maxVectorsPerPartition) * vectorSize * 4;
}
parts[i] = channel.map(FileChannel.MapMode.READ_ONLY, start, length).asFloatBuffer();
}
}
public INDArray vectorize(String word) throws IOException {
int vectorIdx = Arrays.binarySearch(words, word);
if(vectorIdx < 0) vectorIdx = 0;
int partitionIdx = vectorIdx / maxVectorsPerPartition;
FloatBuffer part = this.parts[partitionIdx];
int relativeVectorIdx = vectorIdx % maxVectorsPerPartition;
int offset = relativeVectorIdx * vectorSize;
float[] bytes = new float[vectorSize];
part.position(offset);
part.get(bytes);
return Nd4j.create(bytes);
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment