Skip to content

Instantly share code, notes, and snippets.

@aminnj
Last active May 20, 2024 06:15
Show Gist options
  • Save aminnj/962e81f30e2a0fc6dec219c690c71757 to your computer and use it in GitHub Desktop.
Save aminnj/962e81f30e2a0fc6dec219c690c71757 to your computer and use it in GitHub Desktop.
Fast quantized embedding search
import functools
import jax
import numpy as np
import jax.numpy as jnp
from sentence_transformers import SentenceTransformer
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
model = model.to("cpu")
sentences = [
"This is the first sentence.",
"Here is another sentence.",
"Cats are red",
"Dogs are blue",
"Venus is square",
"Math is hard",
"English is also hard",
"Need another sentence",
"Almost there.",
"Tenth sentence is the last one",
]
dbvecs = model.encode(sentences, precision="ubinary")
# pack uint8s into uint32 (4x reduction in vec size)
# since the popcount operation is a 64bit operation, this gives a 4x speedup
# we would want to pack it into uint64, but jax doesn't support it
dbvecs = dbvecs.view("uint32")
dbvecs = jnp.array(dbvecs)
# simulate having 1M vectors to search
dbvecs = jnp.vstack([dbvecs]*100_000)
sentences = sentences*100_000
@functools.partial(jax.jit, static_argnames=["k", "recall_target"])
def get_nearest_k(qvec, dbvecs, k=5, recall_target=0.95):
xor_result = jax.lax.bitwise_xor(qvec, dbvecs)
# Compute the population count (number of 1 bits) and sum along the last axis
dists = jax.lax.population_count(xor_result).sum(axis=-1)
dists = dists.astype(jnp.float32)
# min was slow for some reason, so using max with flipped distances
dists, indices = jax.lax.approx_max_k(-dists, k=k, recall_target=recall_target)
return -dists, indices
t0 = time.time()
qvec = model.encode(["Difficult school subject"], precision="ubinary")
qvec = qvec.view("uint32")
qvec = jnp.array(qvec)
t1 = time.time()
print(f"Encoded query string into vector in {(t1-t0)*1000:.1f}ms")
# warmup
_ = get_nearest_k(qvec, dbvecs[:1000])
t0 = time.time()
dists, indices = get_nearest_k(qvec, dbvecs)
t1 = time.time()
print(f"Searched {len(dbvecs)} vectors in {(t1-t0)*1000:.1f}ms")
print(
np.vstack([np.array(sentences)[indices], dists])
)
"""
Encoded query string into vector in 12.8ms
Searched 1000000 vectors in 15.5ms
[['Math is hard' 'Math is hard' 'Math is hard' 'Math is hard' 'Math is hard']
['114.0' '114.0' '114.0' '114.0' '114.0']]
"""
@aminnj
Copy link
Author

aminnj commented May 20, 2024

Based on https://news.ycombinator.com/item?id=40379347

The additional trick to get another speedup is to pack the uint8s into a bigger type so that popcount can do it all simultaneously.

@aminnj
Copy link
Author

aminnj commented May 20, 2024

Faster julia implementation is

julia> using BenchmarkTools

julia> function hamming_distance_packed(x1::Vector{UInt64}, x2::Vector{UInt64})::Int
          s = 0
          @inbounds @simd for i in 1:6
              s += count_ones(x1[i]  x2[i])
          end
          s
       end;

julia> a = reinterpret(UInt64, (rand(UInt8, 48))) |> collect;

julia> b = reinterpret(UInt64, (rand(UInt8, 48))) |> collect;

julia> @benchmark hamming_distance_packed($a, $b)
BenchmarkTools.Trial: 10000 samples with 1000 evaluations.
 Range (min  max):  3.958 ns  20.959 ns  ┊ GC (min  max): 0.00%  0.00%
 Time  (median):     4.041 ns              ┊ GC (median):    0.00%
 Time  (mean ± σ):   4.063 ns ±  0.320 ns  ┊ GC (mean ± σ):  0.00% ± 0.00%

 Memory estimate: 0 bytes, allocs estimate: 0.
julia> dbvecs = repeat(Vector{Vector{UInt64}}([b,b,b,a,a,b,a,a,b,b]),100_000);

julia> length(dbvecs)
1000000

julia> @benchmark hamming_distance_packed.(Ref($a), $dbvecs)
BenchmarkTools.Trial: 1780 samples with 1 evaluation.
 Range (min  max):  2.713 ms   3.478 ms  ┊ GC (min  max): 0.00%  0.00%
 Time  (median):     2.755 ms              ┊ GC (median):    0.00%
 Time  (mean ± σ):   2.802 ms ± 96.215 μs  ┊ GC (mean ± σ):  1.32% ± 2.78%

 Memory estimate: 7.63 MiB, allocs estimate: 3.

Jax takes 15ms to compare a query vector against 1M vectors without a approx_max_k. So the julia implementation is 5x faster than jax.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment