Skip to content

Instantly share code, notes, and snippets.

@BrandonStudio
Created April 9, 2024 05:49
Show Gist options
  • Save BrandonStudio/638a629911e47fee29175ca5c0b7430c to your computer and use it in GitHub Desktop.
Save BrandonStudio/638a629911e47fee29175ca5c0b7430c to your computer and use it in GitHub Desktop.
LangChain batch job progress bar callback
from typing import Any, Dict
from uuid import UUID
from tqdm.auto import tqdm
from langchain_core.callbacks import BaseCallbackHandler
class BatchCallback(BaseCallbackHandler):
def __init__(self, total: int):
super().__init__()
self.count = 0
self.progress_bar = tqdm(total=total) # define a progress bar
# Override on_llm_end method. This is called after every response from LLM
def on_llm_end(self, response: LLMResult, *, run_id: UUID, parent_run_id: UUID | None = None, **kwargs: Any) -> Any:
self.count += 1
self.progress_bar.update(1)
def __enter__(self):
self.progress_bar.__enter__()
return self
def __exit__(self, exc_type, exc_value, exc_traceback):
self.progress_bar.__exit__(exc_type, exc_value, exc_traceback)
def __del__(self):
self.progress_bar.__del__()
# Assume your chain is `chain`, inputs is `inputs`
with BatchCallback(len(inputs)) as cb: # init callback
chain.batch(inputs, config={"callbacks": [cb]})
@hxia-neos
Copy link

langchain-ai/langchain#6053 (comment)

    def on_chain_start(
        self, serialized: dict[str, Any], inputs: dict[str, Any], **kwargs: Any
    ) -> Any:
        if kwargs["parent_run_id"] is None:
            self.count += 1
            self.progress_bar.update(1)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment