Created
July 1, 2023 08:05
-
-
Save harisrab/8fc67827ebf3acb997398b5252869351 to your computer and use it in GitHub Desktop.
Callback for Streaming Langchain Agent and stopping the stream when Last LLM call finishes
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
DEFAULT_ANSWER_PREFIX_TOKENS = ['AI', ':'] | |
class FinalStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler): | |
"""Callback handler for streaming in agents. | |
Only works with agents using LLMs that support streaming. | |
Only the final output of the agent will be streamed. | |
""" | |
def append_to_last_tokens(self, token: str) -> None: | |
self.last_tokens.append(token) | |
self.last_tokens_stripped.append(token.strip()) | |
if len(self.last_tokens) > len(self.answer_prefix_tokens): | |
self.last_tokens.pop(0) | |
self.last_tokens_stripped.pop(0) | |
def check_if_answer_reached(self) -> bool: | |
if self.strip_tokens: | |
return self.last_tokens_stripped == self.answer_prefix_tokens_stripped | |
else: | |
return self.last_tokens == self.answer_prefix_tokens | |
def __init__( | |
self, | |
*, | |
answer_prefix_tokens: Optional[List[str]] = None, | |
strip_tokens: bool = True, | |
stream_prefix: bool = False | |
) -> None: | |
"""Instantiate FinalStreamingStdOutCallbackHandler. | |
Args: | |
answer_prefix_tokens: Token sequence that prefixes the anwer. | |
Default is ["Final", "Answer", ":"] | |
strip_tokens: Ignore white spaces and new lines when comparing | |
answer_prefix_tokens to last tokens? (to determine if answer has been | |
reached) | |
stream_prefix: Should answer prefix itself also be streamed? | |
""" | |
super().__init__() | |
if answer_prefix_tokens is None: | |
self.answer_prefix_tokens = DEFAULT_ANSWER_PREFIX_TOKENS | |
else: | |
self.answer_prefix_tokens = answer_prefix_tokens | |
if strip_tokens: | |
self.answer_prefix_tokens_stripped = [ | |
token.strip() for token in self.answer_prefix_tokens | |
] | |
else: | |
self.answer_prefix_tokens_stripped = self.answer_prefix_tokens | |
self.last_tokens = [""] * len(self.answer_prefix_tokens) | |
self.last_tokens_stripped = [""] * len(self.answer_prefix_tokens) | |
self.strip_tokens = strip_tokens | |
self.stream_prefix = stream_prefix | |
self.answer_reached = False | |
self._token_queue: Queue = Queue() | |
self._done = Event() | |
def on_llm_start( | |
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any | |
) -> None: | |
"""Run when LLM starts running.""" | |
self.answer_reached = False | |
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: | |
"""Run when LLM ends running.""" | |
if self.answer_reached: | |
self._done.set() | |
def on_llm_new_token(self, token: str, **kwargs: Any) -> None: | |
"""Run on new LLM token. Only available when streaming is enabled.""" | |
# Remember the last n tokens, where n = len(answer_prefix_tokens) | |
self.append_to_last_tokens(token) | |
# Check if the last n tokens match the answer_prefix_tokens list ... | |
if self.check_if_answer_reached(): | |
# Check for last response, if reached (AI:) then stop | |
print("Answer reached") | |
self.answer_reached = True | |
if self.stream_prefix: | |
for t in self.last_tokens: | |
# Write Standard Stream and Queue the prefix tokens if required | |
self._token_queue.put_nowait(token) | |
sys.stdout.write(t) | |
sys.stdout.flush() | |
return | |
# ... if yes, then print tokens from now on | |
if self.answer_reached: | |
# Write to the standard stream | |
sys.stdout.write(token) | |
# Write to the queue | |
self._token_queue.put_nowait(token) | |
sys.stdout.flush() | |
def get_response_gen(self) -> Generator: | |
token_buffer = "" | |
while True: | |
if not self._token_queue.empty(): | |
token = self._token_queue.get_nowait() | |
yield token | |
elif self._done.is_set(): | |
print("Cutting the stream") | |
print(f"{self._done}") | |
break |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment