Skip to content

Instantly share code, notes, and snippets.

@vblagoje
Created August 2, 2023 15:30
Show Gist options
  • Save vblagoje/430def6cda347c0b65f5f244bc0f2ede to your computer and use it in GitHub Desktop.
Save vblagoje/430def6cda347c0b65f5f244bc0f2ede to your computer and use it in GitHub Desktop.
Two Haystack LFQA/RAG pipelines matchup
import logging
import os
from typing import Dict, Any, List, Optional, Union, Tuple
import torch
from sentence_transformers import SentenceTransformer, util
from haystack import Pipeline, Document, BaseComponent, MultiLabel
from haystack.nodes import PromptNode, PromptTemplate, TopPSampler, PreProcessor
from haystack.nodes.ranker.diversity import DiversityRanker
from haystack.nodes.ranker.lost_in_the_middle import LostInTheMiddleRanker
from haystack.nodes.retriever.web import WebRetriever
class WordCountThreshold(BaseComponent):
"""
A node to filter documents based on a word count threshold.
"""
outgoing_edges = 1
def __init__(self, word_count_threshold: int):
"""
:param word_count_threshold: The maximum number of words allowed in the concatenated documents.
"""
super().__init__()
self.word_count_threshold = word_count_threshold
def run(
self,
query: Optional[str] = None,
file_paths: Optional[List[str]] = None,
labels: Optional[MultiLabel] = None,
documents: Optional[List[Document]] = None,
meta: Optional[dict] = None,
) -> Tuple[Dict, str]: # type: ignore
"""
Filters documents based on the word count threshold.
:param documents: List of Documents to filter.
:return: Filtered list of Documents
"""
results: Dict = {"documents": []}
word_count = 0
for doc in documents:
doc_word_count = len(doc.content.split())
if word_count + doc_word_count <= self.word_count_threshold:
results["documents"].append(doc)
word_count += doc_word_count
else:
break
return results, "output_1"
def run_batch(
self,
queries: Optional[Union[str, List[str]]] = None,
file_paths: Optional[List[str]] = None,
labels: Optional[Union[MultiLabel, List[MultiLabel]]] = None,
documents: Optional[Union[List[Document], List[List[Document]]]] = None,
meta: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
params: Optional[dict] = None,
debug: Optional[bool] = None,
):
pass
def average_pairwise_cosine_distance(encoder: SentenceTransformer, document: List[Document]):
sentences = [doc.content for doc in document]
# Encode all docs into embeddings
embeddings = encoder.encode(sentences, convert_to_tensor=True)
# Cosine similarity for all doc pairs
cos_sim_matrix = util.cos_sim(embeddings, embeddings)
# Cosine similarity -> cosine distance
cos_dist_matrix = 1 - cos_sim_matrix
# Use torch triu to mask upper triangular of the matrix excluding the diagonal
mask = torch.triu(torch.ones_like(cos_dist_matrix, dtype=torch.bool), diagonal=1)
# Use the mask to calculate the mean cosine distance for all doc pairs
avg_cosine_distance = cos_dist_matrix[mask].mean().item()
return avg_cosine_distance
search_key = os.environ.get("SERPERDEV_API_KEY")
if not search_key:
raise ValueError("Please set the SERPERDEV_API_KEY environment variable")
models_info: Dict[str, Any] = {
"openai": {"api_key": os.environ.get("OPENAI_API_KEY"), "model_name": "gpt-3.5-turbo"},
"anthropic": {"api_key": os.environ.get("ANTHROPIC_API_KEY"), "model_name": "claude-instant-1"},
"hf": {"api_key": os.environ.get("HF_API_KEY"), "model_name": "tiiuae/falcon-7b"},
}
prompt_text = """
Synthesize a comprehensive answer from the provided paragraphs and the given question.\n
Answer in full sentences and paragraphs, don't use bullet points or lists.\n
If the answer includes multiple chronological events, order them chronologically.\n
\n\n Paragraphs: {join(documents)} \n\n Question: {query} \n\n Answer:
"""
stream = False
model: Dict[str, str] = models_info["openai"]
prompt_node = PromptNode(
model["model_name"],
default_prompt_template=PromptTemplate(prompt_text),
api_key=model["api_key"],
max_length=768,
model_kwargs={"stream": stream, "model_max_length": 2048},
)
web_retriever = WebRetriever(
api_key=search_key,
top_search_results=10,
preprocessor=PreProcessor(progress_bar=False, split_length=150),
mode="preprocessed_documents",
top_k=50,
)
encoder = SentenceTransformer("all-MiniLM-L6-v2")
sampler = TopPSampler(top_p=0.99)
opt_pipe = Pipeline()
opt_pipe.add_node(component=web_retriever, name="Retriever", inputs=["Query"])
opt_pipe.add_node(component=sampler, name="Sampler", inputs=["Retriever"])
opt_pipe.add_node(component=DiversityRanker(), name="DiversityRanker", inputs=["Sampler"])
opt_pipe.add_node(component=LostInTheMiddleRanker(word_count_threshold=1024), name="LITM", inputs=["DiversityRanker"])
opt_pipe.add_node(component=prompt_node, name="PromptNode", inputs=["LITM"])
regular_pipe = Pipeline()
regular_pipe.add_node(component=web_retriever, name="Retriever", inputs=["Query"])
regular_pipe.add_node(component=sampler, name="Sampler", inputs=["Retriever"])
regular_pipe.add_node(component=WordCountThreshold(word_count_threshold=1024), name="Counter", inputs=["Sampler"])
regular_pipe.add_node(component=prompt_node, name="PromptNode", inputs=["Counter"])
logger = logging.getLogger("boilerpy3")
logger.setLevel(logging.CRITICAL)
question_list = [
"What are the main reasons for long-standing animosities between Russia and Poland?",
"What are the primary causes and effects of climate change on global and local scales?",
"What were the key events and influences that led to Renaissance; how did these developments "
"shape modern Western culture?",
"How have advances in technology in the 21st century affected job markets and economies around the world?",
"What are the main reasons behind the Israel-Palestine conflict and how have they evolved over time?",
"How has the European Union influenced the political, economic, and social dynamics of Europe?",
]
def run_pipeline(pipe: Pipeline = None):
total_cosine_distance = 0
cosine_distance_list = []
for question in question_list:
print(f"\nQuestion: {question}")
answer = pipe.run(query=question)
avg_cosine_distance = average_pairwise_cosine_distance(encoder, answer["documents"])
total_cosine_distance += avg_cosine_distance
if not stream:
print(f"Answer: {answer}")
cosine_distance_list.append(avg_cosine_distance)
print(f"\nAverage pairwise cosine distance: {total_cosine_distance / len(question_list)}")
for question, avg_cosine_distance in zip(question_list, cosine_distance_list):
print(f"Question: {question}, average pairwise cosine distance: {avg_cosine_distance}")
print(f"\nRunning optimized pipeline with {model['model_name']}\n")
run_pipeline(opt_pipe)
print(f"\nRunning regular pipeline with {model['model_name']}\n")
run_pipeline(regular_pipe)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment