-
-
Save ahmedsalim3/3f0e1c6e65b36cd4d97354561e7b04a0 to your computer and use it in GitHub Desktop.
Integrate SQLite with google generative AI to store and query vector embeddings for semantic search and retrieval-augmented generation (RAG)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # pip install sqlite-vec google-generativeai | |
| # sqlite-vec # | |
| # inspired from @ https://alexgarcia.xyz/sqlite-vec/python.html | |
| # sourced from @ https://github.com/asg017/sqlite-vec/blob/main/examples/simple-python/demo.py | |
| import sqlite3 | |
| import sqlite_vec | |
| import struct | |
| from typing import List | |
| from google.generativeai.embedding import embed_content | |
| import google.generativeai as genai | |
| import subprocess | |
| import getpass | |
| import os | |
| # setup configuration | |
| DATABASE_DIR = "adventureworks_vdb.db" | |
| if not os.path.exists(DATABASE_DIR): | |
| subprocess.run(["wget", "https://github.com/ahmedsalim3/AdventureWorks-Database/blob/main/data/database/adventureworks.db"]) | |
| if "GOOGLE_API_KEY" not in os.environ: | |
| os.environ["GOOGLE_API_KEY"] = getpass.getpass("Input your Google API key: ") | |
| genai.configure(api_key=os.environ["GOOGLE_API_KEY"]) | |
| EMBEDDING_MODEL = "models/embedding-001" | |
| EXAMPLES = [ | |
| { | |
| "input": "Calculate the average age of all customers", | |
| "query": "SELECT AVG(strftime('%Y', '2023-01-17') - strftime('%Y', BirthDate)) AS average_age FROM customers;" | |
| }, | |
| { | |
| "input": "Find minimum product profit", | |
| "query": "SELECT ProductName, ProductPrice - ProductCost AS profit FROM products WHERE (ProductPrice - ProductCost) = (SELECT MIN(ProductPrice - ProductCost) FROM products);", | |
| }, | |
| { | |
| "input": "Find all the products and identify them by their unique key values in ascending order.", | |
| "query": "SELECT ProductName, ProductCost, ProductPrice, ProductPrice - ProductCost AS Profit FROM products ORDER BY Profit DESC;", | |
| }, | |
| { | |
| "input": "Find the 10 most expensive products in descending order.", | |
| "query": "SELECT ProductName, ProductPrice FROM products ORDER BY ProductPrice DESC LIMIT 10;", | |
| }, | |
| { | |
| "input": "List all customers who are homeowners", | |
| "query": "SELECT FirstName, LastName, EmailAddress FROM customers WHERE HomeOwner = 'Y';" | |
| }, | |
| { | |
| "input": "Get the total number of orders for each product", | |
| "query": "SELECT ProductKey, SUM(OrderQuantity) AS TotalOrders FROM sales_2015 GROUP BY ProductKey;" | |
| }, | |
| { | |
| "input": "Find all products in the 'Bikes' category", | |
| "query": "SELECT p.ProductName FROM products p JOIN product_subcategories ps ON p.ProductSubcategoryKey = ps.ProductSubcategoryKey JOIN product_categories pc ON ps.ProductCategoryKey = pc.ProductCategoryKey WHERE pc.CategoryName = 'Bikes';" | |
| }, | |
| { | |
| "input": "Identify customers born in the 1980s", | |
| "query": "SELECT FirstName, LastName, BirthDate FROM customers WHERE BirthDate BETWEEN '1980-01-01' AND '1989-12-31';" | |
| }, | |
| { | |
| "input": "Find the most recent order date", | |
| "query": "SELECT MAX(OrderDate) AS MostRecentOrder FROM sales_2016;" | |
| }, | |
| { | |
| "input": "Retrieve the names of all products that cost more than $30", | |
| "query": "SELECT ProductName FROM products WHERE ProductPrice > 30.00;" | |
| }, | |
| { | |
| "input": "Count the number of customers with a Bachelors degree", | |
| "query": "SELECT COUNT(*) AS BachelorsCount FROM customers WHERE EducationLevel = 'Bachelors';" | |
| }, | |
| { | |
| "input": "Find the total quantity of products sold in 2017", | |
| "query": "SELECT SUM(OrderQuantity) AS TotalSold FROM sales_2017;" | |
| }, | |
| { | |
| "input": "List customers who are married and have more than two children", | |
| "query": "SELECT FirstName, LastName FROM customers WHERE MaritalStatus = 'M' AND TotalChildren > 2;" | |
| }, | |
| { | |
| "input": "Get the product names and their corresponding subcategory names", | |
| "query": "SELECT p.ProductName, ps.SubcategoryName FROM products p JOIN product_subcategories ps ON p.ProductSubcategoryKey = ps.ProductSubcategoryKey;" | |
| }, | |
| { | |
| "input": "Calculate the average income of customers by marital status", | |
| "query": "SELECT MaritalStatus, AVG(AnnualIncome) AS AverageIncome FROM customers GROUP BY MaritalStatus;" | |
| }, | |
| { | |
| "input": "Find the name and category of products returned in January 2015", | |
| "query": "SELECT p.ProductName, pc.CategoryName FROM returns r JOIN products p ON r.ProductKey = p.ProductKey JOIN product_subcategories ps ON p.ProductSubcategoryKey = ps.ProductSubcategoryKey JOIN product_categories pc ON ps.ProductCategoryKey = pc.ProductCategoryKey WHERE strftime('%Y-%m', r.ReturnDate) = '2015-01';" | |
| }, | |
| { | |
| "input": "List all orders with a quantity greater than 2 from 2017", | |
| "query": "SELECT OrderNumber, OrderDate, OrderQuantity FROM sales_2017 WHERE OrderQuantity > 2;" | |
| }, | |
| { | |
| "input": "Count the total number of orders made by each customer in 2015", | |
| "query": "SELECT CustomerKey, COUNT(OrderNumber) AS TotalOrders FROM sales_2015 GROUP BY CustomerKey;" | |
| }, | |
| { | |
| "input": "Find all unique product colors available", | |
| "query": "SELECT DISTINCT ProductColor FROM products;" | |
| }, | |
| { | |
| "input": "Identify the number of male and female customers", | |
| "query": "SELECT Gender, COUNT(*) AS TotalCustomers FROM customers GROUP BY Gender;" | |
| }, | |
| { | |
| "input": "List products with a profit margin over $15", | |
| "query": "SELECT ProductName, ProductPrice - ProductCost AS ProfitMargin FROM products WHERE (ProductPrice - ProductCost) > 15;" | |
| }, | |
| { | |
| "input": "Get the first name, last name, and email of customers whose last name starts with 'H'", | |
| "query": "SELECT FirstName, LastName, EmailAddress FROM customers WHERE LastName LIKE 'H%';" | |
| }, | |
| { | |
| "input": "Retrieve all orders made by customers from the 'United States' region", | |
| "query": "SELECT s.OrderNumber, s.OrderDate, c.FirstName, c.LastName FROM sales_2016 s JOIN customers c ON s.CustomerKey = c.CustomerKey JOIN territories t ON s.TerritoryKey = t.TerritoryKey WHERE t.Country = 'United States';" | |
| }, | |
| { | |
| "input": "Find the top 3 categories with the most product returns", | |
| "query": "SELECT pc.CategoryName, COUNT(r.ProductKey) AS TotalReturns FROM returns r JOIN products p ON r.ProductKey = p.ProductKey JOIN product_subcategories ps ON p.ProductSubcategoryKey = ps.ProductSubcategoryKey JOIN product_categories pc ON ps.ProductCategoryKey = pc.ProductCategoryKey GROUP BY pc.CategoryName ORDER BY TotalReturns DESC LIMIT 3;" | |
| }, | |
| { | |
| "input": "List all customers who have no children", | |
| "query": "SELECT FirstName, LastName FROM customers WHERE TotalChildren = 0;" | |
| } | |
| ] | |
| # end configuration | |
| # connect to database and load extensions | |
| db = sqlite3.connect(DATABASE_DIR) | |
| db.enable_load_extension(True) | |
| sqlite_vec.load(db) | |
| db.enable_load_extension(False) | |
| # to store embeddings as binary | |
| def serialize(vector: List[float]) -> bytes: | |
| """Serializes a list of floats into a compact "raw bytes" format.""" | |
| return struct.pack("%sf" % len(vector), *vector) | |
| def create_embeddings(): | |
| """Add embedding columns to AdventureWorks tables""" | |
| tables = [ | |
| "customers", "calendar", "products", "product_subcategories", | |
| "product_categories", "territories", "returns", "sales_2015", | |
| "sales_2016", "sales_2017" | |
| ] | |
| for table in tables: | |
| try: | |
| db.execute(f"ALTER TABLE {table} ADD COLUMN embeddings FLOAT[1536];") | |
| except sqlite3.OperationalError: | |
| pass # ignore if exists | |
| # prepare sentences table and vec_sentences virtual table for RAG retrieval | |
| def create_tables(): | |
| db.execute(""" | |
| CREATE TABLE IF NOT EXISTS sentences ( | |
| id INTEGER PRIMARY KEY, | |
| sentence TEXT | |
| ); | |
| """) | |
| db.execute(""" | |
| CREATE VIRTUAL TABLE IF NOT EXISTS vec_sentences USING vec0( | |
| id INTEGER PRIMARY KEY, | |
| sentence_embedding FLOAT[768] | |
| ); | |
| """) | |
| # insert examples into sentences table and generate embeddings | |
| def embed_sentences(): | |
| with db: | |
| db.execute("DELETE FROM vec_sentences") | |
| for i, example in enumerate(EXAMPLES): | |
| db.execute("INSERT OR IGNORE INTO sentences(id, sentence) VALUES(?, ?)", (i, example["input"])) | |
| sentence_rows = db.execute("SELECT id, sentence FROM sentences").fetchall() | |
| content = [row[1] for row in sentence_rows] | |
| try: | |
| response = embed_content(model=EMBEDDING_MODEL, content=content, task_type="RETRIEVAL_DOCUMENT") | |
| embeddings = response["embedding"] | |
| except Exception as e: | |
| print(f"Error generating embeddings: {e}") | |
| return | |
| # insert embeddings into vec_sentences table | |
| for (id, _), embedding in zip(sentence_rows, embeddings): | |
| db.execute( | |
| "INSERT OR REPLACE INTO vec_sentences(id, sentence_embedding) VALUES(?, ?)", | |
| (id, serialize(embedding)), | |
| ) | |
| def retrieval(query_text: str, k: int = 3): | |
| try: | |
| query_response = embed_content(model=EMBEDDING_MODEL, content=[query_text], task_type="RETRIEVAL_QUERY") | |
| query_embedding = query_response["embedding"][0] # first embedding | |
| except Exception as e: | |
| print(f"Error generating query embedding: {e}") | |
| return [] | |
| # Search for similar sentences | |
| results = db.execute( | |
| """ | |
| SELECT | |
| vec_sentences.id, | |
| distance, | |
| sentence | |
| FROM vec_sentences | |
| LEFT JOIN sentences ON sentences.id = vec_sentences.id | |
| WHERE sentence_embedding MATCH ? | |
| AND k = ? | |
| ORDER BY distance | |
| """, | |
| [serialize(query_embedding), k] | |
| ).fetchall() | |
| # output results | |
| for row in results: | |
| print(f"ID: {row[0]}, Sentence: {row[1]}, Distance: {row[2]}") | |
| return results | |
| # run setup functions | |
| create_embeddings() | |
| create_tables() | |
| embed_sentences() | |
| # test query for RAG | |
| query_text = "how many customers in the database?" | |
| number_of_retrieval = 100 # retrieve up to 100 sentences sorted by distance | |
| print("Query Results:") | |
| results = retrieval(query_text, number_of_retrieval) | |
| # ======= | |
| # OUTPUT | |
| # ======= | |
| # Input your Google API key: | |
| # Query Results: | |
| # ID: 17, Distance: 0.7819159626960754, Sentence: Count the total number of orders made by each customer in 2015 | |
| # ID: 12, Distance: 0.785915195941925, Sentence: List customers who are married and have more than two children | |
| # ID: 24, Distance: 0.8111039400100708, Sentence: List all customers who have no children | |
| # ID: 4, Distance: 0.8175591826438904, Sentence: List all customers who are homeowners | |
| # ID: 10, Distance: 0.8231719732284546, Sentence: Count the number of customers with a Bachelors degree | |
| # ID: 19, Distance: 0.8273471593856812, Sentence: Identify the number of male and female customers | |
| # ID: 0, Distance: 0.8401814103126526, Sentence: Calculate the average age of all customers | |
| # ID: 21, Distance: 0.8487322330474854, Sentence: Get the first name, last name, and email of customers whose last name starts with 'H' | |
| # ID: 16, Distance: 0.856723427772522, Sentence: List all orders with a quantity greater than 2 from 2017 | |
| # ID: 5, Distance: 0.8622673153877258, Sentence: Get the total number of orders for each product | |
| # ID: 11, Distance: 0.8809165358543396, Sentence: Find the total quantity of products sold in 2017 | |
| # ID: 22, Distance: 0.8917487263679504, Sentence: Retrieve all orders made by customers from the 'United States' region | |
| # ID: 14, Distance: 0.893153727054596, Sentence: Calculate the average income of customers by marital status | |
| # ID: 7, Distance: 0.8941826224327087, Sentence: Identify customers born in the 1980s | |
| # ID: 9, Distance: 0.9029486179351807, Sentence: Retrieve the names of all products that cost more than $30 | |
| # ID: 2, Distance: 0.9082536697387695, Sentence: Find all the products and identify them by their unique key values in ascending order. | |
| # ID: 8, Distance: 0.918032705783844, Sentence: Find the most recent order date | |
| # ID: 13, Distance: 0.9238193035125732, Sentence: Get the product names and their corresponding subcategory names | |
| # ID: 15, Distance: 0.9384433031082153, Sentence: Find the name and category of products returned in January 2015 | |
| # ID: 23, Distance: 0.9451109170913696, Sentence: Find the top 3 categories with the most product returns | |
| # ID: 20, Distance: 0.9506728649139404, Sentence: List products with a profit margin over $15 | |
| # ID: 18, Distance: 0.9698100090026855, Sentence: Find all unique product colors available | |
| # ID: 3, Distance: 0.9702053666114807, Sentence: Find the 10 most expensive products in descending order. | |
| # ID: 1, Distance: 0.9767876267433167, Sentence: Find minimum product profit | |
| # ID: 6, Distance: 0.9809509515762329, Sentence: Find all products in the 'Bikes' category |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment