Last active
May 13, 2024 15:27
-
-
Save tubone24/bc25e6f2f3ccb37e685fd0007cd36bfd to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from langchain_core.language_models import BaseChatModel | |
from typing import List, Dict, Optional, Any, Iterator | |
from collections import defaultdict | |
from langchain_core.callbacks import ( | |
CallbackManagerForLLMRun, | |
) | |
from langchain_core.messages import ( | |
AIMessage, | |
AIMessageChunk, | |
BaseMessage, | |
) | |
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult | |
from time import sleep | |
fake_response = "hogehoge!!!" | |
class FakeListChatModelWithStreaming(BaseChatModel): | |
model_id: str = "fake-chat-model-with-streaming" | |
streaming: bool = True | |
responses: List[str] = [] | |
sleep_time: int = 0 | |
def __init__(self, responses=None, sleap_time=0, **kwargs: Any): | |
super().__init__(**kwargs) | |
self.responses = responses if responses else [fake_response] | |
self.sleep_time = sleap_time | |
@property | |
def _llm_type(self) -> str: | |
return "fake_chat_model_with_streaming" | |
def _stream( | |
self, | |
messages: List[BaseMessage], | |
stop: Optional[List[str]] = None, | |
run_manager: Optional[CallbackManagerForLLMRun] = None, | |
**kwargs: Any, | |
) -> Iterator[ChatGenerationChunk]: | |
for response in self.responses: | |
for chunk in [char for char in response]: | |
sleep(self.sleep_time) | |
run_manager.on_llm_new_token(chunk) | |
delta = response | |
yield ChatGenerationChunk(message=AIMessageChunk(content=delta)) | |
def _generate( | |
self, | |
messages: List[BaseMessage], | |
stop: Optional[List[str]] = None, | |
run_manager: Optional[CallbackManagerForLLMRun] = None, | |
**kwargs: Any, | |
) -> ChatResult: | |
completion = "" | |
llm_output: Dict[str, Any] = {"model_id": self.model_id} | |
if self.streaming: | |
for chunk in self._stream(messages, stop, run_manager, **kwargs): | |
completion += chunk.text | |
else: | |
params: Dict[str, Any] = {**kwargs} | |
if stop: | |
params["stop_sequences"] = stop | |
llm_output["usage"] = 0 | |
return ChatResult( | |
generations=[ChatGeneration(message=AIMessage(content=completion))], | |
llm_output=llm_output, | |
) | |
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict: | |
final_usage: Dict[str, int] = defaultdict(int) | |
final_output = {} | |
for output in llm_outputs: | |
output = output or {} | |
usage = output.get("usage", {}) | |
for token_type, token_count in usage.items(): | |
final_usage[token_type] += token_count | |
final_output.update(output) | |
final_output["usage"] = final_usage | |
return final_output |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment