Created
January 5, 2024 12:24
-
-
Save mathislucka/c777bd16010cf56b876012454ace6b05 to your computer and use it in GitHub Desktop.
An example implementation of HyDE embeddings with Haystack 2.0
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from haystack import Pipeline, component, Document | |
from haystack.components.embedders import SentenceTransformersDocumentEmbedder | |
from haystack.components.generators.openai import GPTGenerator | |
from haystack.components.builders import PromptBuilder | |
from typing import List, Any, Literal, Optional | |
import numpy as np | |
# utility function used in Accumulator | |
def flatten(lst): | |
flattened_list = [] | |
for item in lst: | |
if isinstance(item, list): | |
flattened_list.extend(item) | |
else: | |
flattened_list.append(item) | |
return flattened_list | |
# We need 4 custom components | |
@component | |
class HydeVectorBuilder: | |
""" | |
A component returning an average over a sequence of input embeddings. | |
""" | |
@component.output_types(embedding=List[float]) | |
def run(self, embeddings:List[List[float]]): | |
all_embeddings = np.array(embeddings) | |
avg_embeddings = np.mean(all_embeddings, axis=0) | |
hyde_vector = avg_embeddings.reshape((1, len(avg_embeddings))) | |
return {"embedding": hyde_vector[0].tolist()} | |
@component | |
class DocumentExtractor: | |
""" | |
A component that allows you to extract metadata and other attributes from a Document. | |
""" | |
def __init__(self, extraction_attribute): | |
self.extraction_attribute = extraction_attribute | |
@component.output_types(values=List[List[float]]) | |
def run(self, documents: List[Document]): | |
return {"values": [getattr(doc, self.extraction_attribute) for doc in documents]} | |
@component | |
class Accumulator: | |
""" | |
A component that accumulates values from other components and turns them into a list. | |
""" | |
def __init__(self, runtime_variables: List[str], accumulation_strategy: Literal["flatten", "no_op"] = "flatten"): | |
self.accumulation_strategy = accumulation_strategy | |
kwargs_input_slots = {var: Optional[Any] for var in runtime_variables} | |
component.set_input_types(self, **kwargs_input_slots) | |
@component.output_types(values=List[str]) | |
def run(self, **kwargs): | |
inputs = [val for val in kwargs.values()] | |
if self.accumulation_strategy == "flatten": | |
inputs = list(flatten(inputs)) | |
return {"values": inputs} | |
@component | |
class DocumentBuilder: | |
""" | |
A component that turns a list of strings into a list of Documents | |
""" | |
@component.output_types(documents=List[Document]) | |
def run(self, contents:List[str]): | |
return {"documents": [Document(content=txt) for txt in contents]} | |
# Initializing all components | |
prompt = """ | |
Given a question, generate a paragraph of text that answers the question. | |
Question: {{question}} | |
Paragraph: | |
""" | |
embedder = SentenceTransformersDocumentEmbedder(model_name_or_path="sentence-transformers/all-MiniLM-L6-v2") | |
prompt_builder = PromptBuilder(template=prompt) | |
generator = GPTGenerator(model_name="gpt-3.5-turbo", generation_kwargs={"n": 5, "temperature": 0.75, "max_tokens": 400}) | |
document_builder = DocumentBuilder() | |
accumulator = Accumulator(runtime_variables=["replies", "question"]) | |
extractor = DocumentExtractor(extraction_attribute="embedding") | |
hyde_builder = HydeVectorBuilder() | |
# Creating the pipeline | |
pipe = Pipeline() | |
pipe.add_component(name="embedder", instance=embedder) | |
pipe.add_component(name="prompt_builder", instance=prompt_builder) | |
pipe.add_component(name="generator", instance=generator) | |
pipe.add_component(name="document_builder", instance=document_builder) | |
pipe.add_component(name="accumulator", instance=accumulator) | |
pipe.add_component(name="hyde_builder", instance=hyde_builder) | |
pipe.add_component(name="extractor", instance=extractor) | |
pipe.connect("prompt_builder", "generator") | |
pipe.connect("generator.replies", "accumulator.replies") | |
pipe.connect("accumulator.values", "document_builder.contents") | |
pipe.connect("document_builder", "embedder.documents") | |
pipe.connect("embedder", "extractor") | |
pipe.connect("extractor.values", "hyde_builder.embeddings") | |
# Testing things out | |
test_query = "What should I see in the capital of France?" | |
pipe.run(data={"prompt_builder": {"question": test_query}, "accumulator": {"question": test_query}}) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment