Skip to content

Instantly share code, notes, and snippets.

@hwchase17
Last active December 5, 2023 15:35
Show Gist options
  • Save hwchase17/f93f3043a5c9fe2b69e19e2674f80d19 to your computer and use it in GitHub Desktop.
Save hwchase17/f93f3043a5c9fe2b69e19e2674f80d19 to your computer and use it in GitHub Desktop.
from langchain.prompts import ChatPromptTemplate
from langchain.chat_models import ChatOpenAI
from langchain.schema.output_parser import StrOutputParser
from langchain.vectorstores import Chroma
from langchain.embeddings import OpenAIEmbeddings
from langchain.schema.runnable import RunnablePassthrough
from langchain.schema.runnable import RunnableMap
from langchain.schema import format_document
from typing import AsyncGenerator
# Create the retriever
vectorstore = Chroma.from_texts(["harrison worked at kensho"], embedding=OpenAIEmbeddings())
retriever = vectorstore.as_retriever()
from langchain.prompts.prompt import PromptTemplate
_template = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question, in its original language.
Chat History:
{chat_history}
Follow Up Input: {question}
Standalone question:"""
CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)
template = """Answer the question based only on the following context:
{context}
Question: {question}
"""
ANSWER_PROMPT = ChatPromptTemplate.from_template(template)
DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")
def _combine_documents(docs, document_prompt = DEFAULT_DOCUMENT_PROMPT, document_separator="\n\n"):
doc_strings = [format_document(doc, document_prompt) for doc in docs]
return document_separator.join(doc_strings)
from typing import Tuple, List
def _format_chat_history(chat_history: List[Tuple]) -> str:
buffer = ""
for dialogue_turn in chat_history:
human = "Human: " + dialogue_turn[0]
ai = "Assistant: " + dialogue_turn[1]
buffer += "\n" + "\n".join([human, ai])
return buffer
from langchain.memory import ConversationBufferMemory
memory = ConversationBufferMemory(return_messages=True, output_key="answer", input_key="question")
standalone_question_chain = {
"question": lambda x: x["question"],
"chat_history": lambda x: _format_chat_history(x['chat_history'])
} | CONDENSE_QUESTION_PROMPT | ChatOpenAI(temperature=0) | StrOutputParser()
response_chain = ANSWER_PROMPT | ChatOpenAI()
import asyncio
async def generate_response(message: str) -> AsyncGenerator[str, None]:
_memory = memory.load_memory_variables({})
standalone_question = await standalone_question_chain.ainvoke({
"question": message,
"chat_history": _format_chat_history(_memory["history"])
})
retrieved_docs = await retriever.ainvoke(standalone_question)
final_response = ""
async for m in response_chain.astream({
"question": standalone_question,
"context": _combine_documents(retrieved_docs)
}):
final_response += m.content
yield m.content
# Need to save memory explicitly
memory.save_context({"question": message}, {"answer": final_response})
import nest_asyncio
nest_asyncio.apply()
async def run_async(gen):
return [item async for item in gen]
async def main():
responses = await run_async(generate_response("where did harrison work"))
for response in responses:
print(response)
await main()
@ohmeow
Copy link

ohmeow commented Aug 8, 2023

Modified your example above that fixes building the chat history and also returns the streamed data as json instead of strings. Let me know if you think this can be improved in any way ...

from fastapi import FastAPI
from fastapi.responses import StreamingResponse
import os

from dotenv import load_dotenv
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from langchain.chat_models import ChatOpenAI
from pydantic import BaseModel

from typing import AsyncGenerator
from pydantic import BaseModel, validator
from langchain.chat_models import ChatOpenAI
from langchain.memory import ConversationBufferMemory
from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import CharacterTextSplitter
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import Chroma

from langchain.schema.output_parser import StrOutputParser
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import format_document
from langchain.prompts import ChatPromptTemplate
from langchain.schema import HumanMessage, AIMessage, SystemMessage
from typing import Tuple, List, Any

load_dotenv()

app = FastAPI()

vectorstore = Chroma.from_texts(["harrison worked at kensho", "wayde works at UCSD"], embedding=OpenAIEmbeddings())
retriever = vectorstore.as_retriever()


# templates
qst_gen_template = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question, in its original language.
Chat History:
{chat_history}
Follow Up Input: {question}
Standalone question:"""
QST_GEN_PROMPT = PromptTemplate.from_template(qst_gen_template)

doc_template = "{page_content}"
DOCUMENT_PROMPT = PromptTemplate.from_template(template=doc_template)

ans_template = """Answer the question based only on the following context:
{context}

Question: {question}
"""
ANSWER_PROMPT = ChatPromptTemplate.from_template(template=ans_template)


# Helpers
def _format_chat_history(chat_history: List[HumanMessage | AIMessage]) -> str:
    buffer = ""
    for dialogue in chat_history:
        buffer += f"{dialogue.type.upper()}: {dialogue.content}\n"
    return buffer


def _combine_documents(docs, document_prompt=DOCUMENT_PROMPT, document_separator="\n\n"):
    doc_strings = [format_document(doc, document_prompt) for doc in docs]
    return document_separator.join(doc_strings)


class ChatRequest(BaseModel):
    """Request model for chat requests. Includes the conversation ID and the message from the user."""

    conversation_id: str
    message: str


class ChatResponse(BaseModel):
    """Chat response schema"""

    message: str | dict
    type: str

    @validator("type")
    def validate_message_type(cls, v):
        if v not in ["start", "stream", "end", "error", "info"]:
            raise ValueError("type must be start, stream or end")
        return v


class StreamingConversationChain:
    """Class for handling streaming conversation chains."""

    def __init__(self, openai_api_key: str, temperature: float = 0.0):
        self.memories = {}
        self.openai_api_key = openai_api_key
        self.temperature = temperature

    async def generate_response(self, conversation_id: str, message: str) -> AsyncGenerator[Any, None]:
        # Define LLMs
        question_gen_llm = ChatOpenAI(
            model="gpt-3.5-turbo", max_retries=15, temperature=self.temperature, streaming=True, openai_api_key=self.openai_api_key
        )

        streaming_llm = ChatOpenAI(
            model="gpt-3.5-turbo", max_retries=15, temperature=self.temperature, streaming=True, openai_api_key=self.openai_api_key
        )

        # Build Memory
        memory = self.memories.get(conversation_id)
        if memory is None:
            memory = ConversationBufferMemory(input_key="question", output_key="answer", memory_key="chat_history", return_messages=True)
            self.memories[conversation_id] = memory

        # Build standalone question chain, final QA chain
        question_gen_chain = (
            {"question": lambda x: x["question"], "chat_history": lambda x: x["chat_history"]}
            | QST_GEN_PROMPT
            | question_gen_llm
            | StrOutputParser()
        )

        final_qa_chain = ANSWER_PROMPT | streaming_llm

        # Run standalone question chain
        _memory = memory.load_memory_variables({})
        standalone_question = await question_gen_chain.ainvoke(
            {"question": message, "chat_history": _format_chat_history(_memory["chat_history"])}
        )

        # Run the retriever chain and yield the souce docs
        retrieved_docs = await retriever.ainvoke(standalone_question)
        source_docs_d = [{"content": doc.page_content} for doc in retrieved_docs] if retrieved_docs else None

        xtra = {"source_documents": source_docs_d}
        yield ChatResponse(message=xtra, type="info").json()

        # Stream the response from the final QA chain
        final_response = ""
        async for m in final_qa_chain.astream({"question": standalone_question, "context": _combine_documents(retrieved_docs)}):
            final_response += m.content
            yield ChatResponse(message=m.content, type="stream").json()

        # Let client know that there is nothing else to stream
        yield ChatResponse(message="", type="end").json()

        # Need to save memory explicitly
        memory.save_context({"question": message}, {"answer": final_response})

        print("here")


streaming_conversation_chain = StreamingConversationChain(openai_api_key="YOUR_OPENAI_KEY", temperature=0.0)


@app.post("/sse-chat-convo", response_class=StreamingResponse)
async def generate_response(data: ChatRequest) -> StreamingResponse:
    """Endpoint for chat requests"""
    return StreamingResponse(
        streaming_conversation_chain.generate_response(data.conversation_id, data.message),
        media_type="text/event-stream",
    )

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment