Created
August 25, 2023 01:49
-
-
Save lmyyao/5b2c42722fcc4a7ddac2968c9238d8e8 to your computer and use it in GitHub Desktop.
product quantization simlarity search demo
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from scipy.cluster.vq import vq, kmeans | |
from scipy.spatial.distance import cdist | |
import numpy as np | |
class PQ(object): | |
def __init__(self, M, k): | |
self.M = M | |
self.k = k | |
def fit(self, vectors, iter=None): | |
assert vectors.ndim == 2, "vectors must be 2 dim array" | |
N, D = vectors.shape | |
assert D % self.M == 0, "input dimension must be dividable by M" | |
self.D = D | |
self.step = int(self.D / self.M) | |
centers_book = {} | |
for index in range(self.M): | |
vec_sub = vectors[:, index * self.step: (index + 1) * self.step] | |
centers, _ = kmeans(vec_sub, self.k) | |
centers_book[index] = centers | |
self.centers_book = centers_book | |
def encode(self, vectors): | |
ndim = vectors.ndim | |
if self.centers_book is None: | |
raise ValueError("call fit first") | |
if ndim == 1: | |
if vectors.shape[0] != self.D: | |
raise ValueError(f"vector must {self.D} dim array") | |
else: | |
vectors = vectors.reshape(1, -1) | |
else: | |
if vectors.shape[-1] != self.D: | |
raise ValueError(f"vector must {self.D} dim array") | |
codes = [] | |
for index in range(self.M): | |
vec_sub = vectors[:, index * self.step: (index + 1) * self.step] | |
code, _ = vq(vec_sub, self.centers_book[index]) | |
codes.append(code) | |
return np.array(codes).T | |
def query_dis_table(self, query): | |
ndim = query.ndim | |
if ndim == 1: | |
if query.shape[0] != self.D: | |
raise ValueError(f"vector must {self.D} dim array") | |
else: | |
query = query.reshape(1, -1) | |
else: | |
raise ValueError(f"query must be vector") | |
dis_book = {} | |
for index in range(self.M): | |
query_sub = query[:, index * self.step: (index + 1) * self.step] | |
dis = cdist(query_sub, self.centers_book[index]) | |
dis_book[index] = dis[0] | |
return dis_book | |
def topk(self, vectors, k=1, query_dis_table=None): | |
codes = self.encode(vectors) | |
distances = [] | |
for vid, code in enumerate(codes): | |
dis_vector = [] | |
for index, c in enumerate(code): | |
dis_vector.append(query_dis_table[index][c]) | |
distances.append(dis_vector) | |
distances = np.array(distances) | |
d = distances.sum(axis=1) | |
return np.argsort(d)[:k] | |
N, D = 10000, 128 | |
X = np.random.random((N, D)).astype(np.float32) | |
query = np.random.random((D,)).astype(np.float32) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment