-
-
Save stevekrenzel/f9bb918ce823b5714efa6ea550706acc to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # Python 3.11 or earlier is required for `torch` support | |
| # pip install torch pandas sentence-transformers transformers faiss-cpu numpy | |
| import pandas as pd | |
| from sentence_transformers import SentenceTransformer | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| import torch | |
| import faiss | |
| import numpy as np | |
| from typing import List, Tuple, Optional | |
| class RecipeSearch: | |
| def __init__(self, csv_path: str, embedding_model: str = 'sentence-transformers/all-MiniLM-L6-v2', | |
| reranker_model: str = 'cross-encoder/ms-marco-MiniLM-L-12-v2'): | |
| """ | |
| Initialize the recipe search system. | |
| Args: | |
| csv_path: Path to the recipes CSV file | |
| embedding_model: Name of the sentence transformer model to use for embeddings | |
| reranker_model: Name of the cross-encoder model to use for reranking | |
| """ | |
| self.df = self._load_and_filter_recipes(csv_path) | |
| self.embedding_model = SentenceTransformer(embedding_model) | |
| self.tokenizer = AutoTokenizer.from_pretrained(reranker_model) | |
| self.reranker = AutoModelForSequenceClassification.from_pretrained(reranker_model) | |
| self.index = self._build_index() | |
| def _load_and_filter_recipes(self, csv_path: str) -> pd.DataFrame: | |
| """Load and filter recipes for holiday cookies.""" | |
| df = pd.read_csv(csv_path) | |
| cookie_df = df[df['Name'].str.lower().str.contains('cookie')] | |
| return cookie_df[cookie_df['Name'].str.lower().str.contains('holiday|christmas')] | |
| def _build_index(self) -> faiss.IndexFlatIP: | |
| """Build a FAISS index from recipe embeddings.""" | |
| recipe_embeddings = self.embedding_model.encode( | |
| self.df['RecipeInstructions'].tolist(), | |
| convert_to_numpy=True | |
| ) | |
| index = faiss.IndexFlatIP(recipe_embeddings.shape[1]) | |
| index.add(recipe_embeddings) | |
| return index | |
| def _get_relevance_score(self, query: str, recipe_text: str) -> float: | |
| """Calculate relevance score between query and recipe using the reranker.""" | |
| combined_input = f"Query: {query}\nRecipe: {recipe_text}" | |
| inputs = self.tokenizer( | |
| combined_input, | |
| return_tensors='pt', | |
| truncation=True, | |
| max_length=512 | |
| ) | |
| with torch.no_grad(): | |
| outputs = self.reranker(**inputs) | |
| if len(outputs.logits[0]) == 1: | |
| return outputs.logits[0][0].item() | |
| elif len(outputs.logits[0]) == 2: | |
| return outputs.logits[0][1].item() | |
| else: | |
| return outputs.logits[0].max().item() | |
| def search(self, query: str, top_k: int = 5, top_j: int = 50) -> Tuple[List[dict], List[dict]]: | |
| """ | |
| Search for recipes matching the query. | |
| Args: | |
| query: Search query string | |
| top_k: Number of top results to return | |
| top_j: Number of initial results to retrieve from the index | |
| Returns: | |
| Tuple of (unranked_results, ranked_results) where each is a list of | |
| dictionaries containing recipe information and scores | |
| """ | |
| # Initial retrieval | |
| query_emb = self.embedding_model.encode([query], convert_to_numpy=True) | |
| distances, indices = self.index.search(query_emb, top_j) | |
| # Format unranked results | |
| unranked_results = [] | |
| for i, (idx, score) in enumerate(zip(indices[0], distances[0])): | |
| recipe = self.df.iloc[idx] | |
| unranked_results.append({ | |
| 'rank': i + 1, | |
| 'score': float(score), | |
| 'name': recipe['Name'], | |
| 'description': recipe['Description'], | |
| 'instructions': recipe['RecipeInstructions'] | |
| }) | |
| # Rerank results | |
| recipe_scores = [] | |
| for idx in indices[0]: | |
| recipe = self.df.iloc[idx] | |
| recipe_text = f"{recipe['Name']} {recipe['Description']} {recipe['RecipeInstructions']}" | |
| score = self._get_relevance_score(query, recipe_text) | |
| recipe_scores.append((idx, score)) | |
| # Sort and format ranked results | |
| ranked_results = [] | |
| for i, (idx, score) in enumerate(sorted(recipe_scores, key=lambda x: x[1], reverse=True)): | |
| recipe = self.df.iloc[idx] | |
| ranked_results.append({ | |
| 'rank': i + 1, | |
| 'score': score, | |
| 'name': recipe['Name'], | |
| 'description': recipe['Description'], | |
| 'instructions': recipe['RecipeInstructions'] | |
| }) | |
| return list(unranked_results)[:top_k], list(ranked_results)[:top_k] | |
| def print_results(results: List[dict], title: str, max_instructions: int = 200): | |
| """Helper function to print search results.""" | |
| print(f"\n{title}\n") | |
| for result in results: | |
| print(f"#{result['rank']} (Score: {result['score']:.3f})") | |
| print(f"Recipe: {result['name']}") | |
| print(f"Description: {result['description']}") | |
| print(f"Instructions: {result['instructions'][:max_instructions]}...") | |
| print("-" * 80 + "\n") | |
| def main(): | |
| """Main interactive loop for recipe search.""" | |
| # Initialize the search system | |
| searcher = RecipeSearch("recipes.csv") | |
| print("Welcome to the Holiday Cookie Recipe Search!") | |
| print("Enter your search queries (or 'quit' to exit)") | |
| while True: | |
| # Get user input | |
| query = input("\nEnter your search query: ").strip() | |
| if query.lower() == 'quit': | |
| break | |
| # Get number of results to show | |
| try: | |
| top_k = int(input("How many results would you like to see? (default: 5) ") or "5") | |
| except ValueError: | |
| print("Invalid input, using default value of 5") | |
| top_k = 5 | |
| # Perform search | |
| unranked_results, ranked_results = searcher.search(query, top_k) | |
| # Print results | |
| print_results(unranked_results, f"Unranked Top {top_k} matches for query: '{query}'") | |
| print_results(ranked_results, f"Reranked Top {top_k} matches for query: '{query}'") | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment