Skip to content

Instantly share code, notes, and snippets.

@lmyyao
Created August 24, 2023 11:04
Show Gist options
  • Save lmyyao/7c6f685750b43224a1a3f5e690baef84 to your computer and use it in GitHub Desktop.
Save lmyyao/7c6f685750b43224a1a3f5e690baef84 to your computer and use it in GitHub Desktop.
product_quantization demo
from scipy.cluster.vq import vq, kmeans
import numpy as np
class PQ(object):
def __init__(self, M, k):
self.M = M
self.k = k
def fit(self, vectors, iter=None):
assert vectors.ndim == 2, "vectors must be 2 dim array"
N, D = vectors.shape
assert D % self.M == 0, "input dimension must be dividable by M"
self.D = D
self.step = int(self.D / self.M)
centers_book = {}
for index in range(self.M):
vec_sub = vectors[:, index * self.step: (index + 1) * self.step]
centers, _ = kmeans(vec_sub, self.k)
centers_book[index] = centers
self.centers_book = centers_book
def encode(self, vectors):
ndim = vectors.ndim
if ndim == 1:
if vectors.shape[0] != self.D:
raise ValueError(f"vector must {self.D} dim array")
else:
vectors = vectors.reshape(1, -1)
else:
if vectors.shape[-1] != self.D:
raise ValueError(f"vector must {self.D} dim array")
if self.centers_book is None:
raise ValueError("call fit first")
codes = []
for index in range(self.M):
vec_sub = vectors[:, index * self.step: (index + 1) * self.step]
code, _ = vq(vec_sub, self.centers_book[index])
codes.append(code)
return np.array(codes).T
N, D = 10000, 128
X = np.random.random((N, D)).astype(np.float32)
query = np.random.random((D,)).astype(np.float32)
m = PQ(8, 256)
m.fit(X)
m.encode(query)
@lmyyao
Copy link
Author

lmyyao commented Aug 24, 2023

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment