Skip to content

Instantly share code, notes, and snippets.

@sam2332
Created May 8, 2024 01:22
Show Gist options
  • Save sam2332/67c2287e1c6fe7f08bfd5a91778caa6d to your computer and use it in GitHub Desktop.
Save sam2332/67c2287e1c6fe7f08bfd5a91778caa6d to your computer and use it in GitHub Desktop.
RAG Server
from fastapi import FastAPI, HTTPException, Depends
from pydantic import BaseModel
import requests
import time
import sqlite3
from contextlib import closing
import numpy as np
app = FastAPI()
embeddings_model = "mxbai-embed-large"
chat_model = "llama3"
chat_model = "dolphin-mixtral"
chat_model = "mixtral:latest"
chat_model = "dolphin-mixtral:latest"
chat_model = "dolphin-mistral:latest"
ollama_host = "http://localhost:11434"
embeddings_model_db = "default"
# Database connection utility
def get_db_connection():
global embeddings_model_db
conn = sqlite3.connect(f"./embeddings/{embeddings_model_db}.db")
conn.row_factory = sqlite3.Row
return conn
# API models
class EmbeddingRequest(BaseModel):
source: str
content: str
class ChatRequest(BaseModel):
messages: list
class RagRequest(BaseModel):
prompt: str
related_count: int
max_tokens: int
class ChangeEmbeddingDBFilename(BaseModel):
name: str
# Database setup
def setup_database():
with get_db_connection() as conn:
with closing(conn.cursor()) as cursor:
cursor.execute("""
CREATE TABLE IF NOT EXISTS embeddings (
id INTEGER PRIMARY KEY,
source TEXT,
content TEXT,
embedding TEXT
last_updated TIMESTAMP DEFAULT CURRENT_TIMESTAMP)
""")
conn.commit()
setup_database()
@app.post("/api/change_embedding_db/")
async def change_embedding_db(data: ChangeEmbeddingDBFilename):
global embeddings_model_db
embeddings_model_db = data.name
setup_database()
return {"status": "success"}
# Endpoint to get all embeddings
@app.get("/api/embeddings/")
async def get_embeddings():
with get_db_connection() as conn:
with closing(conn.cursor()) as cursor:
cursor.execute("SELECT * FROM embeddings")
rows = cursor.fetchall()
return [dict(row) for row in rows]
def make_embeddings_safe_for_db(embedding):
return str(embedding).replace('[', '{').replace(']', '}')
def insert_embedding(content, source):
print(f"Inserting embedding for {len(content)} bytes from {source}")
response = requests.post(ollama_host+"/api/embeddings", json={"model": embeddings_model, "prompt": content})
if response.status_code == 200:
embedding = response.json()['embedding']
with get_db_connection() as conn:
with closing(conn.cursor()) as cursor:
embedding = make_embeddings_safe_for_db(embedding)
#check if exists
cursor.execute("INSERT INTO embeddings (source, content, embedding) VALUES (?, ?, ?)", (source, content, embedding))
conn.commit()
return {"status": "success", "content": content, "embedding": embedding}
else:
raise HTTPException(status_code=response.status_code, detail="Error processing embeddings")
# Endpoint to insert text and embeddings
@app.post("/api/insert_text_embeddings/")
async def insert_text_embeddings(data: EmbeddingRequest):
# Simulating external API call for embeddings
return insert_embedding(data.content, data.source)
#import all files in the "ingress" folder and mark the filenames as the source use the pathlib
from pathlib import Path
@app.post("/api/ingress_file_embeddings/")
async def ingress_file_embeddings():
# Get all files in the ingress folder
ingress_folder = Path("ingress")
for file in ingress_folder.iterdir():
if file.is_file():
if file.suffix == ".txt":
with open(file, "r") as f:
content = f.read()
#chunk content 255 characters
for i in range(0, len(content), 255):
insert_embedding(content[i:i+255], file.name + " - chunk " + str(i))
elif file.suffix == ".csv":
data = file.read_text()
lines = data.split("\n")
avg_list = []
for index in range(2, len(lines) - 1,5):
start = time.time()
content = ""
if index-2 >0:
content += lines[index-2] + "\n"
if index-1 >0:
content += lines[index-1] + "\n"
content += lines[index] + "\n"
if index+1 < len(lines):
content += lines[index+1] + "\n"
if index+2 < len(lines):
content += lines[index+2] + "\n"
insert_embedding(content, file.name+" - line "+str(index))
end = time.time()
avg_list.append(end-start)
avg = sum(avg_list)/len(avg_list)
print(f"Average time for processing 5 lines: {avg} seconds, time remaining for {len(lines)-index} lines: {avg*(len(lines)-index)} seconds")
avg_list = avg_list[-10:]
return {"status": "success"}
import threading as threadding
from queue import Queue
def ingress_thread(queue):
failout = 5
while failout >0:
while not queue.empty():
file, lines = queue.get()
content = "\n".join(lines)
try:
insert_embedding(content, file)
except Exception as e:
print(e)
failout -= 1
failout -1
@app.post("/api/fast_csv_ingress/")
async def fast_csv_ingress():
queue = Queue()
for file in Path('ingress').glob("*.csv"):
lines = file.read_text().split("\n")
threadding.Thread(target=ingress_thread, args=(queue,)).start()
threadding.Thread(target=ingress_thread, args=(queue,)).start()
threadding.Thread(target=ingress_thread, args=(queue,)).start()
threadding.Thread(target=ingress_thread, args=(queue,)).start()
threadding.Thread(target=ingress_thread, args=(queue,)).start()
threadding.Thread(target=ingress_thread, args=(queue,)).start()
threadding.Thread(target=ingress_thread, args=(queue,)).start()
for index in range(2, len(lines) - 1, 5):
queue.put((f"{file.name} - lines {index-2} - {index+2}", lines[index-2:index+3]))
while queue.qsize() > 0:
time.sleep(1)
return {"status": "success"}
def generate_embedding(prompt):
response = requests.post(ollama_host+"/api/embeddings", json={"model": embeddings_model, "prompt": prompt})
if response.status_code == 200:
return response.json()['embedding']
else:
raise Exception("Error generating embeddings")
# Retrieval-Augmented Generation using embeddings
@app.post("/api/rag_test")
async def perform_ragtest(data: RagRequest):
with get_db_connection() as conn:
with closing(conn.cursor()) as cursor:
cursor.execute("SELECT content, embedding FROM embeddings")
embeddings = cursor.fetchall()
query_emb = generate_embedding(data.prompt)
db_embs = [np.fromstring(row['embedding'][1:-1], sep=',') for row in embeddings]
cos_sims = cosine_similarity([query_emb], db_embs)[0]
indices = np.argsort(cos_sims)[::-1][:3]
related_prompts = " ".join(embeddings[i]['content'] for i in indices)
system_prompt = "You are helpful, here is some info related to the user's question:\n" + related_prompts
return {"system_prompt": system_prompt, "related_prompts": related_prompts}
from fastapi import HTTPException
from numpy import array, argsort, fromstring
from sklearn.metrics.pairwise import cosine_similarity
from pydantic import BaseModel
@app.post("/api/reset_embeddings_db")
async def reset_embeddings_db():
with get_db_connection() as conn:
with closing(conn.cursor()) as cursor:
cursor.execute("DELETE FROM embeddings")
conn.commit()
return {"status": "success"}
@app.post("/api/rag")
async def perform_rag(data: RagRequest):
# Create a connection to the database
with get_db_connection() as conn:
with closing(conn.cursor()) as cursor:
# Retrieve all embeddings from the database
cursor.execute("SELECT source, content, embedding FROM embeddings")
embeddings = cursor.fetchall()
# Generate the embedding for the prompt
query_emb = array([generate_embedding(data.prompt)])
# Convert stored embeddings from strings back to numpy arrays
db_embs = array([fromstring(emb['embedding'][1:-1], sep=',') for emb in embeddings])
# Compute cosine similarities
cos_sims = cosine_similarity(query_emb, db_embs)[0]
indices = argsort(cos_sims)[::-1][:data.related_count] # Top 3 related prompts
# Construct related prompts text
related_prompts = ""
#"\n".join(embeddings[i]['content'] for i in indices)
for i in indices:
related_prompts += f"""
#{embeddings[i]['source']}
```
{embeddings[i]['content']}
```"""
system_prompt = f"You are helpful, here is some info related to the user's question:\n{related_prompts}\nThe next message is the users question"
print()
print(system_prompt)
print(data.prompt)
# Query an external chat model
response = requests.post(ollama_host+"/api/chat", json={
"stream": False,
"model": chat_model,
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": data.prompt}
],
"max_tokens": data.max_tokens
})
print(response.status_code)
print(response.text)
if response.status_code == 200:
print(response.json())
print()
return response.json()
else:
raise HTTPException(status_code=response.status_code, detail="Error processing chat with model")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=11435)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment