Last active
October 28, 2023 01:41
-
-
Save dcbark01/dd5ca1f824145d4d44885031deaa9fa1 to your computer and use it in GitHub Desktop.
AutoHuggingFaceTextGenInference for Langchain PR
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# TODO: Add this to PR for Langchain so that it will be easy to use across all our different LLM projects | |
import re | |
import time | |
import warnings | |
from pathlib import Path | |
from typing import List, Union, Optional | |
import requests | |
from tqdm import tqdm | |
from pydantic import BaseModel, Field, field_validator, computed_field | |
from langchain.llms import HuggingFaceTextGenInference | |
try: | |
import docker | |
except ImportError: | |
# Allow module to be imported regardless of whether user has python docker library installed. | |
# We'll raise an error at runtime if they try to actually use any of the methods in | |
# the class where we implement all our logic. | |
docker = None | |
SERVICE_NAME = "auto-hftextgen" | |
CACHE_DIR = Path().expanduser() / ".cache" | |
MODELS_DIR = CACHE_DIR / "models" | |
# Filter annoying warning about protected namespace in pydantic. This isn't caused | |
# by our case (it's from the text-generation import) so nothing we can do about it. | |
warnings.filterwarnings("ignore", category=UserWarning, module="pydantic") | |
def get_image_tag(image: str): | |
""" Return tag from docker image_name:tag """ | |
if ":" not in image: | |
return "" | |
else: | |
try: | |
tag = image.split(":")[1] | |
return tag | |
except IndexError: | |
return "" | |
def convert_shm_size(shm_size_str: str) -> int: | |
""" | |
Convert shared memory size from string (e.g., "1g") to integer (e.g., 1073741824). | |
Args: | |
shm_size_str (str): Shared memory size as a string. | |
Returns: | |
int: Shared memory size as an integer. | |
""" | |
if shm_size_str.isdigit(): | |
return int(shm_size_str) | |
if shm_size_str.endswith('g') or shm_size_str.endswith('G'): | |
return int(shm_size_str[:-1]) * 1024 * 1024 * 1024 | |
if shm_size_str.endswith('m') or shm_size_str.endswith('M'): | |
return int(shm_size_str[:-1]) * 1024 * 1024 | |
if shm_size_str.endswith('k') or shm_size_str.endswith('K'): | |
return int(shm_size_str[:-1]) * 1024 | |
raise ValueError(f"Invalid shm_size format: {shm_size_str}") | |
def _make_container_name_from_hf_model_name(model_name: str): | |
model_name_sanitized = model_name.replace("/", "--") # Docker container names can't have / in them | |
container_name = f'{SERVICE_NAME}--{model_name_sanitized}' | |
return container_name | |
def get_container(client, model_name): | |
container_name = _make_container_name_from_hf_model_name(model_name) | |
for container in client.containers.list(): | |
if container_name == container.name: | |
return container | |
return None | |
def stop_container(model_name: str, timeout=30): | |
""" Stop a container running one of our models. """ | |
client = docker.from_env() | |
container = get_container(client, model_name) | |
if container is not None: | |
print(f'Stopping container for {model_name}') | |
container.stop() | |
else: | |
print(f'No running container found for {model_name}') | |
if not container: | |
return | |
try: | |
container.wait(timeout=timeout) | |
except docker.errors.NotFound: | |
# Container is already removed, it has shut down | |
return | |
except docker.errors.APIError as e: | |
if "container not found" in str(e).lower(): | |
# Container is not found, it has shut down | |
return | |
else: | |
raise | |
def remove_container(model_name): | |
client = docker.from_env() | |
container = get_container(client, model_name) | |
if container is not None: | |
print(f'Removing container for {model_name}') | |
container.remove() | |
else: | |
print(f'No container found for {model_name}') | |
def start_model_container( | |
model_name: str, | |
volume_path: Union[str, Path], | |
port: int, | |
shm_size: int, | |
image_tag: str = '1.1.0', | |
): | |
volume_path = volume_path if volume_path is not None else MODELS_DIR | |
volume_path = volume_path if isinstance(volume_path, str) else str(volume_path) | |
client = docker.from_env() | |
# Check if the container is already running | |
container = get_container(client, model_name) | |
if container is not None: | |
print(f'Container for {model_name} is already running') | |
return container | |
# If not, start a new container | |
container_name = _make_container_name_from_hf_model_name(model_name) | |
print(f"Starting container: '{container_name}'") | |
host_port = port | |
container_port = 80 # Hardcoding this | |
container = client.containers.run( | |
f'ghcr.io/huggingface/text-generation-inference:{image_tag}', | |
f'--model-id {model_name}', | |
name=container_name, | |
volumes={volume_path: {'bind': '/data', 'mode': 'rw'}}, | |
ports={f'{container_port}/tcp': host_port}, | |
shm_size=shm_size, | |
detach=True, | |
runtime='nvidia', | |
remove=True, | |
) | |
return container | |
class _SharedParams(BaseModel): | |
""" Things that both HFTextGen Docker container and Langchain HFTextGen wrapper need to share. """ | |
port: int = 8080 # This should be the port on the HOST machine | |
def dict(self, *args, **kwargs) -> 'DictStrAny': | |
raise DeprecationWarning( | |
"Method `dict` is pending deprecation in pydantic. For the _SharedParams model and any child classes," | |
"to cast to a dict use our `to_dict` method instead." | |
) | |
def to_dict(self, **kwargs) -> dict: | |
exclude = kwargs.get('exclude', set()) | |
return super().model_dump(exclude=exclude) | |
class DockerParams(_SharedParams): | |
pretrained_model_name: str # Would just use `model_name` but pydantic throws a protected value error. | |
volume_path: str = str(MODELS_DIR) | |
image_tag: str = '1.1.0' | |
shm_size: Optional[Union[int, str]] = 1073741824 # 1gb | |
@field_validator('shm_size') | |
@classmethod | |
def convert_str_to_int(cls, value: Union[int, str]): | |
if isinstance(value, int): | |
return value | |
else: | |
return convert_shm_size(value) | |
@classmethod | |
def from_container(cls, container): | |
""" Given an existing container, get the params used to construct it. """ | |
c = container.attrs | |
# Get model name | |
cmd_args: list = c['Args'] | |
i = cmd_args.index("--model-id") # Required arg, so should be safe | |
model_name = cmd_args[i + 1] | |
# Get volume details | |
# This means identity will be formulated based on first volume only; think that's okay for our use case here. | |
# We shouldn't need more than one volume anyway. | |
# Note this also implicitly requires that a volume always be present; will fail otherwise, | |
# and we're not currently handling that. | |
binds = c["HostConfig"]["Binds"][0] | |
volume_path = binds.split(":")[0] | |
# Get ports | |
host_port = int(c["HostConfig"]["PortBindings"]["80/tcp"][0]["HostPort"]) | |
# Get shared memory size and misc items | |
shm_size = c["HostConfig"]["ShmSize"] | |
image_tag = get_image_tag(c['Config']['Image']) | |
return DockerParams( | |
port=host_port, | |
pretrained_model_name=model_name, | |
volume_path=volume_path, | |
image_tag=image_tag, | |
shm_size=shm_size, | |
) | |
class HFTextGenParams(_SharedParams): | |
""" Configuration params container for Langchain HuggingFaceTextGenInference wrapper. | |
See docs here for details of available generation params: | |
https://python.langchain.com/docs/integrations/llms/huggingface_textgen_inference | |
""" | |
host: str = 'localhost' | |
max_new_tokens: int = 512 | |
top_k: Optional[int] = None | |
top_p: Optional[float] = 0.95 | |
typical_p: Optional[float] = 0.95 | |
temperature: Optional[float] = 0.8 | |
repetition_penalty: Optional[float] = None | |
return_full_text: bool = False | |
truncate: Optional[int] = None | |
stop_sequences: List[str] = Field(default_factory=list) | |
seed: Optional[int] = None | |
timeout: int = 120 | |
streaming: bool = False | |
do_sample: bool = False | |
watermark: bool = False | |
@computed_field | |
@property | |
def inference_server_url(self) -> str: | |
url = f"http://{self.host}:{self.port}" | |
return url | |
def to_dict(self): | |
""" Get dict of valid args for HFTextGen instantiation. | |
Since Langchain HuggingFaceTextGenInference only accepts an 'inference_server_url` | |
arg (not host/port like we're using) we need to override dict so that when | |
we call it based on our schema here it pops the host/port args and | |
only include the inference server URL, which we derive from host/port. | |
""" | |
return super().to_dict(exclude={"host", "port"}) | |
class AutoHuggingFaceTextGenInference: | |
HEALTH_CHECK_TIMEOUT = 30 | |
DOWNLOAD_TIMEOUT = 600 # 10 minutes; will probably have to raise this for larger models | |
@staticmethod | |
def validate_environment(): | |
if docker is None: | |
raise EnvironmentError( | |
"Python `docker` library not found in current environment. " | |
"Please `pip install docker` if you want to use AutoHuggingFaceTextGenInference." | |
) | |
@staticmethod | |
def _get_container(client, model_name): | |
return get_container(client, model_name) | |
@staticmethod | |
def _container_exists_with_same_params(container, docker_params: DockerParams) -> bool: | |
""" | |
Check if a container exists with the same parameters. | |
Compare the container's configuration (image, volumes, etc.) with the new DockerParams. | |
""" | |
if container is None: | |
return False | |
# Get the exiting container's information | |
existing_params = DockerParams.from_container(container) | |
return docker_params == existing_params | |
@classmethod | |
def _wait_for_model_download(cls, container): | |
""" Wait for the model download to complete by monitoring container logs """ | |
if container is None: | |
raise ValueError(f"Container for '{model_name}' not found.") | |
# Define regular expressions to match relevant log messages | |
download_pattern = r"Downloaded .* in (\d+:\d+:\d+)." | |
success_pattern = r"Successfully downloaded weights." | |
# Initialize a progress bar | |
with tqdm(total=cls.DOWNLOAD_TIMEOUT, desc="Waiting for download (time left before timeout)") as pbar: | |
# Keep checking the container logs until the download is complete | |
start_time = time.time() | |
download_complete = False | |
while not download_complete and time.time() - start_time < cls.DOWNLOAD_TIMEOUT: | |
logs = container.logs().decode("utf-8") | |
# Search for patterns in the logs | |
download_match = re.search(download_pattern, logs) | |
success_match = re.search(success_pattern, logs) | |
if download_match and success_match: | |
download_time = download_match.group(1) | |
print(f"Model download completed in {download_time}") | |
download_complete = True | |
else: | |
time.sleep(5) # Wait for 5 seconds before checking again | |
pbar.update(5) # Update the progress bar | |
if not download_complete: | |
raise TimeoutError(f"Model download for '{model_name}' timed out.") | |
@classmethod | |
def _wait_for_status(cls, container, inference_server_url: str): | |
""" Wait for the container and inference API to become healthy """ | |
# First handle waiting for download to complete, if it isn't cached | |
cls._wait_for_model_download(container) | |
is_container_ready = False | |
is_api_ready = False | |
pbar = ".." | |
start_time = time.time() | |
while time.time() - start_time < cls.HEALTH_CHECK_TIMEOUT: | |
pbar += "." | |
if is_container_ready and is_api_ready: | |
break | |
if container.status in ["running", "created"]: | |
is_container_ready = True | |
# Add a health check for the /health API | |
health_check_start_time = time.time() | |
health_check_url = f'{inference_server_url}/health' | |
while time.time() - health_check_start_time < cls.HEALTH_CHECK_TIMEOUT: | |
pbar += "." | |
msg = ( | |
f"Waiting for inference server to start up " | |
f"(CONTAINER={int(is_container_ready)}, " | |
f"API={int(is_api_ready)})" | |
) | |
msg += pbar | |
print(msg, flush=True) | |
if is_api_ready: | |
break | |
try: | |
response = requests.get(health_check_url) | |
if response.status_code == 200: | |
is_api_ready = True | |
except requests.ConnectionError: | |
pass | |
time.sleep(1) | |
time.sleep(1) | |
if is_container_ready and is_api_ready: | |
return container | |
else: | |
raise TimeoutError( | |
f"Container startup timeout. " | |
f"Did not receive healthy status within {cls.HEALTH_CHECK_TIMEOUT} seconds." | |
f"\tIs container ready: {is_container_ready}", | |
f"\tIs inference API ready: {is_api_ready}" | |
) | |
@classmethod | |
def _create_or_recreate_container(cls, client, model_name: str, docker_params: DockerParams): | |
""" | |
Create or recreate a container with the given parameters. | |
Check if a container with the same name and the same parameters exists. | |
If it exists, remove it and create a new one. If not, create a new one. | |
""" | |
container = get_container(client, model_name) | |
if not cls._container_exists_with_same_params(container, docker_params): | |
# User wants to switch to a new configuration; shutdown and remove old container | |
cls.shutdown(model_name) | |
# Create a new container with the updated parameters | |
d_params_docker = docker_params.to_dict() | |
d_params_docker['model_name'] = d_params_docker.pop('pretrained_model_name') # Keeps pydantic happy | |
container = start_model_container(**d_params_docker) | |
return container | |
@classmethod | |
def _from_params( | |
cls, | |
client, | |
generation_params: HFTextGenParams, | |
docker_params: DockerParams, | |
) -> HuggingFaceTextGenInference: | |
# Check if a container with the same name and parameters exists, and recreate if needed | |
container = cls._create_or_recreate_container(client, docker_params.pretrained_model_name, docker_params) | |
# Wait for healthy status based on running container and healthcheck of inference API | |
cls._wait_for_status(container, generation_params.inference_server_url) | |
# Create the langchain wrapper for the server once the server is up and running | |
d_params_generation = generation_params.to_dict() | |
llm = HuggingFaceTextGenInference(**d_params_generation) | |
return llm | |
@classmethod | |
def shutdown(cls, model_name: str): | |
""" Stops the container and removes it. | |
Since we've hardcoded `remove=True` in the startup method, shutting down | |
automatically removes the container (doing it this way gives up some | |
flexibility, but makes state management a lot easier, which we really | |
want to avoid doing from within our python code, and let docker handle | |
as much container state management as possible). | |
""" | |
cls.validate_environment() | |
stop_container(model_name) | |
@classmethod | |
def from_docker(cls, model_name: str, **kwargs) -> HuggingFaceTextGenInference: | |
""" Expose method to create an instance of HF Textgen server by spinning up docker container. """ | |
cls.validate_environment() | |
kwargs = {**{'pretrained_model_name': model_name}, **kwargs} | |
client = docker.from_env() | |
docker_params = DockerParams(**kwargs) | |
generation_params = HFTextGenParams(**kwargs) | |
return cls._from_params(client, generation_params, docker_params) | |
if __name__ == """__main__""": | |
""" Example usage """ | |
# model_name = 'tiiuae/falcon-7b-instruct' | |
model_name = 'bigscience/bloom-560m' | |
llm = AutoHuggingFaceTextGenInference.from_docker(model_name, host='0.0.0.0', port=8080, shm_size="1g") | |
answer = llm("How old is the universe?") | |
print(answer) | |
# import pytest | |
# | |
# import docker | |
# | |
# from llm import ( | |
# AutoHuggingFaceTextGenInference, | |
# HuggingFaceTextGenInference, | |
# DockerParams, | |
# HFTextGenParams, | |
# ) | |
# | |
# # Define some test data | |
# model_name = 'bigscience/bloom-560m' | |
# generation_params = HFTextGenParams(host='0.0.0.0', port=8081, max_new_tokens=512) | |
# | |
# # Start from pristine state with none of our containers running | |
# client = docker.from_env() | |
# for container in client.containers.list(): | |
# if container.name.startswith(SERVICE_NAME): | |
# container.stop() | |
# container.wait(timeout=30) | |
# | |
# | |
# @pytest.mark.slow | |
# def test_startup_shutdown(): | |
# import docker | |
# | |
# client = docker.from_env() | |
# | |
# # Test starting and stopping a container | |
# llm = AutoHuggingFaceTextGenInference.from_docker(model_name, host='0.0.0.0', port=8081, shm_size="1g") | |
# assert isinstance(llm, HuggingFaceTextGenInference) | |
# | |
# # Ensure the container is running | |
# container = AutoHuggingFaceTextGenInference._get_container(client, model_name) | |
# assert container.status == 'running' | |
# | |
# # Shutdown the container | |
# AutoHuggingFaceTextGenInference.shutdown(model_name) | |
# container = AutoHuggingFaceTextGenInference._get_container(client, model_name) | |
# assert container is None # Container should no longer exist | |
# | |
# | |
# @pytest.mark.slow | |
# def test_recreate_container(): | |
# # Test recreating a container with the same parameters | |
# import docker | |
# client = docker.from_env() | |
# | |
# # Start the container initially | |
# llm = AutoHuggingFaceTextGenInference.from_docker(model_name, host='0.0.0.0', port=8081, shm_size="1g") | |
# container = AutoHuggingFaceTextGenInference._get_container(client, model_name) | |
# assert container.status == 'running' | |
# | |
# # Change some parameters | |
# new_port = 8082 | |
# new_shm_size = '2g' | |
# new_generation_params = HFTextGenParams(host='0.0.0.0', port=new_port, max_new_tokens=256) | |
# | |
# # Recreate the container with new parameters | |
# llm = AutoHuggingFaceTextGenInference.from_docker( | |
# model_name, host='0.0.0.0', port=new_port, shm_size=new_shm_size) | |
# | |
# # Ensure the container is running with new parameters | |
# container = AutoHuggingFaceTextGenInference._get_container(client, model_name) | |
# assert container.status == 'running' | |
# assert container.attrs['HostConfig']['ShmSize'] == 2147483648 | |
# | |
# # Shutdown the container | |
# AutoHuggingFaceTextGenInference.shutdown(model_name) | |
# container = AutoHuggingFaceTextGenInference._get_container(client, model_name) | |
# assert container is None # Container should no longer exist | |
# | |
# | |
# def test_invalid_environment(): | |
# # Make sure we raise an exception if the python docker library is not installed | |
# # and we try to use the AutoHuggingFaceTextGenInference | |
# from llm import docker | |
# llm.docker = None | |
# | |
# # Test invalid environment without the docker library installed | |
# with pytest.raises(EnvironmentError): | |
# AutoHuggingFaceTextGenInference.validate_environment() | |
# | |
# # Re-import to go back to original state | |
# import docker | |
# llm.docker = docker | |
# | |
# | |
# def test_health_check_timeout(): | |
# # Test health check timeout | |
# with pytest.raises(TimeoutError): | |
# AutoHuggingFaceTextGenInference.HEALTH_CHECK_TIMEOUT = 0 | |
# llm = AutoHuggingFaceTextGenInference.from_docker( | |
# model_name, host='will-never-connect', port=8081, shm_size="1g") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment