Skip to content

Instantly share code, notes, and snippets.

@keskival
Last active September 10, 2023 11:02
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save keskival/d4188e9d64ed2cf6fe14175c8419da2c to your computer and use it in GitHub Desktop.
Save keskival/d4188e9d64ed2cf6fe14175c8419da2c to your computer and use it in GitHub Desktop.
Computing cosine similarities.
import pandas as pd
import seaborn as sns
import numpy as np
from numpy.linalg import norm
import matplotlib.pyplot as plt
data_df = pd.read_csv('data/embedded_1k_reviews.csv')
queries_df = pd.read_csv('data/queries.csv')
data_df['ada_embedding'] = data_df.ada_embedding.apply(eval).apply(np.array)
queries_df['ada_embedding'] = queries_df.ada_embedding.apply(eval).apply(np.array)
data_embeddings = np.stack(data_df['ada_embedding'])
queries_embeddings = np.stack(queries_df['ada_embedding'])
cosine_similarities = np.tensordot(data_embeddings, queries_embeddings, axes=[1, 1]) / (np.expand_dims(norm(data_embeddings, axis=1), axis=1) * np.expand_dims(norm(queries_embeddings, axis=1), axis=0))
dims=128
cosine_similarities_cropped = np.tensordot(data_embeddings[:, :dims], queries_embeddings[:, :dims], axes=[1, 1]) / (np.expand_dims(norm(data_embeddings[:, :dims], axis=1), axis=1) * np.expand_dims(norm(queries_embeddings[:,:dims], axis=1), axis=0))
sns.histplot(cosine_similarities)
sns.histplot(cosine_similarities_cropped)
np.argsort(cosine_similarities, axis=0)[-10:]
np.argsort(cosine_similarities_cropped, axis=0)[-10:]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment