Last active
March 22, 2024 16:45
-
-
Save sunilkumardash9/418936a2188268293d4e4943e8e91633 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_community.document_loaders import WebBaseLoader | |
from langchain_community.vectorstores import Chroma | |
from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings | |
from typing import Annotated, Dict, TypedDict | |
from langchain_core.messages import BaseMessage | |
import json | |
import operator | |
from typing import Annotated, Sequence, TypedDict | |
from langchain import hub | |
from langchain_core.output_parsers import JsonOutputParser | |
from langchain.prompts import PromptTemplate | |
from langchain.schema import Document | |
from langchain_community.tools.tavily_search import TavilySearchResults | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.runnables import RunnablePassthrough | |
from langchain_openai.chat_models import ChatOpenAI | |
# Load | |
url = "https://www.analyticsvidhya.com/blog/2023/10/introduction-to-hnsw-hierarchical-navigable-small-world/" | |
loader = WebBaseLoader(url) | |
docs = loader.load() | |
# Split | |
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder( | |
chunk_size=500, chunk_overlap=100 | |
) | |
all_splits = text_splitter.split_documents(docs) | |
# Embed and index | |
embedding = SentenceTransformerEmbeddings(model_name="BAAI/bge-base-en-v1.5") | |
# Index | |
vectorstore = Chroma.from_documents( | |
documents=all_splits, | |
collection_name="rag-chroma", | |
embedding=embedding, | |
) | |
retriever = vectorstore.as_retriever() | |
class GraphState(TypedDict): | |
""" | |
Represents the state of our graph. | |
Attributes: | |
keys: A dictionary where each key is a string. | |
""" | |
keys: Dict[str, any] | |
### Nodes ### | |
TAVILY_API_KEY = os.environ.get("TAVILY_API_KEY") | |
llm = ChatOpenAI(base_url="https://api.together.xyz/v1", | |
api_key=os.environ.get("TOGETHER_API_KEY"), | |
model = "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO") | |
def retrieve(state): | |
""" | |
Retrieve documents | |
Args: | |
state (dict): The current graph state | |
Returns: | |
state (dict): New key added to state, documents, that contains retrieved documents | |
""" | |
print("---RETRIEVE---") | |
state_dict = state["keys"] | |
question = state_dict["question"] | |
documents = retriever.get_relevant_documents(question) | |
return {"keys": {"documents": documents, "question": question}} | |
def generate(state): | |
""" | |
Generate answer | |
Args: | |
state (dict): The current graph state | |
Returns: | |
state (dict): New key added to state, generation, that contains generation | |
""" | |
print("---GENERATE---") | |
state_dict = state["keys"] | |
question = state_dict["question"] | |
documents = state_dict["documents"] | |
# Prompt | |
prompt = hub.pull("rlm/rag-prompt") | |
# Post-processing | |
def format_docs(docs): | |
return "\n\n".join(doc.page_content for doc in docs) | |
# Chain | |
rag_chain = prompt | llm | StrOutputParser() | |
# Run | |
generation = rag_chain.invoke({"context": documents, "question": question}) | |
return { | |
"keys": {"documents": documents, "question": question, "generation": generation} | |
} | |
def grade_documents(state): | |
""" | |
Determines whether the retrieved documents are relevant to the question. | |
Args: | |
state (dict): The current graph state | |
Returns: | |
state (dict): Updates documents key with relevant documents | |
""" | |
print("---CHECK RELEVANCE---") | |
state_dict = state["keys"] | |
question = state_dict["question"] | |
documents = state_dict["documents"] | |
prompt = PromptTemplate( | |
template="""You are a grader assessing relevance of a retrieved document to a user question. \n | |
Here is the retrieved document: \n\n {context} \n\n | |
Here is the user question: {question} \n | |
If the document contains keywords related to the user question, grade it as relevant. \n | |
It does not need to be a stringent test. The goal is to filter out erroneous retrievals. \n | |
Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question. \n | |
Provide the binary score as a JSON with a single key 'score' and no premable or explaination.""", | |
input_variables=["question", "context"], | |
) | |
chain = prompt | llm | JsonOutputParser() | |
# Score | |
filtered_docs = [] | |
search = "No" # Default do not opt for web search to supplement retrieval | |
for d in documents: | |
score = chain.invoke( | |
{ | |
"question": question, | |
"context": d.page_content, | |
} | |
) | |
grade = score["score"] | |
if grade == "yes": | |
print("---GRADE: DOCUMENT RELEVANT---") | |
filtered_docs.append(d) | |
else: | |
print("---GRADE: DOCUMENT NOT RELEVANT---") | |
search = "Yes" # Perform web search | |
continue | |
return { | |
"keys": { | |
"documents": filtered_docs, | |
"question": question, | |
"run_web_search": search, | |
} | |
} | |
def transform_query(state): | |
""" | |
Transform the query to produce a better question. | |
Args: | |
state (dict): The current graph state | |
Returns: | |
state (dict): Updates question key with a re-phrased question | |
""" | |
print("---TRANSFORM QUERY---") | |
state_dict = state["keys"] | |
question = state_dict["question"] | |
documents = state_dict["documents"] | |
# Create a prompt template with format instructions and the query | |
prompt = PromptTemplate( | |
template="""You are generating questions that is well optimized for retrieval. \n | |
Look at the input and try to reason about the underlying sematic intent / meaning. \n | |
Here is the initial question: | |
\n ------- \n | |
{question} | |
\n ------- \n | |
Provide an improved question without any premable, only respond with the updated question: """, | |
input_variables=["question"], | |
) | |
# Grader | |
# Prompt | |
chain = prompt | llm | StrOutputParser() | |
better_question = chain.invoke({"question": question}) | |
return { | |
"keys": {"documents": documents, "question": better_question,} | |
} | |
def web_search(state): | |
""" | |
Web search based on the re-phrased question using Tavily API. | |
Args: | |
state (dict): The current graph state | |
Returns: | |
state (dict): Web results appended to documents. | |
""" | |
print("---WEB SEARCH---") | |
state_dict = state["keys"] | |
question = state_dict["question"] | |
documents = state_dict["documents"] | |
tool = TavilySearchResults() | |
docs = tool.invoke({"query": question}) | |
web_results = "\n".join([d["content"] for d in docs]) | |
print(web_results) | |
web_results = Document(page_content=web_results) | |
documents.append(web_results) | |
return {"keys": {"documents": documents, "question": question}} | |
### Edges | |
def decide_to_generate(state): | |
""" | |
Determines whether to generate an answer or re-generate a question for web search. | |
Args: | |
state (dict): The current state of the agent, including all keys. | |
Returns: | |
str: Next node to call | |
""" | |
print("---DECIDE TO GENERATE---") | |
state_dict = state["keys"] | |
question = state_dict["question"] | |
filtered_documents = state_dict["documents"] | |
search = state_dict["run_web_search"] | |
if search == "Yes": | |
# All documents have been filtered check_relevance | |
# We will re-generate a new query | |
print("---DECISION: TRANSFORM QUERY and RUN WEB SEARCH---") | |
return "transform_query" | |
else: | |
# We have relevant documents, so generate answer | |
print("---DECISION: GENERATE---") | |
return "generate" | |
import pprint | |
from langgraph.graph import END, StateGraph | |
workflow = StateGraph(GraphState) | |
# Define the nodes | |
workflow.add_node("retrieve", retrieve) # retrieve | |
workflow.add_node("grade_documents", grade_documents) # grade documents | |
workflow.add_node("generate", generate) # generatae | |
workflow.add_node("transform_query", transform_query) # transform_query | |
workflow.add_node("web_search", web_search) # web search | |
# Build graph | |
workflow.set_entry_point("retrieve") | |
workflow.add_edge("retrieve", "grade_documents") | |
workflow.add_conditional_edges( | |
"grade_documents", | |
decide_to_generate, | |
{ | |
"transform_query": "transform_query", | |
"generate": "generate", | |
}, | |
) | |
workflow.add_edge("transform_query", "web_search") | |
workflow.add_edge("web_search", "generate") | |
workflow.add_edge("generate", END) | |
# Compile | |
app = workflow.compile() | |
# Run | |
inputs = { | |
"keys": { | |
"question": "Who is the author of HNSW paper?", | |
} | |
} | |
for output in app.stream(inputs): | |
for key, value in output.items(): | |
# Node | |
pprint.pprint(f"Node '{key}':") | |
# Optional: print full state at each node | |
# pprint.pprint(value["keys"], indent=2, width=80, depth=None) | |
pprint.pprint("\n---\n") | |
# Final generation | |
pprint.pprint(value["keys"]["generation"]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment