Skip to content

Instantly share code, notes, and snippets.

@harisrab
Created July 1, 2023 08:05
Show Gist options
  • Save harisrab/8fc67827ebf3acb997398b5252869351 to your computer and use it in GitHub Desktop.
Save harisrab/8fc67827ebf3acb997398b5252869351 to your computer and use it in GitHub Desktop.
Callback for Streaming Langchain Agent and stopping the stream when Last LLM call finishes
DEFAULT_ANSWER_PREFIX_TOKENS = ['AI', ':']
class FinalStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler):
"""Callback handler for streaming in agents.
Only works with agents using LLMs that support streaming.
Only the final output of the agent will be streamed.
"""
def append_to_last_tokens(self, token: str) -> None:
self.last_tokens.append(token)
self.last_tokens_stripped.append(token.strip())
if len(self.last_tokens) > len(self.answer_prefix_tokens):
self.last_tokens.pop(0)
self.last_tokens_stripped.pop(0)
def check_if_answer_reached(self) -> bool:
if self.strip_tokens:
return self.last_tokens_stripped == self.answer_prefix_tokens_stripped
else:
return self.last_tokens == self.answer_prefix_tokens
def __init__(
self,
*,
answer_prefix_tokens: Optional[List[str]] = None,
strip_tokens: bool = True,
stream_prefix: bool = False
) -> None:
"""Instantiate FinalStreamingStdOutCallbackHandler.
Args:
answer_prefix_tokens: Token sequence that prefixes the anwer.
Default is ["Final", "Answer", ":"]
strip_tokens: Ignore white spaces and new lines when comparing
answer_prefix_tokens to last tokens? (to determine if answer has been
reached)
stream_prefix: Should answer prefix itself also be streamed?
"""
super().__init__()
if answer_prefix_tokens is None:
self.answer_prefix_tokens = DEFAULT_ANSWER_PREFIX_TOKENS
else:
self.answer_prefix_tokens = answer_prefix_tokens
if strip_tokens:
self.answer_prefix_tokens_stripped = [
token.strip() for token in self.answer_prefix_tokens
]
else:
self.answer_prefix_tokens_stripped = self.answer_prefix_tokens
self.last_tokens = [""] * len(self.answer_prefix_tokens)
self.last_tokens_stripped = [""] * len(self.answer_prefix_tokens)
self.strip_tokens = strip_tokens
self.stream_prefix = stream_prefix
self.answer_reached = False
self._token_queue: Queue = Queue()
self._done = Event()
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Run when LLM starts running."""
self.answer_reached = False
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Run when LLM ends running."""
if self.answer_reached:
self._done.set()
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Run on new LLM token. Only available when streaming is enabled."""
# Remember the last n tokens, where n = len(answer_prefix_tokens)
self.append_to_last_tokens(token)
# Check if the last n tokens match the answer_prefix_tokens list ...
if self.check_if_answer_reached():
# Check for last response, if reached (AI:) then stop
print("Answer reached")
self.answer_reached = True
if self.stream_prefix:
for t in self.last_tokens:
# Write Standard Stream and Queue the prefix tokens if required
self._token_queue.put_nowait(token)
sys.stdout.write(t)
sys.stdout.flush()
return
# ... if yes, then print tokens from now on
if self.answer_reached:
# Write to the standard stream
sys.stdout.write(token)
# Write to the queue
self._token_queue.put_nowait(token)
sys.stdout.flush()
def get_response_gen(self) -> Generator:
token_buffer = ""
while True:
if not self._token_queue.empty():
token = self._token_queue.get_nowait()
yield token
elif self._done.is_set():
print("Cutting the stream")
print(f"{self._done}")
break
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment