Skip to content

Instantly share code, notes, and snippets.

@kurain
Created January 3, 2024 14:34
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kurain/3fb7b79cbc0fa8061fac5dd20e2e983b to your computer and use it in GitHub Desktop.
Save kurain/3fb7b79cbc0fa8061fac5dd20e2e983b to your computer and use it in GitHub Desktop.
お手軽ベクトル検索
import mlx.core as mx
import numpy as np
import faiss
import time
class MLXVecSearch():
def __init__(self, stream=mx.gpu):
self._stream = stream
def add(self, ndarray):
self._ndarray = mx.array(ndarray, dtype=mx.float32)
def search(self, _vec, topk=10):
vec = mx.array(_vec, dtype=mx.float32)
scores = mx.matmul(self._ndarray, vec, stream=self._stream)
idx = mx.argsort(scores, stream=self._stream)[-1:-(topk+1):-1]
return [scores[i].item() for i in idx], [i.item() for i in idx]
class SimpleVecSearch():
def add(self, ndarray):
self._ndarray = ndarray
def search(self, vec, topk=10):
scores = np.matmul(self._ndarray, vec)
idx = np.argsort(scores)[-1:-(topk+1):-1]
return [scores[i].item() for i in idx], list(idx)
class FaissVecSearch():
def add(self, ndarray):
self._faiss_index = faiss.IndexFlatIP(1536)
self._faiss_index.add(ndarray)
def search(self, _vec, topk=10):
d, i = self._faiss_index.search(np.array([_vec]), topk)
return list(d[0]), list(i[0])
def gen_vecs(n, dtype='float32'):
tmp = np.random.rand(n, 1536)
return (tmp / np.array([[i] for i in np.linalg.norm(tmp, axis=1)])).astype(dtype)
if __name__ == '__main__':
dtype = 'float32'
def calc(index_type, index_size, n_test=100):
raw_index = gen_vecs(index_size, dtype=dtype)
queries = gen_vecs(n_test, dtype=dtype)
if index_type == 'mlx(gpu)':
index = MLXVecSearch()
elif index_type == 'mlx(cpu)':
index = MLXVecSearch(stream=mx.cpu)
elif index_type == 'simple':
index = SimpleVecSearch()
elif index_type == 'faiss':
index = FaissVecSearch()
index.add(raw_index)
start = time.time()
for q in queries:
index.search(q, 1)
end = time.time()
avg = (end - start) / n_test
print(f'{index_type:10}', index_size, avg)
return (index_type, index_size, avg)
for i in [10**i for i in range(1, 7)]:
for index_type in ['mlx(gpu)', 'mlx(cpu)', 'simple', 'faiss']:
calc(index_type, i)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment