Skip to content

Instantly share code, notes, and snippets.

@rsbohn
Last active September 14, 2023 21:25
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rsbohn/4f585dd91b350d307f9133e0d2adeb2f to your computer and use it in GitHub Desktop.
Save rsbohn/4f585dd91b350d307f9133e0d2adeb2f to your computer and use it in GitHub Desktop.
"get embeddings from an LLM database"
# LICENSE https://www.apache.org/licenses/LICENSE-2.0.txt
# Copyright (C) Randall Bohn 2023
# requires: llm>=0.9, sqlite_utils, numpy, umap-learn==0.5.3
from typing import Dict
import llm
import numpy as np
from sqlite_utils import Database
from umap import UMAP
database_file = "./local.db"
db = Database(database_file)
def dig(query:str, n=10) -> Dict:
"Get the embeddings for the query results."
collection = llm.Collection("articles", db)
articles = collection.similar(query, n)
score = [article.score for article in articles]
article_id = [article.id for article in articles]
ae = [db.query(f"select embedding from embeddings where id='{item}'")
for item in article_id]
ae = [next(g) for g in ae]
ae = [np.frombuffer(item['embedding'], "<f4") for item in ae]
return dict(id=article_id, score=score, embedding=ae)
def reduce(data:np.ndarray) -> np.ndarray:
"Use UMAP to reduce to 2 dimensions."
embedding = UMAP().fit_transform(data)
return embedding
def main(query:str):
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
N=128
data = dig(query, N)
embeddings_2d = reduce(data['embedding'])
df = pd.DataFrame({
'id':data['id'],
'score':data['score'],
'x':embeddings_2d[:,0],
'y':embeddings_2d[:,1]})
sns.scatterplot(x='x',y='y', data=df)
plt.title(query)
plt.show()
if __name__=="__main__":
main("Mexico")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment