Skip to content

Instantly share code, notes, and snippets.

@jvelezmagic
Created May 17, 2023 12:05
Show Gist options
  • Star 37 You must be signed in to star a gist
  • Fork 5 You must be signed in to fork a gist
  • Save jvelezmagic/03ddf4c452d011aae36b2a0f73d72f68 to your computer and use it in GitHub Desktop.
Save jvelezmagic/03ddf4c452d011aae36b2a0f73d72f68 to your computer and use it in GitHub Desktop.
Langchain FastAPI stream with simple memory
# The goal of this file is to provide a FastAPI application for handling
# chat requests amd generation AI-powered responses using conversation chains.
# The application uses the LangChaing library, which includes a chatOpenAI model
# for natural language processing.
# The `StreamingConversationChain` class is responsible for creating and storing
# conversation memories and generating responses. It utilizes the `ChatOpenAI` model
# and a callback handler to stream responses as they're generated.
# The application defines a `ChatRequest` model for handling chat requests,
# which includes the conversation ID and the user's message.
# The `/chat` endpoint is used to receive chat requests and generate responses.
# It utilizes the `StreamingConversationChain` instance to generate the responses and
# sends them back as a streaming response using the `StreamingResponse` class.
# PLease note that the implementation relies on certain dependencies and imports,
# which are not included in the provided code snippet.
# Ensure that all necessary packages are installed and imported
# correctly before running the application.
#
# Install dependencies:
# pip install fastapi uvicorn[standard] python-dotenv langchain openai
#
# Example of usage:
# uvicorn main:app --reload
#
# Example of request:
#
# curl --no-buffer \
# -X POST \
# -H 'accept: text/event-stream' \
# -H 'Content-Type: application/json' \
# -d '{
# "conversation_id": "cat-conversation",
# "message": "what'\''s their size?"
# }' \
# http://localhost:8000/chat
#
# Cheers,
# @jvelezmagic
import asyncio
from functools import lru_cache
from typing import AsyncGenerator
from fastapi import Depends, FastAPI
from fastapi.responses import StreamingResponse
from langchain.callbacks import AsyncIteratorCallbackHandler
from langchain.chains import ConversationChain
from langchain.chat_models import ChatOpenAI
from langchain.memory import ConversationBufferMemory
from langchain.prompts import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
MessagesPlaceholder,
SystemMessagePromptTemplate,
)
from pydantic import BaseModel, BaseSettings
class Settings(BaseSettings):
"""
Settings class for this application.
Utilizes the BaseSettings from pydantic for environment variables.
"""
openai_api_key: str
class Config:
env_file = ".env"
@lru_cache()
def get_settings():
"""Function to get and cache settings.
The settings are cached to avoid repeated disk I/O.
"""
return Settings()
class StreamingConversationChain:
"""
Class for handling streaming conversation chains.
It creates and stores memory for each conversation,
and generates responses using the ChatOpenAI model from LangChain.
"""
def __init__(self, openai_api_key: str, temperature: float = 0.0):
self.memories = {}
self.openai_api_key = openai_api_key
self.temperature = temperature
async def generate_response(
self, conversation_id: str, message: str
) -> AsyncGenerator[str, None]:
"""
Asynchronous function to generate a response for a conversation.
It creates a new conversation chain for each message and uses a
callback handler to stream responses as they're generated.
:param conversation_id: The ID of the conversation.
:param message: The message from the user.
"""
callback_handler = AsyncIteratorCallbackHandler()
llm = ChatOpenAI(
callbacks=[callback_handler],
streaming=True,
temperature=self.temperature,
openai_api_key=self.openai_api_key,
)
memory = self.memories.get(conversation_id)
if memory is None:
memory = ConversationBufferMemory(return_messages=True)
self.memories[conversation_id] = memory
chain = ConversationChain(
memory=memory,
prompt=CHAT_PROMPT_TEMPLATE,
llm=llm,
)
run = asyncio.create_task(chain.arun(input=message))
async for token in callback_handler.aiter():
yield token
await run
class ChatRequest(BaseModel):
"""Request model for chat requests.
Includes the conversation ID and the message from the user.
"""
conversation_id: str
message: str
CHAT_PROMPT_TEMPLATE = ChatPromptTemplate.from_messages(
[
SystemMessagePromptTemplate.from_template(
"You're a AI that knows everything about cats."
),
MessagesPlaceholder(variable_name="history"),
HumanMessagePromptTemplate.from_template("{input}"),
]
)
app = FastAPI(dependencies=[Depends(get_settings)])
streaming_conversation_chain = StreamingConversationChain(
openai_api_key=get_settings().openai_api_key
)
@app.post("/chat", response_class=StreamingResponse)
async def generate_response(data: ChatRequest) -> StreamingResponse:
"""Endpoint for chat requests.
It uses the StreamingConversationChain instance to generate responses,
and then sends these responses as a streaming response.
:param data: The request data.
"""
return StreamingResponse(
streaming_conversation_chain.generate_response(
data.conversation_id, data.message
),
media_type="text/event-stream",
)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app)
@AvikantSrivastava
Copy link

I tried the above, sadly it does not work :(

@talhaanwarch
Copy link

I tried the above, sadly it does not work :(

I tested it, it is working. you may need to install latest version of langchain

@eggb4by
Copy link

eggb4by commented Jun 1, 2023

I tried the above, sadly it does not work :(

It did work, clean and easy to understand.Thanks author!

@eggb4by
Copy link

eggb4by commented Jun 4, 2023

@jvelezmagic How can I get the latest response text after streaming response? Thanks

@tang85
Copy link

tang85 commented Jun 14, 2023

GREAT! Works fine

@ohmeow
Copy link

ohmeow commented Jun 23, 2023

Thanks for the great example.

Question: Let's say we are using a ConversationalRetrievalChain and want to return the source_documents back as well ... do you have a recommendation on how to augment the example to do so?

Thanks again for probably the most straight-forward example of how to stream data via langchain and fastapi.

@9akashnp8
Copy link

9akashnp8 commented Jun 25, 2023

@jvelezmagic thank you for this!

Although, using StreamingResponse isn't working for me but EventSourceResponse works fine.

Here is my version if anyone's curious:

memory = ConversationBufferMemory(memory_key="chat_history")

async def generate_response(message: str, memory):
    callback_handler = AsyncIteratorCallbackHandler()
    llm = OpenAI(
        callbacks=[callback_handler],
        streaming=True,
        temperature=1,
        openai_api_key=config('OPENAI_API_KEY')
    )
    chain = ConversationChain(
        llm=llm,
        memory=memory,
        prompt=CHAT_PROMPT
    )
    run = asyncio.create_task(chain.arun(input=message))
    async for token in callback_handler.aiter():
        yield token
    await run

@router.post('/stream', response_class=EventSourceResponse)
async def message_stream(request: Request, user_message: UserMessage):
    return EventSourceResponse(
        generate_response(user_message.message),
        media_type="text/event-stream"
    )

@xleven
Copy link

xleven commented Jul 6, 2023

👍 Thanks a lot! Been stuck on await chain.arun for a long time until saw your asyncio.create_task method. Need to learn coroutines again 😄

@ohmeow
Copy link

ohmeow commented Aug 1, 2023

Here's an example of using this with the ConversationalRetrievalChain. If y'all can think of some way to improve the code please reply here. Converting this to use EventSourceResponse might be a good first step :)

class ChatRequest(BaseModel):
    """Request model for chat requests. Includes the conversation ID and the message from the user."""

    conversation_id: str
    message: str


class ChatResponse(BaseModel):
    """Chat response schema"""

    sender: str
    message: str
    type: str
    xtra: dict = None

    @validator("sender")
    def sender_must_be_bot_or_you(cls, v):
        if v not in ["bot", "you"]:
            raise ValueError("sender must be bot or you")
        return v

    @validator("type")
    def validate_message_type(cls, v):
        if v not in ["start", "stream", "end", "error", "info"]:
            raise ValueError("type must be start, stream or end")
        return v


class StreamingLLMCallbackHandler(AsyncIteratorCallbackHandler):
    """Callback handler for streaming LLM responses."""

    async def on_llm_start(self, serialized, prompts, **kwargs) -> None:
        self.done.clear()
        self.queue.put_nowait(ChatResponse(sender="bot", message="", type="start"))

    async def on_llm_end(self, response, **kwargs) -> None:
        # we override this method since we want the ConversationalRetrievalChain to potentially return
        # other items (e.g., source_documents) after it is completed
        pass

    async def on_llm_error(self, error: Exception | KeyboardInterrupt, **kwargs) -> None:
        self.queue.put_nowait(ChatResponse(sender="bot", message=str(error), type="error"))

    async def on_llm_new_token(self, token: str, **kwargs) -> None:
        self.queue.put_nowait(ChatResponse(sender="bot", message=token, type="stream"))


class ConvoChainCallbackHandler(AsyncCallbackHandler):
    """Use to add additional information (e.g., source_documents, etc...) once the chain finishes"""

    def __init__(self, callback_handler) -> None:
        super().__init__()
        self.callback_handler = callback_handler

    async def on_chain_end(self, outputs, *, run_id, parent_run_id, **kwargs) -> None:
        """Run after chain ends running."""

        source_docs = outputs.get("source_documents", None)
        source_docs_d = [{"page": doc.metadata["page"]} for doc in source_docs] if source_docs else None

        xtra = {"source_documents": source_docs_d}
        self.callback_handler.queue.put_nowait(ChatResponse(sender="bot", message="", xtra=xtra, type="info"))
        self.callback_handler.queue.put_nowait(ChatResponse(sender="bot", message="", type="end"))


class StreamingConversationChain:
    """Class for handling streaming conversation chains."""

    def __init__(self, openai_api_key: str, temperature: float = 0.0):
        self.memories = {}
        self.openai_api_key = openai_api_key
        self.temperature = temperature

    async def generate_response(self, conversation_id: str, message: str) -> AsyncGenerator[str, None]:
        """
        Asynchronous function to generate a response for a conversation.
        It creates a new conversation chain for each message and uses a
        callback handler to stream responses as they're generated.
        :param conversation_id: The ID of the conversation.
        :param message: The message from the user.
        """

        streaming_cb = StreamingLLMCallbackHandler()  # AsyncIteratorCallbackHandler()
        convo_cb_manager = AsyncCallbackManager([ConvoChainCallbackHandler(streaming_cb)])

        question_gen_llm = ChatOpenAI(
            model_name="gpt-3.5-turbo",
            max_retries=15,
            temperature=0.0,
            streaming=True,
            openai_api_key=self.openai_api_key,
        )

        streaming_llm = ChatOpenAI(
            max_retries=15,
            temperature=0,
            callbacks=[streaming_cb],
            streaming=True,
            openai_api_key=self.openai_api_key,
        )

        memory = self.memories.get(conversation_id)
        if memory is None:
            memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True, output_key="answer")
            self.memories[conversation_id] = memory

        question_gen_chain = LLMChain(llm=question_gen_llm, prompt=CONDENSE_QUESTION_PROMPT)  # , callback_manager=manager)

        final_qa_chain = load_qa_chain(
            streaming_llm,
            chain_type="stuff",
        )

        convo_chain = ConversationalRetrievalChain(
            retriever=retriever,
            question_generator=question_gen_chain,
            combine_docs_chain=final_qa_chain,
            memory=memory,
            return_source_documents=True,
            callback_manager=convo_cb_manager,
        )

        run = asyncio.create_task(convo_chain.acall({"question": message}))

        async for token in streaming_cb.aiter():
            # to return string
            # yield token

            # to return json
            if token.type in ["end", "error"]:
                streaming_cb.done.set()

            yield json.dumps(token.dict())
        await run


streaming_conversation_chain = StreamingConversationChain(
    openai_api_key="API_KEY", temperature=0.7
)


@app.post("/sse-chat-convo", response_class=StreamingResponse)
async def generate_response(data: ChatRequest) -> StreamingResponse:
    """Endpoint for chat requests"""
    return StreamingResponse(
        streaming_conversation_chain.generate_response(data.conversation_id, data.message),
        media_type="text/event-stream",
    )
    ```

@avikhandakar-dev
Copy link

Hi, Thanks. This code is works fine for OpenAi. But when I try to use LLaMA-2 from replicate, I got this error :
.venv/lib/python3.11/site-packages/langchain/llms/replicate.py:142: RuntimeWarning: coroutine 'AsyncCallbackManagerForLLMRun.on_llm_new_token' was never awaited run_manager.on_llm_new_token(output) RuntimeWarning: Enable tracemalloc to get the object allocation traceback

here is my code:
`async def send_message(message: str) -> AsyncIterable[str]:
callback = AsyncIteratorCallbackHandler()

llm = Replicate(
    streaming=True,
    model="a16z-infra/llama-2-13b-chat:d5da4236b006f967ceb7da037be9cfc3924b20d21fed88e1e94f19d56e2d3111",
    input={"temperature": 0.75, "max_length": 2048, "top_p": 1},
    replicate_api_token=r8_token,
    callbacks=[callback],
    verbose=True,
)


qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, return_source_documents= False)

async def wrap_done(fn: Awaitable, event: asyncio.Event):
    """Wrap an awaitable with a event to signal when it's done or an exception is raised."""
    try:
        await fn
    except Exception as e:
        # TODO: handle exception
        print(f"Caught exception: {e}")
    finally:
        # Signal the aiter to stop.
        event.set()

# Begin a task that runs in the background.
task = asyncio.create_task(wrap_done(
    qa.arun(message),
    callback.done),
)

async for token in callback.aiter():
    # Use server-sent-events to stream the response
    yield f"{token}\n\n"

await task`

Can anyone help? Thanks.

@ohmeow
Copy link

ohmeow commented Aug 5, 2023

Not sure if this helps ... but I've simplified my example to simply use a callback for the retriever.

Lmk if this works for llama and company ...

# load document
loader = PyPDFLoader("example.pdf")
documents = loader.load()
# split the documents into chunks
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
texts = text_splitter.split_documents(documents)
# select which embeddings we want to use
embeddings = OpenAIEmbeddings()
# create the vectorestore to use as the index
db = Chroma.from_documents(texts, embeddings)
# expose this index in a retriever interface
retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 2})


class ChatRequest(BaseModel):
    """Request model for chat requests. Includes the conversation ID and the message from the user."""
    conversation_id: str
    message: str

class RetrieverCallbackHandler(AsyncIteratorCallbackHandler):
    def __init__(self, streaming_callback_handler) -> None:
        super().__init__()
        self.streaming_callback_handler = streaming_callback_handler
        
    async def on_retriever_end(self, source_docs, *, run_id, parent_run_id, tags, **kwargs):
        source_docs_d = [{"page": doc.metadata["page"]} for doc in source_docs] if source_docs else None

        xtra = {"source_documents": source_docs_d}
        self.streaming_callback_handler.queue.put_nowait(xtra)


class StreamingConversationChain:
    """Class for handling streaming conversation chains."""

    def __init__(self, openai_api_key: str, temperature: float = 0.0):
        self.memories = {}
        self.openai_api_key = openai_api_key
        self.temperature = temperature

    async def generate_response(self, conversation_id: str, message: str) -> AsyncGenerator[str, None]:
        streaming_cb = AsyncIteratorCallbackHandler()
        retriever_cb = RetrieverCallbackHandler(streaming_callback_handler=streaming_cb)

        question_gen_llm = ChatOpenAI(
            model_name="gpt-3.5-turbo",
            max_retries=15,
            temperature=0.0,
            streaming=True,
            openai_api_key=self.openai_api_key,
        )

        streaming_llm = ChatOpenAI(
            max_retries=15,
            temperature=0,
            callbacks=[streaming_cb],
            streaming=True,
            openai_api_key=self.openai_api_key,
        )

        memory = self.memories.get(conversation_id)
        if memory is None:
            memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True, output_key="answer")
            self.memories[conversation_id] = memory

        question_gen_chain = LLMChain(llm=question_gen_llm, prompt=CONDENSE_QUESTION_PROMPT)
        final_qa_chain = load_qa_chain(streaming_llm, chain_type="stuff")
        convo_chain = ConversationalRetrievalChain(
            retriever=retriever,
            question_generator=question_gen_chain,
            combine_docs_chain=final_qa_chain,
            memory=memory,
            return_source_documents=True,
        )

        run = asyncio.create_task(convo_chain.acall({"question": message}, callbacks=[retriever_cb]))

        async for token in streaming_cb.aiter():
            yield json.dumps(token) if isinstance(token, dict) else token
            
        await run

@ohmeow
Copy link

ohmeow commented Aug 6, 2023

re: Replicate class ... just looking at the source and it looks like it doesn't support async insofar as I can tell.

Kinda frustrating that you can't just drop in whatever LLM you want to use ... this is a feature that an abstraction library like LangChain should support imo.

@Princekrampah
Copy link

Here's an example of using this with the ConversationalRetrievalChain. If y'all can think of some way to improve the code please reply here. Converting this to use EventSourceResponse might be a good first step :)

class ChatRequest(BaseModel):
    """Request model for chat requests. Includes the conversation ID and the message from the user."""

    conversation_id: str
    message: str


class ChatResponse(BaseModel):
    """Chat response schema"""

    sender: str
    message: str
    type: str
    xtra: dict = None

    @validator("sender")
    def sender_must_be_bot_or_you(cls, v):
        if v not in ["bot", "you"]:
            raise ValueError("sender must be bot or you")
        return v

    @validator("type")
    def validate_message_type(cls, v):
        if v not in ["start", "stream", "end", "error", "info"]:
            raise ValueError("type must be start, stream or end")
        return v


class StreamingLLMCallbackHandler(AsyncIteratorCallbackHandler):
    """Callback handler for streaming LLM responses."""

    async def on_llm_start(self, serialized, prompts, **kwargs) -> None:
        self.done.clear()
        self.queue.put_nowait(ChatResponse(sender="bot", message="", type="start"))

    async def on_llm_end(self, response, **kwargs) -> None:
        # we override this method since we want the ConversationalRetrievalChain to potentially return
        # other items (e.g., source_documents) after it is completed
        pass

    async def on_llm_error(self, error: Exception | KeyboardInterrupt, **kwargs) -> None:
        self.queue.put_nowait(ChatResponse(sender="bot", message=str(error), type="error"))

    async def on_llm_new_token(self, token: str, **kwargs) -> None:
        self.queue.put_nowait(ChatResponse(sender="bot", message=token, type="stream"))


class ConvoChainCallbackHandler(AsyncCallbackHandler):
    """Use to add additional information (e.g., source_documents, etc...) once the chain finishes"""

    def __init__(self, callback_handler) -> None:
        super().__init__()
        self.callback_handler = callback_handler

    async def on_chain_end(self, outputs, *, run_id, parent_run_id, **kwargs) -> None:
        """Run after chain ends running."""

        source_docs = outputs.get("source_documents", None)
        source_docs_d = [{"page": doc.metadata["page"]} for doc in source_docs] if source_docs else None

        xtra = {"source_documents": source_docs_d}
        self.callback_handler.queue.put_nowait(ChatResponse(sender="bot", message="", xtra=xtra, type="info"))
        self.callback_handler.queue.put_nowait(ChatResponse(sender="bot", message="", type="end"))


class StreamingConversationChain:
    """Class for handling streaming conversation chains."""

    def __init__(self, openai_api_key: str, temperature: float = 0.0):
        self.memories = {}
        self.openai_api_key = openai_api_key
        self.temperature = temperature

    async def generate_response(self, conversation_id: str, message: str) -> AsyncGenerator[str, None]:
        """
        Asynchronous function to generate a response for a conversation.
        It creates a new conversation chain for each message and uses a
        callback handler to stream responses as they're generated.
        :param conversation_id: The ID of the conversation.
        :param message: The message from the user.
        """

        streaming_cb = StreamingLLMCallbackHandler()  # AsyncIteratorCallbackHandler()
        convo_cb_manager = AsyncCallbackManager([ConvoChainCallbackHandler(streaming_cb)])

        question_gen_llm = ChatOpenAI(
            model_name="gpt-3.5-turbo",
            max_retries=15,
            temperature=0.0,
            streaming=True,
            openai_api_key=self.openai_api_key,
        )

        streaming_llm = ChatOpenAI(
            max_retries=15,
            temperature=0,
            callbacks=[streaming_cb],
            streaming=True,
            openai_api_key=self.openai_api_key,
        )

        memory = self.memories.get(conversation_id)
        if memory is None:
            memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True, output_key="answer")
            self.memories[conversation_id] = memory

        question_gen_chain = LLMChain(llm=question_gen_llm, prompt=CONDENSE_QUESTION_PROMPT)  # , callback_manager=manager)

        final_qa_chain = load_qa_chain(
            streaming_llm,
            chain_type="stuff",
        )

        convo_chain = ConversationalRetrievalChain(
            retriever=retriever,
            question_generator=question_gen_chain,
            combine_docs_chain=final_qa_chain,
            memory=memory,
            return_source_documents=True,
            callback_manager=convo_cb_manager,
        )

        run = asyncio.create_task(convo_chain.acall({"question": message}))

        async for token in streaming_cb.aiter():
            # to return string
            # yield token

            # to return json
            if token.type in ["end", "error"]:
                streaming_cb.done.set()

            yield json.dumps(token.dict())
        await run


streaming_conversation_chain = StreamingConversationChain(
    openai_api_key="API_KEY", temperature=0.7
)


@app.post("/sse-chat-convo", response_class=StreamingResponse)
async def generate_response(data: ChatRequest) -> StreamingResponse:
    """Endpoint for chat requests"""
    return StreamingResponse(
        streaming_conversation_chain.generate_response(data.conversation_id, data.message),
        media_type="text/event-stream",
    )
    ```

Thanks you so much, but how do I return a string. I keep on getting errors when i try to do that.

@Princekrampah
Copy link

Not sure if this helps ... but I've simplified my example to simply use a callback for the retriever.

Lmk if this works for llama and company ...

# load document
loader = PyPDFLoader("example.pdf")
documents = loader.load()
# split the documents into chunks
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
texts = text_splitter.split_documents(documents)
# select which embeddings we want to use
embeddings = OpenAIEmbeddings()
# create the vectorestore to use as the index
db = Chroma.from_documents(texts, embeddings)
# expose this index in a retriever interface
retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 2})


class ChatRequest(BaseModel):
    """Request model for chat requests. Includes the conversation ID and the message from the user."""
    conversation_id: str
    message: str

class RetrieverCallbackHandler(AsyncIteratorCallbackHandler):
    def __init__(self, streaming_callback_handler) -> None:
        super().__init__()
        self.streaming_callback_handler = streaming_callback_handler
        
    async def on_retriever_end(self, source_docs, *, run_id, parent_run_id, tags, **kwargs):
        source_docs_d = [{"page": doc.metadata["page"]} for doc in source_docs] if source_docs else None

        xtra = {"source_documents": source_docs_d}
        self.streaming_callback_handler.queue.put_nowait(xtra)


class StreamingConversationChain:
    """Class for handling streaming conversation chains."""

    def __init__(self, openai_api_key: str, temperature: float = 0.0):
        self.memories = {}
        self.openai_api_key = openai_api_key
        self.temperature = temperature

    async def generate_response(self, conversation_id: str, message: str) -> AsyncGenerator[str, None]:
        streaming_cb = AsyncIteratorCallbackHandler()
        retriever_cb = RetrieverCallbackHandler(streaming_callback_handler=streaming_cb)

        question_gen_llm = ChatOpenAI(
            model_name="gpt-3.5-turbo",
            max_retries=15,
            temperature=0.0,
            streaming=True,
            openai_api_key=self.openai_api_key,
        )

        streaming_llm = ChatOpenAI(
            max_retries=15,
            temperature=0,
            callbacks=[streaming_cb],
            streaming=True,
            openai_api_key=self.openai_api_key,
        )

        memory = self.memories.get(conversation_id)
        if memory is None:
            memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True, output_key="answer")
            self.memories[conversation_id] = memory

        question_gen_chain = LLMChain(llm=question_gen_llm, prompt=CONDENSE_QUESTION_PROMPT)
        final_qa_chain = load_qa_chain(streaming_llm, chain_type="stuff")
        convo_chain = ConversationalRetrievalChain(
            retriever=retriever,
            question_generator=question_gen_chain,
            combine_docs_chain=final_qa_chain,
            memory=memory,
            return_source_documents=True,
        )

        run = asyncio.create_task(convo_chain.acall({"question": message}, callbacks=[retriever_cb]))

        async for token in streaming_cb.aiter():
            yield json.dumps(token) if isinstance(token, dict) else token
            
        await run

Thanks so much. I have made some modifications so that the return is only a string.

import asyncio
from functools import lru_cache
from typing import AsyncGenerator

from fastapi import Depends, FastAPI
from fastapi.responses import StreamingResponse
from langchain.callbacks import AsyncIteratorCallbackHandler
from langchain.chains import LLMChain, ConversationalRetrievalChain
from langchain.chat_models import ChatOpenAI
from langchain.memory import ConversationBufferMemory
from langchain.prompts import (
    ChatPromptTemplate,
    HumanMessagePromptTemplate,
    MessagesPlaceholder,
    SystemMessagePromptTemplate,
)
from pydantic import BaseModel, BaseSettings, Field, validator
import json
from langchain.chains.question_answering import load_qa_chain
from langchain.vectorstores import Chroma
from langchain.embeddings import OpenAIEmbeddings


persist_directory = '../../ai_agent_logic/vectorstores/chroma'


class Settings(BaseSettings):
    """
    Settings class for this application.
    Utilizes the BaseSettings from pydantic for environment variables.
    """

    openai_api_key: str = Field(..., env="OPENAI_API_KEY")

    class Config:
        env_file = ".env"


@lru_cache()
def get_settings():
    """
    Function to get and cache settings.
    The settings are cached to avoid repeated disk I/O.
    """
    return Settings()


class ChatRequest(BaseModel):
    """Request model for chat requests. Includes the conversation ID and the message from the user."""

    conversation_id: str
    message: str


embedding = OpenAIEmbeddings(openai_api_key=get_settings().openai_api_key)

vectordb = Chroma(
    persist_directory=persist_directory,
    embedding_function=embedding
)


class RetrieverCallbackHandler(AsyncIteratorCallbackHandler):
    def __init__(self, streaming_callback_handler) -> None:
        super().__init__()
        self.streaming_callback_handler = streaming_callback_handler

    async def on_retriever_end(self, source_docs, *, run_id, parent_run_id, tags, **kwargs):
        source_docs_d = [{"page": doc.metadata["page"]}
                         for doc in source_docs] if source_docs else None

        xtra = {"source_documents": source_docs_d}
        self.streaming_callback_handler.queue.put_nowait(xtra)


class StreamingConversationChain:
    """Class for handling streaming conversation chains."""

    def __init__(self, openai_api_key: str, temperature: float = 0.0):
        self.memories = {}
        self.openai_api_key = openai_api_key
        self.temperature = temperature

    async def generate_response(self, conversation_id: str, message: str) -> AsyncGenerator[str, None]:
        streaming_cb = AsyncIteratorCallbackHandler()
        retriever_cb = RetrieverCallbackHandler(
            streaming_callback_handler=streaming_cb)

        question_gen_llm = ChatOpenAI(
            model_name="gpt-3.5-turbo",
            max_retries=15,
            temperature=0.0,
            streaming=True,
            openai_api_key=self.openai_api_key,
        )

        streaming_llm = ChatOpenAI(
            max_retries=15,
            temperature=0,
            callbacks=[streaming_cb],
            streaming=True,
            openai_api_key=self.openai_api_key,
        )

        memory = self.memories.get(conversation_id)
        if memory is None:
            memory = ConversationBufferMemory(
                memory_key="chat_history", return_messages=True, output_key="answer")
            self.memories[conversation_id] = memory

        question_gen_chain = LLMChain(
            llm=question_gen_llm, prompt=CHAT_PROMPT_TEMPLATE)
        final_qa_chain = load_qa_chain(streaming_llm, chain_type="stuff")
        convo_chain = ConversationalRetrievalChain(
            retriever=vectordb.as_retriever(),
            question_generator=question_gen_chain,
            combine_docs_chain=final_qa_chain,
            memory=memory,
            return_source_documents=True,
        )

        run = asyncio.create_task(convo_chain.acall(
            {"question": message}, callbacks=[retriever_cb]))

        async for token in streaming_cb.aiter():
            yield "" if isinstance(token, dict) else token

        await run


app = FastAPI(dependencies=[Depends(get_settings)])

streaming_conversation_chain = StreamingConversationChain(
    openai_api_key=get_settings().openai_api_key, temperature=0.7
)

CHAT_PROMPT_TEMPLATE = ChatPromptTemplate.from_messages(
    [
        SystemMessagePromptTemplate.from_template(
            "You're a AI that knows everything about cats."
        ),
        MessagesPlaceholder(variable_name="history"),
        HumanMessagePromptTemplate.from_template("{input}"),
    ]
)


@app.post("/chat", response_class=StreamingResponse)
async def generate_response(data: ChatRequest) -> StreamingResponse:
    """Endpoint for chat requests"""
    return StreamingResponse(
        streaming_conversation_chain.generate_response(
            data.conversation_id, data.message),
        media_type="text/event-stream",
    )


if __name__ == "__main__":
    import uvicorn

    uvicorn.run(app)

@jvelezmagic
Copy link
Author

Hi, everyone! 🚀 Exciting news – LangChain Expression Language is now available, and it makes implementing streaming responses a breeze, eliminating the need for manual callbacks.

I've put together a special gist to showcase its capabilities. Inside, you'll find everything you need to set up a QA bot that streams responses along with source documents, all built on FastAPI.

gist: https://gist.github.com/jvelezmagic/f3653cc2ddab1c91e86751c8b423a1b6

Example includes:

  • Persistent Chat Memory: Stores chat history in a local file.
  • Persistent Vector Store: Stores document embeddings in a local vector store.
  • Standalone Question Generation: Rephrases follow-up questions to standalone questions in their original language.
  • Document Retrieval: Searches and retrieves relevant documents based on user queries.
  • Context-Aware Responses: Generates responses based on a combined context from relevant documents.
  • Streaming Responses: Streams responses in real time either as plain text or as Server-Sent Events (SSE). SSE also sends the relevant documents as context.

Happy coding! 🐈

@jmc-123
Copy link

jmc-123 commented Aug 9, 2023

Amazing . Any tutorial example would be appreciated.

@ThorPham
Copy link

ThorPham commented Aug 9, 2023

@jvelezmagic i test your code but it's not work

@fadilparves
Copy link

fadilparves commented Aug 22, 2023

This worked for me. Amazing. Thank you @jvelezmagic
For those getting error you can try updating to latest langchain version and it fixed the issue for me with the AsyncIteratorCallbackHandler()

@coreation
Copy link

I'm quite new to Python, I'm a bit confused by how state is handled in the example. The only way this makes sense to me is that running the app via uvicorn.run() makes it so that the object of the StreamingConversationChain is not remade every single time a request is made, but is kept "alive" and re-used until the app is shutdown; which is done if you're restarting the app to push a code update for example.

Is that correct? If not, I'd much appreciate if someone elaborated on how state is maintained over different HTTP sessions using the example of @jvelezmagic , much obliged by the way! <3

@jvelezmagic
Copy link
Author

@coreation , you are right, the application in the example preserves the state, so memory would be available until shutdown. In a real case scenario you could you a database backed memory, like redis or postgresql to keep made the application unaware of the state. 🐾

@coreation
Copy link

@jvelezmagic thanks so much for the reply, I wasn't entirely sure but fully understanding how it works really helps out. Thanks for the gist, cheers!

@suraj143rosy
Copy link

suraj143rosy commented Jan 8, 2024

Hi, I am trying to use ConversationalRetrievalChain with Azure Cognitive Search as retriever with streaming capabilities enabled. The code is not providing the output in a streaming manner. I would like to know if there is any such feature which is supported using Langchain combining Azure Cognitive Search with LLM.
The code snippet I used is as below.

Code Snippet

def search_docs_chain_with_memory_streaming(
search_index_name=os.getenv("AZURE_COGNITIVE_SEARCH_INDEX_NAME"),
question_list=[],
answer_list=[],
):
code = detect(question)
language_name = map_language_code_to_name(code)
embeddings = OpenAIEmbeddings(
deployment=oaienvs.OPENAI_EMBEDDING_DEPLOYMENT_NAME,
model=oaienvs.OPENAI_EMBEDDING_MODEL_NAME,
openai_api_base=os.environ["OPENAI_API_BASE"],
openai_api_type=os.environ["OPENAI_API_TYPE"],
)
memory = ConversationBufferMemory(memory_key="chat_history", output_key="answer")
acs = AzureSearch(
azure_search_endpoint=os.getenv("AZURE_SEARCH_SERVICE_ENDPOINT"),
azure_search_key=os.getenv("AZURE_COGNITIVE_SEARCH_API_KEY"),
index_name=search_index_name,
search_type="similarity",
semantic_configuration_name="default",
embedding_function=embeddings.embed_query,
)
retriever = acs.as_retriever()
retriever.search_kwargs = {"score_threshold": 0.8} # {'k':1}
print("language_name-----", language_name)
hcp_conv_template = (
get_prompt(workflows, "retrievalchain_hcp_conv_template1", "system_prompt", "v0")

  • language_name +
    get_prompt(workflows, "retrievalchain_hcp_conv_template2", "system_prompt", "v0")
    )
    CONDENSE_QUESTION_PROMPT = get_prompt(workflows, "retrievalchain_condense_question_prompt", "system_prompt", "v0")
    prompt = PromptTemplate(
    input_variables=["question"], template=CONDENSE_QUESTION_PROMPT
    )
    SYSTEM_MSG2 = get_prompt(workflows, "retrievalchain_system_msg_template", "system_prompt", "v0")
    messages = [
    SystemMessagePromptTemplate.from_template(SYSTEM_MSG2),
    HumanMessagePromptTemplate.from_template(hcp_conv_template),
    ]
    qa_prompt = ChatPromptTemplate.from_messages(messages)
    llm = AzureChatOpenAI(
    deployment_name=oaienvs.OPENAI_CHAT_MODEL_DEPLOYMENT_NAME, temperature=0.7, max_retries=4,
    #callbacks=[streaming_cb],
    streaming=True
    #callback_manager=CallbackManager([MyCustomHandler()])
    )

qa_chain = ConversationalRetrievalChain.from_llm(
llm=llm,
retriever=retriever,
get_chat_history=lambda o: o,
memory=memory,
condense_question_prompt=prompt,
return_source_documents=True,
verbose=True,
#callback_manager=convo_cb_manager,
#condense_question_llm = llm_condense_ques,
combine_docs_chain_kwargs={"prompt": qa_prompt},
)

if len(question_list) == 0:
question = question + ". Give the answer only in " + language_name + "."

for i in range(len(question_list)):
qa_chain.memory.save_context(
inputs={"question": question_list[i]}, outputs={"answer": answer_list[i]}
)
#return qa_chain.stream({"question": question, "chat_history": []})
return qa_chain

Also I have tried different callback handlers and invoke methods as mentioned in https://gist.github.com/jvelezmagic/03ddf4c452d011aae36b2a0f73d72f68

Kindly suggest if there is any workaround to it.

@ZanyuanYang
Copy link

is anyone work well for memory chat, it work well for the first time API call, but for the second time, it's not working

@AvikantSrivastava
Copy link

is anyone work well for memory chat, it work well for the first time API call, but for the second time, it's not working

@ZanyuanYang what database are you using? is it in-memory?

if it's an in-memory storage, building an API would be diffcult because each call to the endpoint will not share the memory

@ZanyuanYang
Copy link

is anyone work well for memory chat, it work well for the first time API call, but for the second time, it's not working

@ZanyuanYang what database are you using? is it in-memory?

if it's an in-memory storage, building an API would be diffcult because each call to the endpoint will not share the memory

I used Elasticsearch AvikantSrivastava

@ZanyuanYang
Copy link

@AvikantSrivastava This is my code

class StreamingLLMCallbackHandler(AsyncIteratorCallbackHandler):
    """Callback handler for streaming LLM responses."""

    async def on_llm_start(self, serialized, prompts, **kwargs) -> None:
        logging.info("LLM start")
        self.done.clear()
        self.queue.put_nowait(ChatResponse(sender="bot", message="", type="start"))

    async def on_llm_end(self, response, **kwargs) -> None:
        logging.info("LLM end")
        # we override this method since we want the ConversationalRetrievalChain to potentially return
        # other items (e.g., source_documents) after it is completed
        pass

    async def on_llm_error(
        self, error: Exception | KeyboardInterrupt, **kwargs
    ) -> None:
        logging.error(f"LLM error: {error}")
        self.queue.put_nowait(
            ChatResponse(sender="bot", message=str(error), type="error")
        )

    async def on_llm_new_token(self, token: str, **kwargs) -> None:
        # if token not in ['"', "}"]:
        self.queue.put_nowait(ChatResponse(sender="bot", message=token, type="stream"))


class ConvoChainCallbackHandler(AsyncCallbackHandler):
    """Use to add additional information (e.g., source_documents, etc...) once the chain finishes"""

    def __init__(self, callback_handler) -> None:
        super().__init__()
        self.callback_handler = callback_handler

    async def on_chain_end(self, outputs, *, run_id, parent_run_id, **kwargs) -> None:
        """Run after chain ends running."""
        source_docs = outputs.get("source_documents", None)
        doc_list = [
            {"page_content": doc.page_content, "metadata": doc.metadata}
            for doc in source_docs
        ]

        metadata_list = getMetadataFromCourtListener(doc_list)

        # metadata = {"metadata": metadata_list}
        self.callback_handler.queue.put_nowait(
            ChatResponse(sender="bot", message="", metadata=metadata_list, type="info")
        )
        self.callback_handler.queue.put_nowait(
            ChatResponse(sender="bot", message="", type="end")
        )


class StreamingConversationChain:
    """Class for handling streaming conversation chains."""

    def __init__(self, temperature: float = 0.0):
        self.memories = {}
        self.openai_api_key = openai_api_key
        self.temperature = temperature

    async def generate_response(
        self, conversation_id: str, message: str
    ) -> AsyncGenerator[str, None]:
        """
        Asynchronous function to generate a response for a conversation.
        It creates a new conversation chain for each message and uses a
        callback handler to stream responses as they're generated.
        :param conversation_id: The ID of the conversation.
        :param message: The message from the user.
        """

        streaming_cb = StreamingLLMCallbackHandler()  # AsyncIteratorCallbackHandler()
        convo_cb_manager = AsyncCallbackManager(
            [ConvoChainCallbackHandler(streaming_cb)]
        )

        question_gen_llm = ChatOpenAI(
            model_name="gpt-3.5-turbo-1106",
            temperature=0.0,
            streaming=True,
            openai_api_key=self.openai_api_key,
            max_tokens=4097,
        )

        streaming_llm = ChatOpenAI(
            model_name="gpt-3.5-turbo-1106",
            temperature=0,
            callbacks=[streaming_cb],
            streaming=True,
            openai_api_key=self.openai_api_key,
        )

        memory = self.memories.get(conversation_id)

        if memory is None:
            memory = ConversationBufferMemory(
                memory_key="chat_history",
                return_messages=True,
                output_key="answer",
            )
            self.memories[conversation_id] = memory

        prompt_template = "Tell me a {adjective} joke"
        prompt = PromptTemplate(input_variables=["adjective"], template=prompt_template)

        question_gen_chain = LLMChain(
            llm=question_gen_llm, prompt=prompt
        )  # , callback_manager=manager)

        final_qa_chain = load_qa_chain(
            streaming_llm,
            chain_type="stuff",
        )

        es_retriever = LexARIElasticSearchBM25Retriever(
            client=ES_CLIENT, index_name=ES_INDEX_NAME
        )

        docs = es_retriever.get_relevant_documents(message)
        qdrant = Qdrant.from_documents(
            docs,
            EMBEDDINGS,
            location=":memory:",  # Local mode with in-memory storage only
            collection_name="my_documents",
        )

        convo_chain = ConversationalRetrievalChain(
            retriever=qdrant.as_retriever(search_type="similarity"),
            question_generator=question_gen_chain,
            combine_docs_chain=final_qa_chain,
            memory=memory,
            return_source_documents=True,
            callback_manager=convo_cb_manager,
            max_tokens_limit=16385,
        )

        run = asyncio.create_task(convo_chain.acall({"question": message}))

        async for token in streaming_cb.aiter():
            # Print for debugging purposes
            print("dict: ", token.dict())

            # Yield the response as JSON
            yield json.dumps(token.dict())

            if token.dict().get("type") in ["end", "error"]:
                streaming_cb.done.set()

        # Wait for the conversation chain task to complete
        await run

@ZanyuanYang
Copy link

Is anyone know how to fixed this error?
return inputs[prompt_input_key], outputs[output_key]
~~~~~~~^^^^^^^^^^^^
KeyError: 'answer'

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment