Skip to content

Instantly share code, notes, and snippets.

@Sandy4321
Forked from sunilkumardash9/langgraph_crag.py
Created March 22, 2024 16:45
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Sandy4321/c5d959627762ec668a850abc102a52ef to your computer and use it in GitHub Desktop.
Save Sandy4321/c5d959627762ec668a850abc102a52ef to your computer and use it in GitHub Desktop.
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import WebBaseLoader
from langchain_community.vectorstores import Chroma
from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings
from typing import Annotated, Dict, TypedDict
from langchain_core.messages import BaseMessage
import json
import operator
from typing import Annotated, Sequence, TypedDict
from langchain import hub
from langchain_core.output_parsers import JsonOutputParser
from langchain.prompts import PromptTemplate
from langchain.schema import Document
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_openai.chat_models import ChatOpenAI
# Load
url = "https://www.analyticsvidhya.com/blog/2023/10/introduction-to-hnsw-hierarchical-navigable-small-world/"
loader = WebBaseLoader(url)
docs = loader.load()
# Split
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
chunk_size=500, chunk_overlap=100
)
all_splits = text_splitter.split_documents(docs)
# Embed and index
embedding = SentenceTransformerEmbeddings(model_name="BAAI/bge-base-en-v1.5")
# Index
vectorstore = Chroma.from_documents(
documents=all_splits,
collection_name="rag-chroma",
embedding=embedding,
)
retriever = vectorstore.as_retriever()
class GraphState(TypedDict):
"""
Represents the state of our graph.
Attributes:
keys: A dictionary where each key is a string.
"""
keys: Dict[str, any]
### Nodes ###
TAVILY_API_KEY = os.environ.get("TAVILY_API_KEY")
llm = ChatOpenAI(base_url="https://api.together.xyz/v1",
api_key=os.environ.get("TOGETHER_API_KEY"),
model = "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO")
def retrieve(state):
"""
Retrieve documents
Args:
state (dict): The current graph state
Returns:
state (dict): New key added to state, documents, that contains retrieved documents
"""
print("---RETRIEVE---")
state_dict = state["keys"]
question = state_dict["question"]
documents = retriever.get_relevant_documents(question)
return {"keys": {"documents": documents, "question": question}}
def generate(state):
"""
Generate answer
Args:
state (dict): The current graph state
Returns:
state (dict): New key added to state, generation, that contains generation
"""
print("---GENERATE---")
state_dict = state["keys"]
question = state_dict["question"]
documents = state_dict["documents"]
# Prompt
prompt = hub.pull("rlm/rag-prompt")
# Post-processing
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)
# Chain
rag_chain = prompt | llm | StrOutputParser()
# Run
generation = rag_chain.invoke({"context": documents, "question": question})
return {
"keys": {"documents": documents, "question": question, "generation": generation}
}
def grade_documents(state):
"""
Determines whether the retrieved documents are relevant to the question.
Args:
state (dict): The current graph state
Returns:
state (dict): Updates documents key with relevant documents
"""
print("---CHECK RELEVANCE---")
state_dict = state["keys"]
question = state_dict["question"]
documents = state_dict["documents"]
prompt = PromptTemplate(
template="""You are a grader assessing relevance of a retrieved document to a user question. \n
Here is the retrieved document: \n\n {context} \n\n
Here is the user question: {question} \n
If the document contains keywords related to the user question, grade it as relevant. \n
It does not need to be a stringent test. The goal is to filter out erroneous retrievals. \n
Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question. \n
Provide the binary score as a JSON with a single key 'score' and no premable or explaination.""",
input_variables=["question", "context"],
)
chain = prompt | llm | JsonOutputParser()
# Score
filtered_docs = []
search = "No" # Default do not opt for web search to supplement retrieval
for d in documents:
score = chain.invoke(
{
"question": question,
"context": d.page_content,
}
)
grade = score["score"]
if grade == "yes":
print("---GRADE: DOCUMENT RELEVANT---")
filtered_docs.append(d)
else:
print("---GRADE: DOCUMENT NOT RELEVANT---")
search = "Yes" # Perform web search
continue
return {
"keys": {
"documents": filtered_docs,
"question": question,
"run_web_search": search,
}
}
def transform_query(state):
"""
Transform the query to produce a better question.
Args:
state (dict): The current graph state
Returns:
state (dict): Updates question key with a re-phrased question
"""
print("---TRANSFORM QUERY---")
state_dict = state["keys"]
question = state_dict["question"]
documents = state_dict["documents"]
# Create a prompt template with format instructions and the query
prompt = PromptTemplate(
template="""You are generating questions that is well optimized for retrieval. \n
Look at the input and try to reason about the underlying sematic intent / meaning. \n
Here is the initial question:
\n ------- \n
{question}
\n ------- \n
Provide an improved question without any premable, only respond with the updated question: """,
input_variables=["question"],
)
# Grader
# Prompt
chain = prompt | llm | StrOutputParser()
better_question = chain.invoke({"question": question})
return {
"keys": {"documents": documents, "question": better_question,}
}
def web_search(state):
"""
Web search based on the re-phrased question using Tavily API.
Args:
state (dict): The current graph state
Returns:
state (dict): Web results appended to documents.
"""
print("---WEB SEARCH---")
state_dict = state["keys"]
question = state_dict["question"]
documents = state_dict["documents"]
tool = TavilySearchResults()
docs = tool.invoke({"query": question})
web_results = "\n".join([d["content"] for d in docs])
print(web_results)
web_results = Document(page_content=web_results)
documents.append(web_results)
return {"keys": {"documents": documents, "question": question}}
### Edges
def decide_to_generate(state):
"""
Determines whether to generate an answer or re-generate a question for web search.
Args:
state (dict): The current state of the agent, including all keys.
Returns:
str: Next node to call
"""
print("---DECIDE TO GENERATE---")
state_dict = state["keys"]
question = state_dict["question"]
filtered_documents = state_dict["documents"]
search = state_dict["run_web_search"]
if search == "Yes":
# All documents have been filtered check_relevance
# We will re-generate a new query
print("---DECISION: TRANSFORM QUERY and RUN WEB SEARCH---")
return "transform_query"
else:
# We have relevant documents, so generate answer
print("---DECISION: GENERATE---")
return "generate"
import pprint
from langgraph.graph import END, StateGraph
workflow = StateGraph(GraphState)
# Define the nodes
workflow.add_node("retrieve", retrieve) # retrieve
workflow.add_node("grade_documents", grade_documents) # grade documents
workflow.add_node("generate", generate) # generatae
workflow.add_node("transform_query", transform_query) # transform_query
workflow.add_node("web_search", web_search) # web search
# Build graph
workflow.set_entry_point("retrieve")
workflow.add_edge("retrieve", "grade_documents")
workflow.add_conditional_edges(
"grade_documents",
decide_to_generate,
{
"transform_query": "transform_query",
"generate": "generate",
},
)
workflow.add_edge("transform_query", "web_search")
workflow.add_edge("web_search", "generate")
workflow.add_edge("generate", END)
# Compile
app = workflow.compile()
# Run
inputs = {
"keys": {
"question": "Who is the author of HNSW paper?",
}
}
for output in app.stream(inputs):
for key, value in output.items():
# Node
pprint.pprint(f"Node '{key}':")
# Optional: print full state at each node
# pprint.pprint(value["keys"], indent=2, width=80, depth=None)
pprint.pprint("\n---\n")
# Final generation
pprint.pprint(value["keys"]["generation"])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment