Skip to content

Instantly share code, notes, and snippets.

@jvelezmagic
Created May 17, 2023 12:05
Show Gist options
  • Save jvelezmagic/03ddf4c452d011aae36b2a0f73d72f68 to your computer and use it in GitHub Desktop.
Save jvelezmagic/03ddf4c452d011aae36b2a0f73d72f68 to your computer and use it in GitHub Desktop.
Langchain FastAPI stream with simple memory
# The goal of this file is to provide a FastAPI application for handling
# chat requests amd generation AI-powered responses using conversation chains.
# The application uses the LangChaing library, which includes a chatOpenAI model
# for natural language processing.
# The `StreamingConversationChain` class is responsible for creating and storing
# conversation memories and generating responses. It utilizes the `ChatOpenAI` model
# and a callback handler to stream responses as they're generated.
# The application defines a `ChatRequest` model for handling chat requests,
# which includes the conversation ID and the user's message.
# The `/chat` endpoint is used to receive chat requests and generate responses.
# It utilizes the `StreamingConversationChain` instance to generate the responses and
# sends them back as a streaming response using the `StreamingResponse` class.
# PLease note that the implementation relies on certain dependencies and imports,
# which are not included in the provided code snippet.
# Ensure that all necessary packages are installed and imported
# correctly before running the application.
#
# Install dependencies:
# pip install fastapi uvicorn[standard] python-dotenv langchain openai
#
# Example of usage:
# uvicorn main:app --reload
#
# Example of request:
#
# curl --no-buffer \
# -X POST \
# -H 'accept: text/event-stream' \
# -H 'Content-Type: application/json' \
# -d '{
# "conversation_id": "cat-conversation",
# "message": "what'\''s their size?"
# }' \
# http://localhost:8000/chat
#
# Cheers,
# @jvelezmagic
import asyncio
from functools import lru_cache
from typing import AsyncGenerator
from fastapi import Depends, FastAPI
from fastapi.responses import StreamingResponse
from langchain.callbacks import AsyncIteratorCallbackHandler
from langchain.chains import ConversationChain
from langchain.chat_models import ChatOpenAI
from langchain.memory import ConversationBufferMemory
from langchain.prompts import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
MessagesPlaceholder,
SystemMessagePromptTemplate,
)
from pydantic import BaseModel, BaseSettings
class Settings(BaseSettings):
"""
Settings class for this application.
Utilizes the BaseSettings from pydantic for environment variables.
"""
openai_api_key: str
class Config:
env_file = ".env"
@lru_cache()
def get_settings():
"""Function to get and cache settings.
The settings are cached to avoid repeated disk I/O.
"""
return Settings()
class StreamingConversationChain:
"""
Class for handling streaming conversation chains.
It creates and stores memory for each conversation,
and generates responses using the ChatOpenAI model from LangChain.
"""
def __init__(self, openai_api_key: str, temperature: float = 0.0):
self.memories = {}
self.openai_api_key = openai_api_key
self.temperature = temperature
async def generate_response(
self, conversation_id: str, message: str
) -> AsyncGenerator[str, None]:
"""
Asynchronous function to generate a response for a conversation.
It creates a new conversation chain for each message and uses a
callback handler to stream responses as they're generated.
:param conversation_id: The ID of the conversation.
:param message: The message from the user.
"""
callback_handler = AsyncIteratorCallbackHandler()
llm = ChatOpenAI(
callbacks=[callback_handler],
streaming=True,
temperature=self.temperature,
openai_api_key=self.openai_api_key,
)
memory = self.memories.get(conversation_id)
if memory is None:
memory = ConversationBufferMemory(return_messages=True)
self.memories[conversation_id] = memory
chain = ConversationChain(
memory=memory,
prompt=CHAT_PROMPT_TEMPLATE,
llm=llm,
)
run = asyncio.create_task(chain.arun(input=message))
async for token in callback_handler.aiter():
yield token
await run
class ChatRequest(BaseModel):
"""Request model for chat requests.
Includes the conversation ID and the message from the user.
"""
conversation_id: str
message: str
CHAT_PROMPT_TEMPLATE = ChatPromptTemplate.from_messages(
[
SystemMessagePromptTemplate.from_template(
"You're a AI that knows everything about cats."
),
MessagesPlaceholder(variable_name="history"),
HumanMessagePromptTemplate.from_template("{input}"),
]
)
app = FastAPI(dependencies=[Depends(get_settings)])
streaming_conversation_chain = StreamingConversationChain(
openai_api_key=get_settings().openai_api_key
)
@app.post("/chat", response_class=StreamingResponse)
async def generate_response(data: ChatRequest) -> StreamingResponse:
"""Endpoint for chat requests.
It uses the StreamingConversationChain instance to generate responses,
and then sends these responses as a streaming response.
:param data: The request data.
"""
return StreamingResponse(
streaming_conversation_chain.generate_response(
data.conversation_id, data.message
),
media_type="text/event-stream",
)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app)
@jmc-123
Copy link

jmc-123 commented Aug 9, 2023

Amazing . Any tutorial example would be appreciated.

@ThorPham
Copy link

ThorPham commented Aug 9, 2023

@jvelezmagic i test your code but it's not work

@fadilparves
Copy link

fadilparves commented Aug 22, 2023

This worked for me. Amazing. Thank you @jvelezmagic
For those getting error you can try updating to latest langchain version and it fixed the issue for me with the AsyncIteratorCallbackHandler()

@coreation
Copy link

I'm quite new to Python, I'm a bit confused by how state is handled in the example. The only way this makes sense to me is that running the app via uvicorn.run() makes it so that the object of the StreamingConversationChain is not remade every single time a request is made, but is kept "alive" and re-used until the app is shutdown; which is done if you're restarting the app to push a code update for example.

Is that correct? If not, I'd much appreciate if someone elaborated on how state is maintained over different HTTP sessions using the example of @jvelezmagic , much obliged by the way! <3

@jvelezmagic
Copy link
Author

@coreation , you are right, the application in the example preserves the state, so memory would be available until shutdown. In a real case scenario you could you a database backed memory, like redis or postgresql to keep made the application unaware of the state. 🐾

@coreation
Copy link

@jvelezmagic thanks so much for the reply, I wasn't entirely sure but fully understanding how it works really helps out. Thanks for the gist, cheers!

@suraj143rosy
Copy link

suraj143rosy commented Jan 8, 2024

Hi, I am trying to use ConversationalRetrievalChain with Azure Cognitive Search as retriever with streaming capabilities enabled. The code is not providing the output in a streaming manner. I would like to know if there is any such feature which is supported using Langchain combining Azure Cognitive Search with LLM.
The code snippet I used is as below.

Code Snippet

def search_docs_chain_with_memory_streaming(
search_index_name=os.getenv("AZURE_COGNITIVE_SEARCH_INDEX_NAME"),
question_list=[],
answer_list=[],
):
code = detect(question)
language_name = map_language_code_to_name(code)
embeddings = OpenAIEmbeddings(
deployment=oaienvs.OPENAI_EMBEDDING_DEPLOYMENT_NAME,
model=oaienvs.OPENAI_EMBEDDING_MODEL_NAME,
openai_api_base=os.environ["OPENAI_API_BASE"],
openai_api_type=os.environ["OPENAI_API_TYPE"],
)
memory = ConversationBufferMemory(memory_key="chat_history", output_key="answer")
acs = AzureSearch(
azure_search_endpoint=os.getenv("AZURE_SEARCH_SERVICE_ENDPOINT"),
azure_search_key=os.getenv("AZURE_COGNITIVE_SEARCH_API_KEY"),
index_name=search_index_name,
search_type="similarity",
semantic_configuration_name="default",
embedding_function=embeddings.embed_query,
)
retriever = acs.as_retriever()
retriever.search_kwargs = {"score_threshold": 0.8} # {'k':1}
print("language_name-----", language_name)
hcp_conv_template = (
get_prompt(workflows, "retrievalchain_hcp_conv_template1", "system_prompt", "v0")

  • language_name +
    get_prompt(workflows, "retrievalchain_hcp_conv_template2", "system_prompt", "v0")
    )
    CONDENSE_QUESTION_PROMPT = get_prompt(workflows, "retrievalchain_condense_question_prompt", "system_prompt", "v0")
    prompt = PromptTemplate(
    input_variables=["question"], template=CONDENSE_QUESTION_PROMPT
    )
    SYSTEM_MSG2 = get_prompt(workflows, "retrievalchain_system_msg_template", "system_prompt", "v0")
    messages = [
    SystemMessagePromptTemplate.from_template(SYSTEM_MSG2),
    HumanMessagePromptTemplate.from_template(hcp_conv_template),
    ]
    qa_prompt = ChatPromptTemplate.from_messages(messages)
    llm = AzureChatOpenAI(
    deployment_name=oaienvs.OPENAI_CHAT_MODEL_DEPLOYMENT_NAME, temperature=0.7, max_retries=4,
    #callbacks=[streaming_cb],
    streaming=True
    #callback_manager=CallbackManager([MyCustomHandler()])
    )

qa_chain = ConversationalRetrievalChain.from_llm(
llm=llm,
retriever=retriever,
get_chat_history=lambda o: o,
memory=memory,
condense_question_prompt=prompt,
return_source_documents=True,
verbose=True,
#callback_manager=convo_cb_manager,
#condense_question_llm = llm_condense_ques,
combine_docs_chain_kwargs={"prompt": qa_prompt},
)

if len(question_list) == 0:
question = question + ". Give the answer only in " + language_name + "."

for i in range(len(question_list)):
qa_chain.memory.save_context(
inputs={"question": question_list[i]}, outputs={"answer": answer_list[i]}
)
#return qa_chain.stream({"question": question, "chat_history": []})
return qa_chain

Also I have tried different callback handlers and invoke methods as mentioned in https://gist.github.com/jvelezmagic/03ddf4c452d011aae36b2a0f73d72f68

Kindly suggest if there is any workaround to it.

@ZanyuanYang
Copy link

is anyone work well for memory chat, it work well for the first time API call, but for the second time, it's not working

@AvikantSrivastava
Copy link

is anyone work well for memory chat, it work well for the first time API call, but for the second time, it's not working

@ZanyuanYang what database are you using? is it in-memory?

if it's an in-memory storage, building an API would be diffcult because each call to the endpoint will not share the memory

@ZanyuanYang
Copy link

is anyone work well for memory chat, it work well for the first time API call, but for the second time, it's not working

@ZanyuanYang what database are you using? is it in-memory?

if it's an in-memory storage, building an API would be diffcult because each call to the endpoint will not share the memory

I used Elasticsearch AvikantSrivastava

@ZanyuanYang
Copy link

@AvikantSrivastava This is my code

class StreamingLLMCallbackHandler(AsyncIteratorCallbackHandler):
    """Callback handler for streaming LLM responses."""

    async def on_llm_start(self, serialized, prompts, **kwargs) -> None:
        logging.info("LLM start")
        self.done.clear()
        self.queue.put_nowait(ChatResponse(sender="bot", message="", type="start"))

    async def on_llm_end(self, response, **kwargs) -> None:
        logging.info("LLM end")
        # we override this method since we want the ConversationalRetrievalChain to potentially return
        # other items (e.g., source_documents) after it is completed
        pass

    async def on_llm_error(
        self, error: Exception | KeyboardInterrupt, **kwargs
    ) -> None:
        logging.error(f"LLM error: {error}")
        self.queue.put_nowait(
            ChatResponse(sender="bot", message=str(error), type="error")
        )

    async def on_llm_new_token(self, token: str, **kwargs) -> None:
        # if token not in ['"', "}"]:
        self.queue.put_nowait(ChatResponse(sender="bot", message=token, type="stream"))


class ConvoChainCallbackHandler(AsyncCallbackHandler):
    """Use to add additional information (e.g., source_documents, etc...) once the chain finishes"""

    def __init__(self, callback_handler) -> None:
        super().__init__()
        self.callback_handler = callback_handler

    async def on_chain_end(self, outputs, *, run_id, parent_run_id, **kwargs) -> None:
        """Run after chain ends running."""
        source_docs = outputs.get("source_documents", None)
        doc_list = [
            {"page_content": doc.page_content, "metadata": doc.metadata}
            for doc in source_docs
        ]

        metadata_list = getMetadataFromCourtListener(doc_list)

        # metadata = {"metadata": metadata_list}
        self.callback_handler.queue.put_nowait(
            ChatResponse(sender="bot", message="", metadata=metadata_list, type="info")
        )
        self.callback_handler.queue.put_nowait(
            ChatResponse(sender="bot", message="", type="end")
        )


class StreamingConversationChain:
    """Class for handling streaming conversation chains."""

    def __init__(self, temperature: float = 0.0):
        self.memories = {}
        self.openai_api_key = openai_api_key
        self.temperature = temperature

    async def generate_response(
        self, conversation_id: str, message: str
    ) -> AsyncGenerator[str, None]:
        """
        Asynchronous function to generate a response for a conversation.
        It creates a new conversation chain for each message and uses a
        callback handler to stream responses as they're generated.
        :param conversation_id: The ID of the conversation.
        :param message: The message from the user.
        """

        streaming_cb = StreamingLLMCallbackHandler()  # AsyncIteratorCallbackHandler()
        convo_cb_manager = AsyncCallbackManager(
            [ConvoChainCallbackHandler(streaming_cb)]
        )

        question_gen_llm = ChatOpenAI(
            model_name="gpt-3.5-turbo-1106",
            temperature=0.0,
            streaming=True,
            openai_api_key=self.openai_api_key,
            max_tokens=4097,
        )

        streaming_llm = ChatOpenAI(
            model_name="gpt-3.5-turbo-1106",
            temperature=0,
            callbacks=[streaming_cb],
            streaming=True,
            openai_api_key=self.openai_api_key,
        )

        memory = self.memories.get(conversation_id)

        if memory is None:
            memory = ConversationBufferMemory(
                memory_key="chat_history",
                return_messages=True,
                output_key="answer",
            )
            self.memories[conversation_id] = memory

        prompt_template = "Tell me a {adjective} joke"
        prompt = PromptTemplate(input_variables=["adjective"], template=prompt_template)

        question_gen_chain = LLMChain(
            llm=question_gen_llm, prompt=prompt
        )  # , callback_manager=manager)

        final_qa_chain = load_qa_chain(
            streaming_llm,
            chain_type="stuff",
        )

        es_retriever = LexARIElasticSearchBM25Retriever(
            client=ES_CLIENT, index_name=ES_INDEX_NAME
        )

        docs = es_retriever.get_relevant_documents(message)
        qdrant = Qdrant.from_documents(
            docs,
            EMBEDDINGS,
            location=":memory:",  # Local mode with in-memory storage only
            collection_name="my_documents",
        )

        convo_chain = ConversationalRetrievalChain(
            retriever=qdrant.as_retriever(search_type="similarity"),
            question_generator=question_gen_chain,
            combine_docs_chain=final_qa_chain,
            memory=memory,
            return_source_documents=True,
            callback_manager=convo_cb_manager,
            max_tokens_limit=16385,
        )

        run = asyncio.create_task(convo_chain.acall({"question": message}))

        async for token in streaming_cb.aiter():
            # Print for debugging purposes
            print("dict: ", token.dict())

            # Yield the response as JSON
            yield json.dumps(token.dict())

            if token.dict().get("type") in ["end", "error"]:
                streaming_cb.done.set()

        # Wait for the conversation chain task to complete
        await run

@ZanyuanYang
Copy link

Is anyone know how to fixed this error?
return inputs[prompt_input_key], outputs[output_key]
~~~~~~~^^^^^^^^^^^^
KeyError: 'answer'

@aijdsofttech
Copy link

return_source_documents=False
i hope after this changes your chain will work fine

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment