Skip to content

Instantly share code, notes, and snippets.

@ahmedsalim3
Last active November 3, 2024 20:45
Show Gist options
  • Save ahmedsalim3/3f0e1c6e65b36cd4d97354561e7b04a0 to your computer and use it in GitHub Desktop.
Save ahmedsalim3/3f0e1c6e65b36cd4d97354561e7b04a0 to your computer and use it in GitHub Desktop.
Integrate SQLite with google generative AI to store and query vector embeddings for semantic search and retrieval-augmented generation (RAG)
# pip install sqlite-vec google-generativeai
# sqlite-vec #
# inspired from @ https://alexgarcia.xyz/sqlite-vec/python.html
# sourced from @ https://github.com/asg017/sqlite-vec/blob/main/examples/simple-python/demo.py
import sqlite3
import sqlite_vec
import struct
from typing import List
from google.generativeai.embedding import embed_content
import google.generativeai as genai
import subprocess
import getpass
import os
# setup configuration
DATABASE_DIR = "adventureworks_vdb.db"
if not os.path.exists(DATABASE_DIR):
subprocess.run(["wget", "https://github.com/ahmedsalim3/AdventureWorks-Database/blob/main/data/database/adventureworks.db"])
if "GOOGLE_API_KEY" not in os.environ:
os.environ["GOOGLE_API_KEY"] = getpass.getpass("Input your Google API key: ")
genai.configure(api_key=os.environ["GOOGLE_API_KEY"])
EMBEDDING_MODEL = "models/embedding-001"
EXAMPLES = [
{
"input": "Calculate the average age of all customers",
"query": "SELECT AVG(strftime('%Y', '2023-01-17') - strftime('%Y', BirthDate)) AS average_age FROM customers;"
},
{
"input": "Find minimum product profit",
"query": "SELECT ProductName, ProductPrice - ProductCost AS profit FROM products WHERE (ProductPrice - ProductCost) = (SELECT MIN(ProductPrice - ProductCost) FROM products);",
},
{
"input": "Find all the products and identify them by their unique key values in ascending order.",
"query": "SELECT ProductName, ProductCost, ProductPrice, ProductPrice - ProductCost AS Profit FROM products ORDER BY Profit DESC;",
},
{
"input": "Find the 10 most expensive products in descending order.",
"query": "SELECT ProductName, ProductPrice FROM products ORDER BY ProductPrice DESC LIMIT 10;",
},
{
"input": "List all customers who are homeowners",
"query": "SELECT FirstName, LastName, EmailAddress FROM customers WHERE HomeOwner = 'Y';"
},
{
"input": "Get the total number of orders for each product",
"query": "SELECT ProductKey, SUM(OrderQuantity) AS TotalOrders FROM sales_2015 GROUP BY ProductKey;"
},
{
"input": "Find all products in the 'Bikes' category",
"query": "SELECT p.ProductName FROM products p JOIN product_subcategories ps ON p.ProductSubcategoryKey = ps.ProductSubcategoryKey JOIN product_categories pc ON ps.ProductCategoryKey = pc.ProductCategoryKey WHERE pc.CategoryName = 'Bikes';"
},
{
"input": "Identify customers born in the 1980s",
"query": "SELECT FirstName, LastName, BirthDate FROM customers WHERE BirthDate BETWEEN '1980-01-01' AND '1989-12-31';"
},
{
"input": "Find the most recent order date",
"query": "SELECT MAX(OrderDate) AS MostRecentOrder FROM sales_2016;"
},
{
"input": "Retrieve the names of all products that cost more than $30",
"query": "SELECT ProductName FROM products WHERE ProductPrice > 30.00;"
},
{
"input": "Count the number of customers with a Bachelors degree",
"query": "SELECT COUNT(*) AS BachelorsCount FROM customers WHERE EducationLevel = 'Bachelors';"
},
{
"input": "Find the total quantity of products sold in 2017",
"query": "SELECT SUM(OrderQuantity) AS TotalSold FROM sales_2017;"
},
{
"input": "List customers who are married and have more than two children",
"query": "SELECT FirstName, LastName FROM customers WHERE MaritalStatus = 'M' AND TotalChildren > 2;"
},
{
"input": "Get the product names and their corresponding subcategory names",
"query": "SELECT p.ProductName, ps.SubcategoryName FROM products p JOIN product_subcategories ps ON p.ProductSubcategoryKey = ps.ProductSubcategoryKey;"
},
{
"input": "Calculate the average income of customers by marital status",
"query": "SELECT MaritalStatus, AVG(AnnualIncome) AS AverageIncome FROM customers GROUP BY MaritalStatus;"
},
{
"input": "Find the name and category of products returned in January 2015",
"query": "SELECT p.ProductName, pc.CategoryName FROM returns r JOIN products p ON r.ProductKey = p.ProductKey JOIN product_subcategories ps ON p.ProductSubcategoryKey = ps.ProductSubcategoryKey JOIN product_categories pc ON ps.ProductCategoryKey = pc.ProductCategoryKey WHERE strftime('%Y-%m', r.ReturnDate) = '2015-01';"
},
{
"input": "List all orders with a quantity greater than 2 from 2017",
"query": "SELECT OrderNumber, OrderDate, OrderQuantity FROM sales_2017 WHERE OrderQuantity > 2;"
},
{
"input": "Count the total number of orders made by each customer in 2015",
"query": "SELECT CustomerKey, COUNT(OrderNumber) AS TotalOrders FROM sales_2015 GROUP BY CustomerKey;"
},
{
"input": "Find all unique product colors available",
"query": "SELECT DISTINCT ProductColor FROM products;"
},
{
"input": "Identify the number of male and female customers",
"query": "SELECT Gender, COUNT(*) AS TotalCustomers FROM customers GROUP BY Gender;"
},
{
"input": "List products with a profit margin over $15",
"query": "SELECT ProductName, ProductPrice - ProductCost AS ProfitMargin FROM products WHERE (ProductPrice - ProductCost) > 15;"
},
{
"input": "Get the first name, last name, and email of customers whose last name starts with 'H'",
"query": "SELECT FirstName, LastName, EmailAddress FROM customers WHERE LastName LIKE 'H%';"
},
{
"input": "Retrieve all orders made by customers from the 'United States' region",
"query": "SELECT s.OrderNumber, s.OrderDate, c.FirstName, c.LastName FROM sales_2016 s JOIN customers c ON s.CustomerKey = c.CustomerKey JOIN territories t ON s.TerritoryKey = t.TerritoryKey WHERE t.Country = 'United States';"
},
{
"input": "Find the top 3 categories with the most product returns",
"query": "SELECT pc.CategoryName, COUNT(r.ProductKey) AS TotalReturns FROM returns r JOIN products p ON r.ProductKey = p.ProductKey JOIN product_subcategories ps ON p.ProductSubcategoryKey = ps.ProductSubcategoryKey JOIN product_categories pc ON ps.ProductCategoryKey = pc.ProductCategoryKey GROUP BY pc.CategoryName ORDER BY TotalReturns DESC LIMIT 3;"
},
{
"input": "List all customers who have no children",
"query": "SELECT FirstName, LastName FROM customers WHERE TotalChildren = 0;"
}
]
# end configuration
# connect to database and load extensions
db = sqlite3.connect(DATABASE_DIR)
db.enable_load_extension(True)
sqlite_vec.load(db)
db.enable_load_extension(False)
# to store embeddings as binary
def serialize(vector: List[float]) -> bytes:
"""Serializes a list of floats into a compact "raw bytes" format."""
return struct.pack("%sf" % len(vector), *vector)
def create_embeddings():
"""Add embedding columns to AdventureWorks tables"""
tables = [
"customers", "calendar", "products", "product_subcategories",
"product_categories", "territories", "returns", "sales_2015",
"sales_2016", "sales_2017"
]
for table in tables:
try:
db.execute(f"ALTER TABLE {table} ADD COLUMN embeddings FLOAT[1536];")
except sqlite3.OperationalError:
pass # ignore if exists
# prepare sentences table and vec_sentences virtual table for RAG retrieval
def create_tables():
db.execute("""
CREATE TABLE IF NOT EXISTS sentences (
id INTEGER PRIMARY KEY,
sentence TEXT
);
""")
db.execute("""
CREATE VIRTUAL TABLE IF NOT EXISTS vec_sentences USING vec0(
id INTEGER PRIMARY KEY,
sentence_embedding FLOAT[768]
);
""")
# insert examples into sentences table and generate embeddings
def embed_sentences():
with db:
db.execute("DELETE FROM vec_sentences")
for i, example in enumerate(EXAMPLES):
db.execute("INSERT OR IGNORE INTO sentences(id, sentence) VALUES(?, ?)", (i, example["input"]))
sentence_rows = db.execute("SELECT id, sentence FROM sentences").fetchall()
content = [row[1] for row in sentence_rows]
try:
response = embed_content(model=EMBEDDING_MODEL, content=content, task_type="RETRIEVAL_DOCUMENT")
embeddings = response["embedding"]
except Exception as e:
print(f"Error generating embeddings: {e}")
return
# insert embeddings into vec_sentences table
for (id, _), embedding in zip(sentence_rows, embeddings):
db.execute(
"INSERT OR REPLACE INTO vec_sentences(id, sentence_embedding) VALUES(?, ?)",
(id, serialize(embedding)),
)
def retrieval(query_text: str, k: int = 3):
try:
query_response = embed_content(model=EMBEDDING_MODEL, content=[query_text], task_type="RETRIEVAL_QUERY")
query_embedding = query_response["embedding"][0] # first embedding
except Exception as e:
print(f"Error generating query embedding: {e}")
return []
# Search for similar sentences
results = db.execute(
"""
SELECT
vec_sentences.id,
distance,
sentence
FROM vec_sentences
LEFT JOIN sentences ON sentences.id = vec_sentences.id
WHERE sentence_embedding MATCH ?
AND k = ?
ORDER BY distance
""",
[serialize(query_embedding), k]
).fetchall()
# output results
for row in results:
print(f"ID: {row[0]}, Sentence: {row[1]}, Distance: {row[2]}")
return results
# run setup functions
create_embeddings()
create_tables()
embed_sentences()
# test query for RAG
query_text = "how many customers in the database?"
number_of_retrieval = 100 # retrieve up to 100 sentences sorted by distance
print("Query Results:")
results = retrieval(query_text, number_of_retrieval)
# =======
# OUTPUT
# =======
# Input your Google API key:
# Query Results:
# ID: 17, Distance: 0.7819159626960754, Sentence: Count the total number of orders made by each customer in 2015
# ID: 12, Distance: 0.785915195941925, Sentence: List customers who are married and have more than two children
# ID: 24, Distance: 0.8111039400100708, Sentence: List all customers who have no children
# ID: 4, Distance: 0.8175591826438904, Sentence: List all customers who are homeowners
# ID: 10, Distance: 0.8231719732284546, Sentence: Count the number of customers with a Bachelors degree
# ID: 19, Distance: 0.8273471593856812, Sentence: Identify the number of male and female customers
# ID: 0, Distance: 0.8401814103126526, Sentence: Calculate the average age of all customers
# ID: 21, Distance: 0.8487322330474854, Sentence: Get the first name, last name, and email of customers whose last name starts with 'H'
# ID: 16, Distance: 0.856723427772522, Sentence: List all orders with a quantity greater than 2 from 2017
# ID: 5, Distance: 0.8622673153877258, Sentence: Get the total number of orders for each product
# ID: 11, Distance: 0.8809165358543396, Sentence: Find the total quantity of products sold in 2017
# ID: 22, Distance: 0.8917487263679504, Sentence: Retrieve all orders made by customers from the 'United States' region
# ID: 14, Distance: 0.893153727054596, Sentence: Calculate the average income of customers by marital status
# ID: 7, Distance: 0.8941826224327087, Sentence: Identify customers born in the 1980s
# ID: 9, Distance: 0.9029486179351807, Sentence: Retrieve the names of all products that cost more than $30
# ID: 2, Distance: 0.9082536697387695, Sentence: Find all the products and identify them by their unique key values in ascending order.
# ID: 8, Distance: 0.918032705783844, Sentence: Find the most recent order date
# ID: 13, Distance: 0.9238193035125732, Sentence: Get the product names and their corresponding subcategory names
# ID: 15, Distance: 0.9384433031082153, Sentence: Find the name and category of products returned in January 2015
# ID: 23, Distance: 0.9451109170913696, Sentence: Find the top 3 categories with the most product returns
# ID: 20, Distance: 0.9506728649139404, Sentence: List products with a profit margin over $15
# ID: 18, Distance: 0.9698100090026855, Sentence: Find all unique product colors available
# ID: 3, Distance: 0.9702053666114807, Sentence: Find the 10 most expensive products in descending order.
# ID: 1, Distance: 0.9767876267433167, Sentence: Find minimum product profit
# ID: 6, Distance: 0.9809509515762329, Sentence: Find all products in the 'Bikes' category
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment