Skip to content

Instantly share code, notes, and snippets.

@tubone24
Last active May 13, 2024 15:27
Show Gist options
  • Save tubone24/bc25e6f2f3ccb37e685fd0007cd36bfd to your computer and use it in GitHub Desktop.
Save tubone24/bc25e6f2f3ccb37e685fd0007cd36bfd to your computer and use it in GitHub Desktop.
from langchain_core.language_models import BaseChatModel
from typing import List, Dict, Optional, Any, Iterator
from collections import defaultdict
from langchain_core.callbacks import (
CallbackManagerForLLMRun,
)
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from time import sleep
fake_response = "hogehoge!!!"
class FakeListChatModelWithStreaming(BaseChatModel):
model_id: str = "fake-chat-model-with-streaming"
streaming: bool = True
responses: List[str] = []
sleep_time: int = 0
def __init__(self, responses=None, sleap_time=0, **kwargs: Any):
super().__init__(**kwargs)
self.responses = responses if responses else [fake_response]
self.sleep_time = sleap_time
@property
def _llm_type(self) -> str:
return "fake_chat_model_with_streaming"
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
for response in self.responses:
for chunk in [char for char in response]:
sleep(self.sleep_time)
run_manager.on_llm_new_token(chunk)
delta = response
yield ChatGenerationChunk(message=AIMessageChunk(content=delta))
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
completion = ""
llm_output: Dict[str, Any] = {"model_id": self.model_id}
if self.streaming:
for chunk in self._stream(messages, stop, run_manager, **kwargs):
completion += chunk.text
else:
params: Dict[str, Any] = {**kwargs}
if stop:
params["stop_sequences"] = stop
llm_output["usage"] = 0
return ChatResult(
generations=[ChatGeneration(message=AIMessage(content=completion))],
llm_output=llm_output,
)
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
final_usage: Dict[str, int] = defaultdict(int)
final_output = {}
for output in llm_outputs:
output = output or {}
usage = output.get("usage", {})
for token_type, token_count in usage.items():
final_usage[token_type] += token_count
final_output.update(output)
final_output["usage"] = final_usage
return final_output
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment