Created
November 13, 2018 15:59
-
-
Save huyhoang17/6ad54db6d861c10b36a2da047f746d33 to your computer and use it in GitHub Desktop.
PQ code from Product Quantization paper
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from scipy.cluster.vq import vq, kmeans2 | |
from scipy.spatial.distance import cdist | |
def train(vec, M, Ks=256): | |
Ds = int(vec.shape[1] / M) | |
codeword = np.empty((M, Ks, Ds), np.float32) | |
for m in range(M): | |
vec_sub = vec[:, m * Ds: (m + 1) * Ds] | |
codeword[m], _ = kmeans2(vec_sub, Ks) | |
return codeword | |
def encode(codeword, vec): | |
M, Ks, Ds = codeword.shape | |
pqcode = np.empty((vec.shape[0], M), np.uint8) | |
for m in range(M): | |
vec_sub = vec[:, m * Ds: (m + 1) * Ds] | |
pqcode[:, m], _ = vq(vec_sub, codeword[m]) | |
return pqcode | |
def search(codeword, pqcode, query): | |
M, Ks, Ds = codeword.shape | |
dist_table = np.empty((M, Ks), np.float32) | |
for m in range(M): | |
query_sub = query[m * Ds: (m + 1) * Ds] | |
dist_table[m, :] = cdist([query_sub], codeword[m], 'sqeuclidean')[0] | |
dist = np.sum(dist_table[range(M), pqcode], axis=1) | |
return dist | |
if __name__ == '__main__': | |
N, Nt, D = 10000, 2000, 128 | |
# 10,000 128-dim vectors to be indexed | |
vec = np.random.random((N, D)).astype(np.float32) | |
vec_train = np.random.random((Nt, D)).astype( | |
np.float32) # 2,000 128-dim vectors for training | |
query = np.random.random((D,)).astype(np.float32) # a 128-dim query vector | |
M = 8 | |
codeword = train(vec_train, M) | |
pqcode = encode(codeword, vec) | |
dist = search(codeword, pqcode, query) | |
print(dist) | |
mind_ids = dist.argsort()[:10] | |
for id_ in mind_ids: | |
print("Id: {} -> Dist: {}".format(id_, dist[id_])) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment