Skip to content

Instantly share code, notes, and snippets.

@stevekrenzel
Last active December 29, 2024 01:41
Show Gist options
  • Select an option

  • Save stevekrenzel/f9bb918ce823b5714efa6ea550706acc to your computer and use it in GitHub Desktop.

Select an option

Save stevekrenzel/f9bb918ce823b5714efa6ea550706acc to your computer and use it in GitHub Desktop.
# Python 3.11 or earlier is required for `torch` support
# pip install torch pandas sentence-transformers transformers faiss-cpu numpy
import pandas as pd
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import faiss
import numpy as np
from typing import List, Tuple, Optional
class RecipeSearch:
def __init__(self, csv_path: str, embedding_model: str = 'sentence-transformers/all-MiniLM-L6-v2',
reranker_model: str = 'cross-encoder/ms-marco-MiniLM-L-12-v2'):
"""
Initialize the recipe search system.
Args:
csv_path: Path to the recipes CSV file
embedding_model: Name of the sentence transformer model to use for embeddings
reranker_model: Name of the cross-encoder model to use for reranking
"""
self.df = self._load_and_filter_recipes(csv_path)
self.embedding_model = SentenceTransformer(embedding_model)
self.tokenizer = AutoTokenizer.from_pretrained(reranker_model)
self.reranker = AutoModelForSequenceClassification.from_pretrained(reranker_model)
self.index = self._build_index()
def _load_and_filter_recipes(self, csv_path: str) -> pd.DataFrame:
"""Load and filter recipes for holiday cookies."""
df = pd.read_csv(csv_path)
cookie_df = df[df['Name'].str.lower().str.contains('cookie')]
return cookie_df[cookie_df['Name'].str.lower().str.contains('holiday|christmas')]
def _build_index(self) -> faiss.IndexFlatIP:
"""Build a FAISS index from recipe embeddings."""
recipe_embeddings = self.embedding_model.encode(
self.df['RecipeInstructions'].tolist(),
convert_to_numpy=True
)
index = faiss.IndexFlatIP(recipe_embeddings.shape[1])
index.add(recipe_embeddings)
return index
def _get_relevance_score(self, query: str, recipe_text: str) -> float:
"""Calculate relevance score between query and recipe using the reranker."""
combined_input = f"Query: {query}\nRecipe: {recipe_text}"
inputs = self.tokenizer(
combined_input,
return_tensors='pt',
truncation=True,
max_length=512
)
with torch.no_grad():
outputs = self.reranker(**inputs)
if len(outputs.logits[0]) == 1:
return outputs.logits[0][0].item()
elif len(outputs.logits[0]) == 2:
return outputs.logits[0][1].item()
else:
return outputs.logits[0].max().item()
def search(self, query: str, top_k: int = 5, top_j: int = 50) -> Tuple[List[dict], List[dict]]:
"""
Search for recipes matching the query.
Args:
query: Search query string
top_k: Number of top results to return
top_j: Number of initial results to retrieve from the index
Returns:
Tuple of (unranked_results, ranked_results) where each is a list of
dictionaries containing recipe information and scores
"""
# Initial retrieval
query_emb = self.embedding_model.encode([query], convert_to_numpy=True)
distances, indices = self.index.search(query_emb, top_j)
# Format unranked results
unranked_results = []
for i, (idx, score) in enumerate(zip(indices[0], distances[0])):
recipe = self.df.iloc[idx]
unranked_results.append({
'rank': i + 1,
'score': float(score),
'name': recipe['Name'],
'description': recipe['Description'],
'instructions': recipe['RecipeInstructions']
})
# Rerank results
recipe_scores = []
for idx in indices[0]:
recipe = self.df.iloc[idx]
recipe_text = f"{recipe['Name']} {recipe['Description']} {recipe['RecipeInstructions']}"
score = self._get_relevance_score(query, recipe_text)
recipe_scores.append((idx, score))
# Sort and format ranked results
ranked_results = []
for i, (idx, score) in enumerate(sorted(recipe_scores, key=lambda x: x[1], reverse=True)):
recipe = self.df.iloc[idx]
ranked_results.append({
'rank': i + 1,
'score': score,
'name': recipe['Name'],
'description': recipe['Description'],
'instructions': recipe['RecipeInstructions']
})
return list(unranked_results)[:top_k], list(ranked_results)[:top_k]
def print_results(results: List[dict], title: str, max_instructions: int = 200):
"""Helper function to print search results."""
print(f"\n{title}\n")
for result in results:
print(f"#{result['rank']} (Score: {result['score']:.3f})")
print(f"Recipe: {result['name']}")
print(f"Description: {result['description']}")
print(f"Instructions: {result['instructions'][:max_instructions]}...")
print("-" * 80 + "\n")
def main():
"""Main interactive loop for recipe search."""
# Initialize the search system
searcher = RecipeSearch("recipes.csv")
print("Welcome to the Holiday Cookie Recipe Search!")
print("Enter your search queries (or 'quit' to exit)")
while True:
# Get user input
query = input("\nEnter your search query: ").strip()
if query.lower() == 'quit':
break
# Get number of results to show
try:
top_k = int(input("How many results would you like to see? (default: 5) ") or "5")
except ValueError:
print("Invalid input, using default value of 5")
top_k = 5
# Perform search
unranked_results, ranked_results = searcher.search(query, top_k)
# Print results
print_results(unranked_results, f"Unranked Top {top_k} matches for query: '{query}'")
print_results(ranked_results, f"Reranked Top {top_k} matches for query: '{query}'")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment