Skip to content

Instantly share code, notes, and snippets.

@gordinmitya
Created September 15, 2023 16:53
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save gordinmitya/0d4f0b4dd168f1de049c4fd688a15e73 to your computer and use it in GitHub Desktop.
Save gordinmitya/0d4f0b4dd168f1de049c4fd688a15e73 to your computer and use it in GitHub Desktop.
find closest images by their embeddings
from dataclasses import dataclass
import numpy as np
@dataclass(frozen=True)
class Embeddings:
names: list[str]
vectors: np.ndarray
def find_close_to_many(
request: set[str],
embeddings: Embeddings,
target_count: int
) -> list[tuple[str, float]]:
indices = [embeddings.names.index(name) for name in request]
vectors = embeddings.vectors
query = vectors[indices]
scores = np.dot(vectors, query.T)
scores[indices] = -np.inf
top_score = np.max(scores, axis=1)
top_items = np.argsort(top_score)[-target_count:][::-1]
return [(embeddings.names[i], top_score[i]) for i in top_items]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment