Create a gist now

Instantly share code, notes, and snippets.

What would you like to do?
b-bit Minwise hashing の Java 実装です。#DSIRNLP http://partake.in/events/76854228-ba38-4f6e-87b9-f79e30add75c での発表用に実装してみました。
/**
* b-bit Minwise hashing の Java 実装です。
* <p>
* 参考文献 : <a href="http://research.microsoft.com/pubs/120078/wfc0398-lips.pdf">b-Bit Minwise Hashing</a>
* </p>
*
* @author KOMIYA Atsushi
*/
public class MinHash {
private final int numBits;
private final int numSamples;
private final int vectorLength;
private final int maskBits;
private long[] workingMinHashValues;
private int[] seeds;
public MinHash(int b, int k, int vectorLength) {
if (b < 1 || b > 8) {
// 実装が面倒だったので、b は 9 以上を受け付けないようにしている
throw new IllegalArgumentException("b には 1 以上 8 以下の値を指定する必要があります : " + b);
}
if (k < 1) {
throw new IllegalArgumentException("k には 1 以上の値を指定する必要があります : " + k);
}
this.numBits = b;
this.numSamples = k;
this.vectorLength = vectorLength;
this.maskBits = (1 << numBits) - 1;
workingMinHashValues = new long[numSamples];
Random r = new Random(77811215);
seeds = new int[k];
for (int i = 0; i < k; i++) {
seeds[i] = r.nextInt();
}
}
public Context beginCalculation() {
return new Context();
}
Value hashValues(TIntList values) {
Arrays.fill(workingMinHashValues, Integer.MAX_VALUE);
for (TIntIterator i = values.iterator(); i.hasNext(); ) {
int val = i.next();
for (int j = 0; j < numSamples; j++) {
int seed = seeds[j];
long hashValue = XXHash.xxHash4ByteLE(val, seed) & 0xffffffffL;
if (hashValue < workingMinHashValues[j]) {
workingMinHashValues[j] = hashValue;
}
}
}
byte[] bytes = packToByteArray(workingMinHashValues);
return new Value(bytes, values.size());
}
byte[] packToByteArray(long[] values) {
byte[] result = new byte[(numBits * numSamples + 7) / 8];
int bitIndex = 0;
for (int i = 0; i < numSamples; i++) {
int minHashLowestBits = (int) (values[i] & maskBits);
int byteIndex = bitIndex / 8;
int bitShiftCount = bitIndex % 8;
result[byteIndex] = (byte) ((result[byteIndex] | (minHashLowestBits << bitShiftCount)) & 0xff);
bitIndex += numBits;
if (bitIndex / 8 > byteIndex && bitIndex % 8 > 0) {
int remainBits = bitIndex % 8;
result[byteIndex + 1] = (byte) (minHashLowestBits >> (numBits - remainBits));
}
}
return result;
}
public double jaccardCoefficient(Value value1, Value value2) {
double eb = calculateEb(value1.bytes, value2.bytes);
double estimatedR = estimateR(value1.f, value2.f, eb);
return estimatedR;
}
double estimateR(int f1, int f2, double eb) {
final double r1 = (double) f1 / vectorLength;
final double r2 = (double) f2 / vectorLength;
final double oneMinusR1 = 1 - r1;
final double oneMinusR2 = 1 - r2;
final int doubledB = numBits * 2;
final double a1 = r1 * Math.pow(oneMinusR1, doubledB - 1)
/ (1 - Math.pow(oneMinusR1, doubledB));
final double a2 = r2 * Math.pow(oneMinusR2, doubledB - 1)
/ (1 - Math.pow(oneMinusR2, doubledB));
final double r1PlusR2 = r1 + r2;
final double r1DividedByR1PlusR2 = r1 / r1PlusR2;
final double r2DividedByR1PlusR2 = r2 / r1PlusR2;
final double c1 = a1 * r2DividedByR1PlusR2 + a2 * r1DividedByR1PlusR2;
final double c2 = a1 * r1DividedByR1PlusR2 + a2 * r2DividedByR1PlusR2;
final double estimatedR = (eb - c1) / (1 - c2);
return estimatedR;
}
double calculateEb(byte[] bytes1, byte[] bytes2) {
int sameCount = 0;
int bitIndex = 0;
for (int i = 0; i < numSamples; i++) {
int byteIndex = bitIndex / 8;
int bitShiftCount = bitIndex % 8;
int bits1 = bytes1[byteIndex] >>> bitShiftCount;
int bits2 = bytes2[byteIndex] >>> bitShiftCount;
bitIndex += numBits;
if (bitIndex / 8 > byteIndex && bitIndex % 8 > 0) {
int remainBits = bitIndex % 8;
int shiftBits = (numBits - remainBits);
bits1 |= bytes1[byteIndex + 1] << shiftBits;
bits2 |= bytes2[byteIndex + 1] << shiftBits;
}
bits1 &= maskBits;
bits2 &= maskBits;
if (bits1 == bits2) {
sameCount++;
}
}
return (double) sameCount / numSamples;
}
public class Context {
private TIntList values = new TIntArrayList();
public void add(int val) {
values.add(val);
}
public Value hash() {
return hashValues(values);
}
}
public static class Value {
final byte[] bytes;
final int f;
Value(byte[] bytes, int f) {
this.bytes = bytes;
this.f = f;
}
}
}
class XXHash {
private static final int PRIME32_2 = (int) 2246822519L;
private static final int PRIME32_3 = (int) 3266489917L;
private static final int PRIME32_4 = 668265263;
private static final int PRIME32_5 = 374761393;
public static int xxHash4ByteLE(int val, int seed) {
long h32 = seed + PRIME32_5 + 4;
h32 += (val & 0xffffffffL) * PRIME32_3;
h32 &= 0xffffffffL;
h32 = (((h32 << 17) & 0xffffffffL) | (h32 >>> (32 - 17))) * PRIME32_4;
h32 &= 0xffffffffL;
h32 ^= h32 >>> 15;
h32 *= PRIME32_2;
h32 &= 0xffffffffL;
h32 ^= h32 >>> 13;
h32 *= PRIME32_3;
h32 &= 0xffffffffL;
h32 ^= h32 >>> 16;
return (int) h32;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment