Skip to content

Instantly share code, notes, and snippets.

@ddomen
Last active February 28, 2024 08:38
Show Gist options
  • Save ddomen/8eaa49879d42a4a42a243437b5ddfa83 to your computer and use it in GitHub Desktop.
Save ddomen/8eaa49879d42a4a42a243437b5ddfa83 to your computer and use it in GitHub Desktop.
langchain_textgen_tempfix
# llm version
import json
import logging
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional
import requests
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.llms.base import LLM
from langchain.pydantic_v1 import Field
from langchain.schema.output import GenerationChunk
logger = logging.getLogger(__name__)
class TextGen(LLM):
"""text-generation-webui models.
To use, you should have the text-generation-webui installed, a model loaded,
and --api added as a command-line option.
Suggested installation, use one-click installer for your OS:
https://github.com/oobabooga/text-generation-webui#one-click-installers
Parameters below taken from text-generation-webui api example:
https://github.com/oobabooga/text-generation-webui/blob/main/api-examples/api-example.py
Example:
.. code-block:: python
from langchain.llms import TextGen
llm = TextGen(model_url="http://localhost:8500")
"""
model_url: str
"""The full URL to the textgen webui including http[s]://host:port """
preset: Optional[str] = None
"""The preset to use in the textgen webui """
max_tokens: Optional[int] = 250
"""The maximum number of tokens to generate."""
do_sample: bool = Field(True, alias="do_sample")
"""Do sample"""
temperature: Optional[float] = 1.3
"""Primary factor to control randomness of outputs. 0 = deterministic
(only the most likely token is used). Higher value = more randomness."""
top_p: Optional[float] = 0.1
"""If not set to 1, select tokens with probabilities adding up to less than this
number. Higher value = higher range of possible random results."""
typical_p: Optional[float] = 1
"""If not set to 1, select only tokens that are at least this much more likely to
appear than random tokens, given the prior text."""
epsilon_cutoff: Optional[float] = 0 # In units of 1e-4
"""Epsilon cutoff"""
eta_cutoff: Optional[float] = 0 # In units of 1e-4
"""ETA cutoff"""
repetition_penalty: Optional[float] = 1.18
"""Exponential penalty factor for repeating prior tokens. 1 means no penalty,
higher value = less repetition, lower value = more repetition."""
top_k: Optional[float] = 40
"""Similar to top_p, but select instead only the top_k most likely tokens.
Higher value = higher range of possible random results."""
min_length: Optional[int] = 0
"""Minimum generation length in tokens."""
no_repeat_ngram_size: Optional[int] = 0
"""If not set to 0, specifies the length of token sets that are completely blocked
from repeating at all. Higher values = blocks larger phrases,
lower values = blocks words or letters from repeating.
Only 0 or high values are a good idea in most cases."""
num_beams: Optional[int] = 1
"""Number of beams"""
penalty_alpha: Optional[float] = 0
"""Penalty Alpha"""
length_penalty: Optional[float] = 1
"""Length Penalty"""
early_stopping: bool = Field(False, alias="early_stopping")
"""Early stopping"""
seed: int = Field(-1, alias="seed")
"""Seed (-1 for random)"""
add_bos_token: bool = Field(True, alias="add_bos_token")
"""Add the bos_token to the beginning of prompts.
Disabling this can make the replies more creative."""
truncation_length: Optional[int] = 2048
"""Truncate the prompt up to this length. The leftmost tokens are removed if
the prompt exceeds this length. Most models require this to be at most 2048."""
ban_eos_token: bool = Field(False, alias="ban_eos_token")
"""Ban the eos_token. Forces the model to never end the generation prematurely."""
skip_special_tokens: bool = Field(True, alias="skip_special_tokens")
"""Skip special tokens. Some specific models need this unset."""
stopping_strings: Optional[List[str]] = []
"""A list of strings to stop generation when encountered."""
streaming: bool = False
"""Whether to stream the results, token by token."""
legacy_api: bool = False
"""Wheter to use the legacy rest api"""
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling textgen."""
return {
"max_tokens": self.max_tokens,
"do_sample": self.do_sample,
"temperature": self.temperature,
"top_p": self.top_p,
"typical_p": self.typical_p,
"epsilon_cutoff": self.epsilon_cutoff,
"eta_cutoff": self.eta_cutoff,
"repetition_penalty": self.repetition_penalty,
"top_k": self.top_k,
"min_length": self.min_length,
"no_repeat_ngram_size": self.no_repeat_ngram_size,
"num_beams": self.num_beams,
"penalty_alpha": self.penalty_alpha,
"length_penalty": self.length_penalty,
"early_stopping": self.early_stopping,
"seed": self.seed,
"add_bos_token": self.add_bos_token,
"truncation_length": self.truncation_length,
"ban_eos_token": self.ban_eos_token,
"skip_special_tokens": self.skip_special_tokens,
"stopping_strings": self.stopping_strings,
"legacy_api": self.legacy_api
}
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""
return {**{"model_url": self.model_url}, **self._default_params}
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "textgen"
def _get_parameters(self, stop: Optional[List[str]] = None) -> Dict[str, Any]:
"""
Performs sanity check, preparing parameters in format needed by textgen.
Args:
stop (Optional[List[str]]): List of stop sequences for textgen.
Returns:
Dictionary containing the combined parameters.
"""
# Raise error if stop sequences are in both input and default params
# if self.stop and stop is not None:
if self.stopping_strings and stop is not None:
raise ValueError("`stop` found in both the input and default params.")
if self.preset is None:
params = self._default_params
else:
params = {"preset": self.preset}
# then sets it as configured, or default to an empty list:
params["stopping_strings"] = self.stopping_strings or stop or []
return params
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call the textgen web API and return the output.
Args:
prompt: The prompt to use for generation.
stop: A list of strings to stop generation when encountered.
Returns:
The generated text.
Example:
.. code-block:: python
from langchain.llms import TextGen
llm = TextGen(model_url="http://localhost:5000")
llm("Write a story about llamas.")
"""
if self.streaming:
combined_text_output = ""
for chunk in self._stream(
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
):
combined_text_output += chunk.text
result = combined_text_output
else:
params = self._get_parameters(stop)
if params.get("legacy_api"):
url = f"{self.model_url}/api/v1/generate"
resname = "results"
else:
url = f"{self.model_url}/v1/completions"
resname = "choices"
request = params.copy()
request["prompt"] = prompt
response = requests.post(url, json=request)
if response.status_code == 200:
result = response.json()[resname][0]["text"]
else:
print(f"ERROR: response: {response}")
result = ""
return result
async def _acall(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call the textgen web API and return the output.
Args:
prompt: The prompt to use for generation.
stop: A list of strings to stop generation when encountered.
Returns:
The generated text.
Example:
.. code-block:: python
from langchain.llms import TextGen
llm = TextGen(model_url="http://localhost:5000")
llm("Write a story about llamas.")
"""
if self.streaming:
combined_text_output = ""
async for chunk in self._stream(
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
):
combined_text_output += chunk.text
result = combined_text_output
else:
params = self._get_parameters(stop)
if params.get('legacy_api'):
url = f'{self.model_url}/api/v1/generate'
resname = 'results'
else:
url = f'{self.model_url}/v1/completions'
resname = 'choices'
request = params.copy()
request["prompt"] = prompt
response = requests.post(url, json=request)
if response.status_code == 200:
result = response.json()[resname][0]["text"]
else:
print(f"ERROR: response: {response}")
result = ""
return result
def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
"""Yields results objects as they are generated in real time.
It also calls the callback manager's on_llm_new_token event with
similar parameters to the OpenAI LLM class method of the same name.
Args:
prompt: The prompts to pass into the model.
stop: Optional list of stop words to use when generating.
Returns:
A generator representing the stream of tokens being generated.
Yields:
A dictionary like objects containing a string token and metadata.
See text-generation-webui docs and below for more.
Example:
.. code-block:: python
from langchain.llms import TextGen
llm = TextGen(
model_url = "ws://localhost:5005"
streaming=True
)
for chunk in llm.stream("Ask 'Hi, how are you?' like a pirate:'",
stop=["'","\n"]):
print(chunk, end='', flush=True)
"""
params = {**self._get_parameters(stop), **kwargs}
request = params.copy()
request["prompt"] = prompt
if params.get('legacy_api'):
try:
import websocket
except ImportError:
raise ImportError("The `websocket-client` package is required for streaming (pip install websocket).")
url = f"{self.model_url}/api/v1/stream"
websocket_client = websocket.WebSocket()
websocket_client.connect(url)
websocket_client.send(json.dumps(request))
while True:
result = websocket_client.recv()
result = json.loads(result)
if result["event"] == "text_stream":
chunk = GenerationChunk(
text=result["text"],
generation_info=None,
)
yield chunk
elif result["event"] == "stream_end":
websocket_client.close()
return
if run_manager:
run_manager.on_llm_new_token(token=chunk.text)
else:
url = f"{self.model_url}/v1/completions"
request["stream"] = True
req_session = requests.Session()
with req_session.post(url, json=request, stream=True) as response:
if response.status_code != 200:
print(f"ERROR: response: {response}")
return
for chunk in response.iter_lines():
chunk = chunk.decode("utf8").strip()
if not chunk: continue
if chunk.startswith('data: '): chunk = chunk[6:]
chunk = json.loads(chunk)
chunk_text = chunk["choices"][0]["text"]
yield GenerationChunk(text=chunk_text, generation_info=None)
if run_manager:
run_manager.on_llm_new_token(token=chunk_text)
async def _astream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[GenerationChunk]:
"""Yields results objects as they are generated in real time.
It also calls the callback manager's on_llm_new_token event with
similar parameters to the OpenAI LLM class method of the same name.
Args:
prompt: The prompts to pass into the model.
stop: Optional list of stop words to use when generating.
Returns:
A generator representing the stream of tokens being generated.
Yields:
A dictionary like objects containing a string token and metadata.
See text-generation-webui docs and below for more.
Example:
.. code-block:: python
from langchain.llms import TextGen
llm = TextGen(
model_url = "ws://localhost:5005"
streaming=True
)
for chunk in llm.stream("Ask 'Hi, how are you?' like a pirate:'",
stop=["'","\n"]):
print(chunk, end='', flush=True)
"""
params = {**self._get_parameters(stop), **kwargs}
request = params.copy()
request["prompt"] = prompt
if params.get('legacy_api'):
try:
import websocket
except ImportError:
raise ImportError("The `websocket-client` package is required for streaming (pip install websocket).")
url = f"{self.model_url}/api/v1/stream"
websocket_client = websocket.WebSocket()
websocket_client.connect(url)
websocket_client.send(json.dumps(request))
while True:
result = websocket_client.recv()
result = json.loads(result)
if result["event"] == "text_stream":
chunk = GenerationChunk(
text=result["text"],
generation_info=None,
)
yield chunk
elif result["event"] == "stream_end":
websocket_client.close()
return
if run_manager:
await run_manager.on_llm_new_token(token=chunk.text)
else:
url = f"{self.model_url}/v1/completions"
req_session = requests.Session()
with req_session.post(url, json=request, stream=True) as response:
if response.status_code != 200:
print(f"ERROR: response: {response}")
return
for chunk in response.iter_content(chunk_size=5):
yield GenerationChunk(
text=chunk,
generation_info=None
)
import json
import logging
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional
import requests
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.chat_models.base import SimpleChatModel
from langchain.pydantic_v1 import Field
from langchain.schema.output import ChatGenerationChunk
from langchain.schema.messages import BaseMessage, HumanMessageChunk, AIMessageChunk, SystemMessageChunk
logger = logging.getLogger(__name__)
TYPE_MAP=dict(
human='user',
ai='assistant',
chat='user',
function='user',
tool='user',
system='system'
)
INVTYPE_MAP = dict(
user=HumanMessageChunk,
assistant=AIMessageChunk,
system=SystemMessageChunk
)
class ChatTextGen(SimpleChatModel):
"""text-generation-webui models.
To use, you should have the text-generation-webui installed, a model loaded,
and --api added as a command-line option.
Suggested installation, use one-click installer for your OS:
https://github.com/oobabooga/text-generation-webui#one-click-installers
Parameters below taken from text-generation-webui api example:
https://github.com/oobabooga/text-generation-webui/blob/main/api-examples/api-example.py
Example:
.. code-block:: python
from langchain.llms import TextGen
llm = TextGen(model_url="http://localhost:8500")
"""
model_url: str
"""The full URL to the textgen webui including http[s]://host:port """
preset: Optional[str] = None
"""The preset to use in the textgen webui """
max_tokens: Optional[int] = 250
"""The maximum number of tokens to generate."""
do_sample: bool = Field(True, alias="do_sample")
"""Do sample"""
temperature: Optional[float] = 1.3
"""Primary factor to control randomness of outputs. 0 = deterministic
(only the most likely token is used). Higher value = more randomness."""
top_p: Optional[float] = 0.1
"""If not set to 1, select tokens with probabilities adding up to less than this
number. Higher value = higher range of possible random results."""
typical_p: Optional[float] = 1
"""If not set to 1, select only tokens that are at least this much more likely to
appear than random tokens, given the prior text."""
epsilon_cutoff: Optional[float] = 0 # In units of 1e-4
"""Epsilon cutoff"""
eta_cutoff: Optional[float] = 0 # In units of 1e-4
"""ETA cutoff"""
repetition_penalty: Optional[float] = 1.18
"""Exponential penalty factor for repeating prior tokens. 1 means no penalty,
higher value = less repetition, lower value = more repetition."""
top_k: Optional[float] = 40
"""Similar to top_p, but select instead only the top_k most likely tokens.
Higher value = higher range of possible random results."""
min_length: Optional[int] = 0
"""Minimum generation length in tokens."""
no_repeat_ngram_size: Optional[int] = 0
"""If not set to 0, specifies the length of token sets that are completely blocked
from repeating at all. Higher values = blocks larger phrases,
lower values = blocks words or letters from repeating.
Only 0 or high values are a good idea in most cases."""
num_beams: Optional[int] = 1
"""Number of beams"""
penalty_alpha: Optional[float] = 0
"""Penalty Alpha"""
length_penalty: Optional[float] = 1
"""Length Penalty"""
early_stopping: bool = Field(False, alias="early_stopping")
"""Early stopping"""
seed: int = Field(-1, alias="seed")
"""Seed (-1 for random)"""
add_bos_token: bool = Field(True, alias="add_bos_token")
"""Add the bos_token to the beginning of prompts.
Disabling this can make the replies more creative."""
truncation_length: Optional[int] = 2048
"""Truncate the prompt up to this length. The leftmost tokens are removed if
the prompt exceeds this length. Most models require this to be at most 2048."""
ban_eos_token: bool = Field(False, alias="ban_eos_token")
"""Ban the eos_token. Forces the model to never end the generation prematurely."""
skip_special_tokens: bool = Field(True, alias="skip_special_tokens")
"""Skip special tokens. Some specific models need this unset."""
stopping_strings: Optional[List[str]] = []
"""A list of strings to stop generation when encountered."""
streaming: bool = False
"""Whether to stream the results, token by token."""
legacy_api: bool = False
"""Wheter to use the legacy rest api"""
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling textgen."""
return {
"max_tokens": self.max_tokens,
"do_sample": self.do_sample,
"temperature": self.temperature,
"top_p": self.top_p,
"typical_p": self.typical_p,
"epsilon_cutoff": self.epsilon_cutoff,
"eta_cutoff": self.eta_cutoff,
"repetition_penalty": self.repetition_penalty,
"top_k": self.top_k,
"min_length": self.min_length,
"no_repeat_ngram_size": self.no_repeat_ngram_size,
"num_beams": self.num_beams,
"penalty_alpha": self.penalty_alpha,
"length_penalty": self.length_penalty,
"early_stopping": self.early_stopping,
"seed": self.seed,
"add_bos_token": self.add_bos_token,
"truncation_length": self.truncation_length,
"ban_eos_token": self.ban_eos_token,
"skip_special_tokens": self.skip_special_tokens,
"stopping_strings": self.stopping_strings,
"legacy_api": self.legacy_api
}
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""
return {**{"model_url": self.model_url}, **self._default_params}
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "textgen"
def _get_parameters(self, stop: Optional[List[str]] = None) -> Dict[str, Any]:
"""
Performs sanity check, preparing parameters in format needed by textgen.
Args:
stop (Optional[List[str]]): List of stop sequences for textgen.
Returns:
Dictionary containing the combined parameters.
"""
# Raise error if stop sequences are in both input and default params
# if self.stop and stop is not None:
if self.stopping_strings and stop is not None:
raise ValueError("`stop` found in both the input and default params.")
if self.preset is None:
params = self._default_params
else:
params = {"preset": self.preset}
# then sets it as configured, or default to an empty list:
params["stopping_strings"] = self.stopping_strings or stop or []
return params
def _call(
self,
prompt: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call the textgen web API and return the output.
Args:
prompt: The prompt to use for generation.
stop: A list of strings to stop generation when encountered.
Returns:
The generated text.
Example:
.. code-block:: python
from langchain.llms import TextGen
llm = TextGen(model_url="http://localhost:5000")
llm("Write a story about llamas.")
"""
if self.streaming:
combined_text_output = ""
for chunk in self._stream(
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
):
combined_text_output += chunk.text
result = combined_text_output
else:
params = self._get_parameters(stop)
request = params.copy()
if params.get("legacy_api"):
url = f"{self.model_url}/api/v1/generate"
resname = "results"
request["prompt"] = str.join("\n", [
f"{TYPE_MAP.get(message.type, 'user')}: {message.content}"
for message in prompt
])
else:
url = f"{self.model_url}/v1/chat/completions"
resname = "choices"
request["messages"] = [
{
'content': message.content,
'role': TYPE_MAP.get(message.type, 'user')
}
for message in prompt
]
response = requests.post(url, json=request)
if response.status_code == 200:
result = response.json()[resname][0]["message"]["content"]
else:
print(f"ERROR: response: {response}")
result = ""
return result
async def _acall(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call the textgen web API and return the output.
Args:
prompt: The prompt to use for generation.
stop: A list of strings to stop generation when encountered.
Returns:
The generated text.
Example:
.. code-block:: python
from langchain.llms import TextGen
llm = TextGen(model_url="http://localhost:5000")
llm("Write a story about llamas.")
"""
if self.streaming:
combined_text_output = ""
async for chunk in self._stream(
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
):
combined_text_output += chunk.text
result = combined_text_output
else:
params = self._get_parameters(stop)
request = params.copy()
if params.get("legacy_api"):
url = f"{self.model_url}/api/v1/generate"
resname = "results"
request["prompt"] = str.join("\n", [
f"{TYPE_MAP.get(message.type, 'user')}: {message.content}"
for message in prompt
])
else:
url = f"{self.model_url}/v1/chat/completions"
resname = "choices"
request["messages"] = [
{ 'content': message.content, 'role': TYPE_MAP.get(message.type, 'user') }
for message in prompt
]
response = requests.post(url, json=request)
if response.status_code == 200:
result = response.json()[resname][0]["message"]["content"]
else:
print(f"ERROR: response: {response}")
result = ""
return result
def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
"""Yields results objects as they are generated in real time.
It also calls the callback manager's on_llm_new_token event with
similar parameters to the OpenAI LLM class method of the same name.
Args:
prompt: The prompts to pass into the model.
stop: Optional list of stop words to use when generating.
Returns:
A generator representing the stream of tokens being generated.
Yields:
A dictionary like objects containing a string token and metadata.
See text-generation-webui docs and below for more.
Example:
.. code-block:: python
from langchain.llms import TextGen
llm = TextGen(
model_url = "ws://localhost:5005"
streaming=True
)
for chunk in llm.stream("Ask 'Hi, how are you?' like a pirate:'",
stop=["'","\n"]):
print(chunk, end='', flush=True)
"""
params = {**self._get_parameters(stop), **kwargs}
request = params.copy()
if params.get('legacy_api'):
try:
import websocket
except ImportError:
raise ImportError("The `websocket-client` package is required for streaming (pip install websocket).")
url = f"{self.model_url}/api/v1/stream"
request["prompt"] = str.join("\n", [
f"{TYPE_MAP.get(message.type, 'user')}: {message.content}"
for message in prompt
])
websocket_client = websocket.WebSocket()
websocket_client.connect(url)
websocket_client.send(json.dumps(request))
while True:
result = websocket_client.recv()
result = json.loads(result)
if result["event"] == "text_stream":
chunk = ChatGenerationChunk(
text=result["text"],
generation_info=None,
)
yield chunk
elif result["event"] == "stream_end":
websocket_client.close()
return
if run_manager:
run_manager.on_llm_new_token(token=chunk.text)
else:
url = f"{self.model_url}/v1/chat/completions"
request["stream"] = True
request["messages"] = [
{ 'content': message.content, 'role': TYPE_MAP.get(message.type, 'user') }
for message in prompt
]
req_session = requests.Session()
with req_session.post(url, json=request, stream=True) as response:
if response.status_code != 200:
print(f"ERROR: response: {response}")
return
for chunk in response.iter_lines():
chunk = chunk.decode("utf8").strip()
if not chunk: continue
if chunk.startswith('data: '): chunk = chunk[6:]
chunk = json.loads(chunk)
chunk_msg = chunk["choices"][0]["message"]
yield ChatGenerationChunk(
text=chunk_msg["content"],
message=INVTYPE_MAP.get(chunk_msg["role"], HumanMessageChunk)(content=chunk_msg["content"]),
generation_info=None
)
if run_manager:
run_manager.on_llm_new_token(token=chunk_msg["content"])
async def _astream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
"""Yields results objects as they are generated in real time.
It also calls the callback manager's on_llm_new_token event with
similar parameters to the OpenAI LLM class method of the same name.
Args:
prompt: The prompts to pass into the model.
stop: Optional list of stop words to use when generating.
Returns:
A generator representing the stream of tokens being generated.
Yields:
A dictionary like objects containing a string token and metadata.
See text-generation-webui docs and below for more.
Example:
.. code-block:: python
from langchain.llms import TextGen
llm = TextGen(
model_url = "ws://localhost:5005"
streaming=True
)
for chunk in llm.stream("Ask 'Hi, how are you?' like a pirate:'",
stop=["'","\n"]):
print(chunk, end='', flush=True)
"""
params = {**self._get_parameters(stop), **kwargs}
request = params.copy()
if params.get('legacy_api'):
try:
import websocket
except ImportError:
raise ImportError("The `websocket-client` package is required for streaming (pip install websocket).")
url = f"{self.model_url}/api/v1/stream"
request["prompt"] = str.join("\n", [
f"{TYPE_MAP.get(message.type, 'user')}: {message.content}"
for message in prompt
])
websocket_client = websocket.WebSocket()
websocket_client.connect(url)
websocket_client.send(json.dumps(request))
while True:
result = websocket_client.recv()
result = json.loads(result)
if result["event"] == "text_stream":
chunk = ChatGenerationChunk(
text=result["text"],
generation_info=None,
)
yield chunk
elif result["event"] == "stream_end":
websocket_client.close()
return
if run_manager:
run_manager.on_llm_new_token(token=chunk.text)
else:
url = f"{self.model_url}/v1/chat/completions"
request["stream"] = True
request["messages"] = [
{ 'content': message.content, 'role': TYPE_MAP.get(message.type, 'user') }
for message in prompt
]
req_session = requests.Session()
with req_session.post(url, json=request, stream=True) as response:
if response.status_code != 200:
print(f"ERROR: response: {response}")
return
for chunk in response.iter_lines():
chunk = chunk.decode("utf8").strip()
if not chunk: continue
if chunk.startswith('data: '): chunk = chunk[6:]
chunk = json.loads(chunk)
chunk_msg = chunk["choices"][0]["message"]
yield ChatGenerationChunk(
text=chunk_msg["content"],
message=INVTYPE_MAP.get(chunk_msg["role"], HumanMessageChunk)(content=chunk_msg["content"]),
generation_info=None
)
if run_manager:
await run_manager.on_llm_new_token(token=chunk_msg["content"])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment