Last active
September 14, 2023 21:25
-
-
Save rsbohn/4f585dd91b350d307f9133e0d2adeb2f to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"get embeddings from an LLM database" | |
# LICENSE https://www.apache.org/licenses/LICENSE-2.0.txt | |
# Copyright (C) Randall Bohn 2023 | |
# requires: llm>=0.9, sqlite_utils, numpy, umap-learn==0.5.3 | |
from typing import Dict | |
import llm | |
import numpy as np | |
from sqlite_utils import Database | |
from umap import UMAP | |
database_file = "./local.db" | |
db = Database(database_file) | |
def dig(query:str, n=10) -> Dict: | |
"Get the embeddings for the query results." | |
collection = llm.Collection("articles", db) | |
articles = collection.similar(query, n) | |
score = [article.score for article in articles] | |
article_id = [article.id for article in articles] | |
ae = [db.query(f"select embedding from embeddings where id='{item}'") | |
for item in article_id] | |
ae = [next(g) for g in ae] | |
ae = [np.frombuffer(item['embedding'], "<f4") for item in ae] | |
return dict(id=article_id, score=score, embedding=ae) | |
def reduce(data:np.ndarray) -> np.ndarray: | |
"Use UMAP to reduce to 2 dimensions." | |
embedding = UMAP().fit_transform(data) | |
return embedding | |
def main(query:str): | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
N=128 | |
data = dig(query, N) | |
embeddings_2d = reduce(data['embedding']) | |
df = pd.DataFrame({ | |
'id':data['id'], | |
'score':data['score'], | |
'x':embeddings_2d[:,0], | |
'y':embeddings_2d[:,1]}) | |
sns.scatterplot(x='x',y='y', data=df) | |
plt.title(query) | |
plt.show() | |
if __name__=="__main__": | |
main("Mexico") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment