Skip to content

Instantly share code, notes, and snippets.

@lmyyao
Created August 25, 2023 01:49
Show Gist options
  • Save lmyyao/5b2c42722fcc4a7ddac2968c9238d8e8 to your computer and use it in GitHub Desktop.
Save lmyyao/5b2c42722fcc4a7ddac2968c9238d8e8 to your computer and use it in GitHub Desktop.
product quantization simlarity search demo
from scipy.cluster.vq import vq, kmeans
from scipy.spatial.distance import cdist
import numpy as np
class PQ(object):
def __init__(self, M, k):
self.M = M
self.k = k
def fit(self, vectors, iter=None):
assert vectors.ndim == 2, "vectors must be 2 dim array"
N, D = vectors.shape
assert D % self.M == 0, "input dimension must be dividable by M"
self.D = D
self.step = int(self.D / self.M)
centers_book = {}
for index in range(self.M):
vec_sub = vectors[:, index * self.step: (index + 1) * self.step]
centers, _ = kmeans(vec_sub, self.k)
centers_book[index] = centers
self.centers_book = centers_book
def encode(self, vectors):
ndim = vectors.ndim
if self.centers_book is None:
raise ValueError("call fit first")
if ndim == 1:
if vectors.shape[0] != self.D:
raise ValueError(f"vector must {self.D} dim array")
else:
vectors = vectors.reshape(1, -1)
else:
if vectors.shape[-1] != self.D:
raise ValueError(f"vector must {self.D} dim array")
codes = []
for index in range(self.M):
vec_sub = vectors[:, index * self.step: (index + 1) * self.step]
code, _ = vq(vec_sub, self.centers_book[index])
codes.append(code)
return np.array(codes).T
def query_dis_table(self, query):
ndim = query.ndim
if ndim == 1:
if query.shape[0] != self.D:
raise ValueError(f"vector must {self.D} dim array")
else:
query = query.reshape(1, -1)
else:
raise ValueError(f"query must be vector")
dis_book = {}
for index in range(self.M):
query_sub = query[:, index * self.step: (index + 1) * self.step]
dis = cdist(query_sub, self.centers_book[index])
dis_book[index] = dis[0]
return dis_book
def topk(self, vectors, k=1, query_dis_table=None):
codes = self.encode(vectors)
distances = []
for vid, code in enumerate(codes):
dis_vector = []
for index, c in enumerate(code):
dis_vector.append(query_dis_table[index][c])
distances.append(dis_vector)
distances = np.array(distances)
d = distances.sum(axis=1)
return np.argsort(d)[:k]
N, D = 10000, 128
X = np.random.random((N, D)).astype(np.float32)
query = np.random.random((D,)).astype(np.float32)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment