-
-
Save ayende/60a037ff9fc14aa62aabe28090f7d790 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import psycopg2 | |
| from pgvector.psycopg2 import register_vector | |
| import numpy as np | |
| from openai import OpenAI | |
| import os | |
| from dotenv import load_dotenv | |
| # Load environment variables | |
| load_dotenv() | |
| # Database connection parameters | |
| DB_PARAMS = { | |
| "dbname": os.getenv("DB_NAME"), | |
| "user": os.getenv("DB_USER"), | |
| "password": os.getenv("DB_PASSWORD"), | |
| "host": os.getenv("DB_HOST"), | |
| "port": os.getenv("DB_PORT") | |
| } | |
| # Initialize OpenAI client | |
| openai_client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) | |
| def prepare_embeddings(): | |
| """ | |
| Generate embeddings for products and create vector index | |
| Assumes a Products table exists with columns: id, name, description | |
| """ | |
| try: | |
| # Connect to database | |
| conn = psycopg2.connect(**DB_PARAMS) | |
| cur = conn.cursor() | |
| # Register vector type | |
| register_vector(conn) | |
| # Add vector column if it doesn't exist | |
| cur.execute(""" | |
| ALTER TABLE Products | |
| ADD COLUMN IF NOT EXISTS embedding vector(1536) | |
| """) | |
| # Fetch all products | |
| cur.execute("SELECT id, name, description FROM Products WHERE embedding IS NULL") | |
| products = cur.fetchall() | |
| print(f"Found {len(products)} products to process") | |
| for product_id, name, description in products: | |
| # Combine name and description for embedding | |
| text = f"{name}: {description}" | |
| # Generate embedding | |
| response = openai_client.embeddings.create( | |
| model="text-embedding-3-small", | |
| input=text | |
| ) | |
| embedding = response.data[0].embedding | |
| # Update product with embedding | |
| cur.execute(""" | |
| UPDATE Products | |
| SET embedding = %s | |
| WHERE id = %s | |
| """, (embedding, product_id)) | |
| print(f"Processed product ID: {product_id}") | |
| # Create index if it doesn't exist | |
| cur.execute(""" | |
| CREATE INDEX IF NOT EXISTS products_embedding_idx | |
| ON Products | |
| USING ivfflat (embedding vector_cosine_ops) | |
| WITH (lists = 100) | |
| """) | |
| conn.commit() | |
| print("Embedding preparation completed successfully") | |
| except Exception as e: | |
| print(f"Error in prepare_embeddings: {str(e)}") | |
| conn.rollback() | |
| finally: | |
| cur.close() | |
| conn.close() | |
| def query_products(search_text, limit=5): | |
| """ | |
| Query products using vector similarity search | |
| Returns list of tuples: (id, name, description, similarity) | |
| """ | |
| try: | |
| # Connect to database | |
| conn = psycopg2.connect(**DB_PARAMS) | |
| cur = conn.cursor() | |
| # Register vector type | |
| register_vector(conn) | |
| # Generate embedding for query | |
| response = openai_client.embeddings.create( | |
| model="text-embedding-3-small", | |
| input=search_text | |
| ) | |
| query_embedding = response.data[0].embedding | |
| # Query similar products using cosine similarity | |
| cur.execute(""" | |
| SELECT id, name, description, | |
| 1 - (embedding <=> %s) as similarity | |
| FROM Products | |
| WHERE embedding IS NOT NULL | |
| ORDER BY embedding <=> %s | |
| LIMIT %s | |
| """, (query_embedding, query_embedding, limit)) | |
| results = cur.fetchall() | |
| return results | |
| except Exception as e: | |
| print(f"Error in query_products: {str(e)}") | |
| return [] | |
| finally: | |
| cur.close() | |
| conn.close() | |
| # Example usage | |
| if __name__ == "__main__": | |
| # Prepare embeddings | |
| prepare_embeddings() | |
| # Example query | |
| search_query = "red cotton t-shirt" | |
| results = query_products(search_query) | |
| print("\nSearch results:") | |
| for product_id, name, desc, similarity in results: | |
| print(f"ID: {product_id}") | |
| print(f"Name: {name}") | |
| print(f"Description: {desc}") | |
| print(f"Similarity: {similarity:.4f}") | |
| print("---") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment