Skip to content

Instantly share code, notes, and snippets.

@jvelezmagic
Created May 17, 2023 12:05
Show Gist options
  • Save jvelezmagic/03ddf4c452d011aae36b2a0f73d72f68 to your computer and use it in GitHub Desktop.
Save jvelezmagic/03ddf4c452d011aae36b2a0f73d72f68 to your computer and use it in GitHub Desktop.
Langchain FastAPI stream with simple memory
# The goal of this file is to provide a FastAPI application for handling
# chat requests amd generation AI-powered responses using conversation chains.
# The application uses the LangChaing library, which includes a chatOpenAI model
# for natural language processing.
# The `StreamingConversationChain` class is responsible for creating and storing
# conversation memories and generating responses. It utilizes the `ChatOpenAI` model
# and a callback handler to stream responses as they're generated.
# The application defines a `ChatRequest` model for handling chat requests,
# which includes the conversation ID and the user's message.
# The `/chat` endpoint is used to receive chat requests and generate responses.
# It utilizes the `StreamingConversationChain` instance to generate the responses and
# sends them back as a streaming response using the `StreamingResponse` class.
# PLease note that the implementation relies on certain dependencies and imports,
# which are not included in the provided code snippet.
# Ensure that all necessary packages are installed and imported
# correctly before running the application.
#
# Install dependencies:
# pip install fastapi uvicorn[standard] python-dotenv langchain openai
#
# Example of usage:
# uvicorn main:app --reload
#
# Example of request:
#
# curl --no-buffer \
# -X POST \
# -H 'accept: text/event-stream' \
# -H 'Content-Type: application/json' \
# -d '{
# "conversation_id": "cat-conversation",
# "message": "what'\''s their size?"
# }' \
# http://localhost:8000/chat
#
# Cheers,
# @jvelezmagic
import asyncio
from functools import lru_cache
from typing import AsyncGenerator
from fastapi import Depends, FastAPI
from fastapi.responses import StreamingResponse
from langchain.callbacks import AsyncIteratorCallbackHandler
from langchain.chains import ConversationChain
from langchain.chat_models import ChatOpenAI
from langchain.memory import ConversationBufferMemory
from langchain.prompts import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
MessagesPlaceholder,
SystemMessagePromptTemplate,
)
from pydantic import BaseModel, BaseSettings
class Settings(BaseSettings):
"""
Settings class for this application.
Utilizes the BaseSettings from pydantic for environment variables.
"""
openai_api_key: str
class Config:
env_file = ".env"
@lru_cache()
def get_settings():
"""Function to get and cache settings.
The settings are cached to avoid repeated disk I/O.
"""
return Settings()
class StreamingConversationChain:
"""
Class for handling streaming conversation chains.
It creates and stores memory for each conversation,
and generates responses using the ChatOpenAI model from LangChain.
"""
def __init__(self, openai_api_key: str, temperature: float = 0.0):
self.memories = {}
self.openai_api_key = openai_api_key
self.temperature = temperature
async def generate_response(
self, conversation_id: str, message: str
) -> AsyncGenerator[str, None]:
"""
Asynchronous function to generate a response for a conversation.
It creates a new conversation chain for each message and uses a
callback handler to stream responses as they're generated.
:param conversation_id: The ID of the conversation.
:param message: The message from the user.
"""
callback_handler = AsyncIteratorCallbackHandler()
llm = ChatOpenAI(
callbacks=[callback_handler],
streaming=True,
temperature=self.temperature,
openai_api_key=self.openai_api_key,
)
memory = self.memories.get(conversation_id)
if memory is None:
memory = ConversationBufferMemory(return_messages=True)
self.memories[conversation_id] = memory
chain = ConversationChain(
memory=memory,
prompt=CHAT_PROMPT_TEMPLATE,
llm=llm,
)
run = asyncio.create_task(chain.arun(input=message))
async for token in callback_handler.aiter():
yield token
await run
class ChatRequest(BaseModel):
"""Request model for chat requests.
Includes the conversation ID and the message from the user.
"""
conversation_id: str
message: str
CHAT_PROMPT_TEMPLATE = ChatPromptTemplate.from_messages(
[
SystemMessagePromptTemplate.from_template(
"You're a AI that knows everything about cats."
),
MessagesPlaceholder(variable_name="history"),
HumanMessagePromptTemplate.from_template("{input}"),
]
)
app = FastAPI(dependencies=[Depends(get_settings)])
streaming_conversation_chain = StreamingConversationChain(
openai_api_key=get_settings().openai_api_key
)
@app.post("/chat", response_class=StreamingResponse)
async def generate_response(data: ChatRequest) -> StreamingResponse:
"""Endpoint for chat requests.
It uses the StreamingConversationChain instance to generate responses,
and then sends these responses as a streaming response.
:param data: The request data.
"""
return StreamingResponse(
streaming_conversation_chain.generate_response(
data.conversation_id, data.message
),
media_type="text/event-stream",
)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app)
@ZanyuanYang
Copy link

@AvikantSrivastava This is my code

class StreamingLLMCallbackHandler(AsyncIteratorCallbackHandler):
    """Callback handler for streaming LLM responses."""

    async def on_llm_start(self, serialized, prompts, **kwargs) -> None:
        logging.info("LLM start")
        self.done.clear()
        self.queue.put_nowait(ChatResponse(sender="bot", message="", type="start"))

    async def on_llm_end(self, response, **kwargs) -> None:
        logging.info("LLM end")
        # we override this method since we want the ConversationalRetrievalChain to potentially return
        # other items (e.g., source_documents) after it is completed
        pass

    async def on_llm_error(
        self, error: Exception | KeyboardInterrupt, **kwargs
    ) -> None:
        logging.error(f"LLM error: {error}")
        self.queue.put_nowait(
            ChatResponse(sender="bot", message=str(error), type="error")
        )

    async def on_llm_new_token(self, token: str, **kwargs) -> None:
        # if token not in ['"', "}"]:
        self.queue.put_nowait(ChatResponse(sender="bot", message=token, type="stream"))


class ConvoChainCallbackHandler(AsyncCallbackHandler):
    """Use to add additional information (e.g., source_documents, etc...) once the chain finishes"""

    def __init__(self, callback_handler) -> None:
        super().__init__()
        self.callback_handler = callback_handler

    async def on_chain_end(self, outputs, *, run_id, parent_run_id, **kwargs) -> None:
        """Run after chain ends running."""
        source_docs = outputs.get("source_documents", None)
        doc_list = [
            {"page_content": doc.page_content, "metadata": doc.metadata}
            for doc in source_docs
        ]

        metadata_list = getMetadataFromCourtListener(doc_list)

        # metadata = {"metadata": metadata_list}
        self.callback_handler.queue.put_nowait(
            ChatResponse(sender="bot", message="", metadata=metadata_list, type="info")
        )
        self.callback_handler.queue.put_nowait(
            ChatResponse(sender="bot", message="", type="end")
        )


class StreamingConversationChain:
    """Class for handling streaming conversation chains."""

    def __init__(self, temperature: float = 0.0):
        self.memories = {}
        self.openai_api_key = openai_api_key
        self.temperature = temperature

    async def generate_response(
        self, conversation_id: str, message: str
    ) -> AsyncGenerator[str, None]:
        """
        Asynchronous function to generate a response for a conversation.
        It creates a new conversation chain for each message and uses a
        callback handler to stream responses as they're generated.
        :param conversation_id: The ID of the conversation.
        :param message: The message from the user.
        """

        streaming_cb = StreamingLLMCallbackHandler()  # AsyncIteratorCallbackHandler()
        convo_cb_manager = AsyncCallbackManager(
            [ConvoChainCallbackHandler(streaming_cb)]
        )

        question_gen_llm = ChatOpenAI(
            model_name="gpt-3.5-turbo-1106",
            temperature=0.0,
            streaming=True,
            openai_api_key=self.openai_api_key,
            max_tokens=4097,
        )

        streaming_llm = ChatOpenAI(
            model_name="gpt-3.5-turbo-1106",
            temperature=0,
            callbacks=[streaming_cb],
            streaming=True,
            openai_api_key=self.openai_api_key,
        )

        memory = self.memories.get(conversation_id)

        if memory is None:
            memory = ConversationBufferMemory(
                memory_key="chat_history",
                return_messages=True,
                output_key="answer",
            )
            self.memories[conversation_id] = memory

        prompt_template = "Tell me a {adjective} joke"
        prompt = PromptTemplate(input_variables=["adjective"], template=prompt_template)

        question_gen_chain = LLMChain(
            llm=question_gen_llm, prompt=prompt
        )  # , callback_manager=manager)

        final_qa_chain = load_qa_chain(
            streaming_llm,
            chain_type="stuff",
        )

        es_retriever = LexARIElasticSearchBM25Retriever(
            client=ES_CLIENT, index_name=ES_INDEX_NAME
        )

        docs = es_retriever.get_relevant_documents(message)
        qdrant = Qdrant.from_documents(
            docs,
            EMBEDDINGS,
            location=":memory:",  # Local mode with in-memory storage only
            collection_name="my_documents",
        )

        convo_chain = ConversationalRetrievalChain(
            retriever=qdrant.as_retriever(search_type="similarity"),
            question_generator=question_gen_chain,
            combine_docs_chain=final_qa_chain,
            memory=memory,
            return_source_documents=True,
            callback_manager=convo_cb_manager,
            max_tokens_limit=16385,
        )

        run = asyncio.create_task(convo_chain.acall({"question": message}))

        async for token in streaming_cb.aiter():
            # Print for debugging purposes
            print("dict: ", token.dict())

            # Yield the response as JSON
            yield json.dumps(token.dict())

            if token.dict().get("type") in ["end", "error"]:
                streaming_cb.done.set()

        # Wait for the conversation chain task to complete
        await run

@ZanyuanYang
Copy link

Is anyone know how to fixed this error?
return inputs[prompt_input_key], outputs[output_key]
~~~~~~~^^^^^^^^^^^^
KeyError: 'answer'

@aijdsofttech
Copy link

return_source_documents=False
i hope after this changes your chain will work fine

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment