Skip to content

Instantly share code, notes, and snippets.

@dcbark01
Last active October 28, 2023 01:41
Show Gist options
  • Save dcbark01/dd5ca1f824145d4d44885031deaa9fa1 to your computer and use it in GitHub Desktop.
Save dcbark01/dd5ca1f824145d4d44885031deaa9fa1 to your computer and use it in GitHub Desktop.
AutoHuggingFaceTextGenInference for Langchain PR
# TODO: Add this to PR for Langchain so that it will be easy to use across all our different LLM projects
import re
import time
import warnings
from pathlib import Path
from typing import List, Union, Optional
import requests
from tqdm import tqdm
from pydantic import BaseModel, Field, field_validator, computed_field
from langchain.llms import HuggingFaceTextGenInference
try:
import docker
except ImportError:
# Allow module to be imported regardless of whether user has python docker library installed.
# We'll raise an error at runtime if they try to actually use any of the methods in
# the class where we implement all our logic.
docker = None
SERVICE_NAME = "auto-hftextgen"
CACHE_DIR = Path().expanduser() / ".cache"
MODELS_DIR = CACHE_DIR / "models"
# Filter annoying warning about protected namespace in pydantic. This isn't caused
# by our case (it's from the text-generation import) so nothing we can do about it.
warnings.filterwarnings("ignore", category=UserWarning, module="pydantic")
def get_image_tag(image: str):
""" Return tag from docker image_name:tag """
if ":" not in image:
return ""
else:
try:
tag = image.split(":")[1]
return tag
except IndexError:
return ""
def convert_shm_size(shm_size_str: str) -> int:
"""
Convert shared memory size from string (e.g., "1g") to integer (e.g., 1073741824).
Args:
shm_size_str (str): Shared memory size as a string.
Returns:
int: Shared memory size as an integer.
"""
if shm_size_str.isdigit():
return int(shm_size_str)
if shm_size_str.endswith('g') or shm_size_str.endswith('G'):
return int(shm_size_str[:-1]) * 1024 * 1024 * 1024
if shm_size_str.endswith('m') or shm_size_str.endswith('M'):
return int(shm_size_str[:-1]) * 1024 * 1024
if shm_size_str.endswith('k') or shm_size_str.endswith('K'):
return int(shm_size_str[:-1]) * 1024
raise ValueError(f"Invalid shm_size format: {shm_size_str}")
def _make_container_name_from_hf_model_name(model_name: str):
model_name_sanitized = model_name.replace("/", "--") # Docker container names can't have / in them
container_name = f'{SERVICE_NAME}--{model_name_sanitized}'
return container_name
def get_container(client, model_name):
container_name = _make_container_name_from_hf_model_name(model_name)
for container in client.containers.list():
if container_name == container.name:
return container
return None
def stop_container(model_name: str, timeout=30):
""" Stop a container running one of our models. """
client = docker.from_env()
container = get_container(client, model_name)
if container is not None:
print(f'Stopping container for {model_name}')
container.stop()
else:
print(f'No running container found for {model_name}')
if not container:
return
try:
container.wait(timeout=timeout)
except docker.errors.NotFound:
# Container is already removed, it has shut down
return
except docker.errors.APIError as e:
if "container not found" in str(e).lower():
# Container is not found, it has shut down
return
else:
raise
def remove_container(model_name):
client = docker.from_env()
container = get_container(client, model_name)
if container is not None:
print(f'Removing container for {model_name}')
container.remove()
else:
print(f'No container found for {model_name}')
def start_model_container(
model_name: str,
volume_path: Union[str, Path],
port: int,
shm_size: int,
image_tag: str = '1.1.0',
):
volume_path = volume_path if volume_path is not None else MODELS_DIR
volume_path = volume_path if isinstance(volume_path, str) else str(volume_path)
client = docker.from_env()
# Check if the container is already running
container = get_container(client, model_name)
if container is not None:
print(f'Container for {model_name} is already running')
return container
# If not, start a new container
container_name = _make_container_name_from_hf_model_name(model_name)
print(f"Starting container: '{container_name}'")
host_port = port
container_port = 80 # Hardcoding this
container = client.containers.run(
f'ghcr.io/huggingface/text-generation-inference:{image_tag}',
f'--model-id {model_name}',
name=container_name,
volumes={volume_path: {'bind': '/data', 'mode': 'rw'}},
ports={f'{container_port}/tcp': host_port},
shm_size=shm_size,
detach=True,
runtime='nvidia',
remove=True,
)
return container
class _SharedParams(BaseModel):
""" Things that both HFTextGen Docker container and Langchain HFTextGen wrapper need to share. """
port: int = 8080 # This should be the port on the HOST machine
def dict(self, *args, **kwargs) -> 'DictStrAny':
raise DeprecationWarning(
"Method `dict` is pending deprecation in pydantic. For the _SharedParams model and any child classes,"
"to cast to a dict use our `to_dict` method instead."
)
def to_dict(self, **kwargs) -> dict:
exclude = kwargs.get('exclude', set())
return super().model_dump(exclude=exclude)
class DockerParams(_SharedParams):
pretrained_model_name: str # Would just use `model_name` but pydantic throws a protected value error.
volume_path: str = str(MODELS_DIR)
image_tag: str = '1.1.0'
shm_size: Optional[Union[int, str]] = 1073741824 # 1gb
@field_validator('shm_size')
@classmethod
def convert_str_to_int(cls, value: Union[int, str]):
if isinstance(value, int):
return value
else:
return convert_shm_size(value)
@classmethod
def from_container(cls, container):
""" Given an existing container, get the params used to construct it. """
c = container.attrs
# Get model name
cmd_args: list = c['Args']
i = cmd_args.index("--model-id") # Required arg, so should be safe
model_name = cmd_args[i + 1]
# Get volume details
# This means identity will be formulated based on first volume only; think that's okay for our use case here.
# We shouldn't need more than one volume anyway.
# Note this also implicitly requires that a volume always be present; will fail otherwise,
# and we're not currently handling that.
binds = c["HostConfig"]["Binds"][0]
volume_path = binds.split(":")[0]
# Get ports
host_port = int(c["HostConfig"]["PortBindings"]["80/tcp"][0]["HostPort"])
# Get shared memory size and misc items
shm_size = c["HostConfig"]["ShmSize"]
image_tag = get_image_tag(c['Config']['Image'])
return DockerParams(
port=host_port,
pretrained_model_name=model_name,
volume_path=volume_path,
image_tag=image_tag,
shm_size=shm_size,
)
class HFTextGenParams(_SharedParams):
""" Configuration params container for Langchain HuggingFaceTextGenInference wrapper.
See docs here for details of available generation params:
https://python.langchain.com/docs/integrations/llms/huggingface_textgen_inference
"""
host: str = 'localhost'
max_new_tokens: int = 512
top_k: Optional[int] = None
top_p: Optional[float] = 0.95
typical_p: Optional[float] = 0.95
temperature: Optional[float] = 0.8
repetition_penalty: Optional[float] = None
return_full_text: bool = False
truncate: Optional[int] = None
stop_sequences: List[str] = Field(default_factory=list)
seed: Optional[int] = None
timeout: int = 120
streaming: bool = False
do_sample: bool = False
watermark: bool = False
@computed_field
@property
def inference_server_url(self) -> str:
url = f"http://{self.host}:{self.port}"
return url
def to_dict(self):
""" Get dict of valid args for HFTextGen instantiation.
Since Langchain HuggingFaceTextGenInference only accepts an 'inference_server_url`
arg (not host/port like we're using) we need to override dict so that when
we call it based on our schema here it pops the host/port args and
only include the inference server URL, which we derive from host/port.
"""
return super().to_dict(exclude={"host", "port"})
class AutoHuggingFaceTextGenInference:
HEALTH_CHECK_TIMEOUT = 30
DOWNLOAD_TIMEOUT = 600 # 10 minutes; will probably have to raise this for larger models
@staticmethod
def validate_environment():
if docker is None:
raise EnvironmentError(
"Python `docker` library not found in current environment. "
"Please `pip install docker` if you want to use AutoHuggingFaceTextGenInference."
)
@staticmethod
def _get_container(client, model_name):
return get_container(client, model_name)
@staticmethod
def _container_exists_with_same_params(container, docker_params: DockerParams) -> bool:
"""
Check if a container exists with the same parameters.
Compare the container's configuration (image, volumes, etc.) with the new DockerParams.
"""
if container is None:
return False
# Get the exiting container's information
existing_params = DockerParams.from_container(container)
return docker_params == existing_params
@classmethod
def _wait_for_model_download(cls, container):
""" Wait for the model download to complete by monitoring container logs """
if container is None:
raise ValueError(f"Container for '{model_name}' not found.")
# Define regular expressions to match relevant log messages
download_pattern = r"Downloaded .* in (\d+:\d+:\d+)."
success_pattern = r"Successfully downloaded weights."
# Initialize a progress bar
with tqdm(total=cls.DOWNLOAD_TIMEOUT, desc="Waiting for download (time left before timeout)") as pbar:
# Keep checking the container logs until the download is complete
start_time = time.time()
download_complete = False
while not download_complete and time.time() - start_time < cls.DOWNLOAD_TIMEOUT:
logs = container.logs().decode("utf-8")
# Search for patterns in the logs
download_match = re.search(download_pattern, logs)
success_match = re.search(success_pattern, logs)
if download_match and success_match:
download_time = download_match.group(1)
print(f"Model download completed in {download_time}")
download_complete = True
else:
time.sleep(5) # Wait for 5 seconds before checking again
pbar.update(5) # Update the progress bar
if not download_complete:
raise TimeoutError(f"Model download for '{model_name}' timed out.")
@classmethod
def _wait_for_status(cls, container, inference_server_url: str):
""" Wait for the container and inference API to become healthy """
# First handle waiting for download to complete, if it isn't cached
cls._wait_for_model_download(container)
is_container_ready = False
is_api_ready = False
pbar = ".."
start_time = time.time()
while time.time() - start_time < cls.HEALTH_CHECK_TIMEOUT:
pbar += "."
if is_container_ready and is_api_ready:
break
if container.status in ["running", "created"]:
is_container_ready = True
# Add a health check for the /health API
health_check_start_time = time.time()
health_check_url = f'{inference_server_url}/health'
while time.time() - health_check_start_time < cls.HEALTH_CHECK_TIMEOUT:
pbar += "."
msg = (
f"Waiting for inference server to start up "
f"(CONTAINER={int(is_container_ready)}, "
f"API={int(is_api_ready)})"
)
msg += pbar
print(msg, flush=True)
if is_api_ready:
break
try:
response = requests.get(health_check_url)
if response.status_code == 200:
is_api_ready = True
except requests.ConnectionError:
pass
time.sleep(1)
time.sleep(1)
if is_container_ready and is_api_ready:
return container
else:
raise TimeoutError(
f"Container startup timeout. "
f"Did not receive healthy status within {cls.HEALTH_CHECK_TIMEOUT} seconds."
f"\tIs container ready: {is_container_ready}",
f"\tIs inference API ready: {is_api_ready}"
)
@classmethod
def _create_or_recreate_container(cls, client, model_name: str, docker_params: DockerParams):
"""
Create or recreate a container with the given parameters.
Check if a container with the same name and the same parameters exists.
If it exists, remove it and create a new one. If not, create a new one.
"""
container = get_container(client, model_name)
if not cls._container_exists_with_same_params(container, docker_params):
# User wants to switch to a new configuration; shutdown and remove old container
cls.shutdown(model_name)
# Create a new container with the updated parameters
d_params_docker = docker_params.to_dict()
d_params_docker['model_name'] = d_params_docker.pop('pretrained_model_name') # Keeps pydantic happy
container = start_model_container(**d_params_docker)
return container
@classmethod
def _from_params(
cls,
client,
generation_params: HFTextGenParams,
docker_params: DockerParams,
) -> HuggingFaceTextGenInference:
# Check if a container with the same name and parameters exists, and recreate if needed
container = cls._create_or_recreate_container(client, docker_params.pretrained_model_name, docker_params)
# Wait for healthy status based on running container and healthcheck of inference API
cls._wait_for_status(container, generation_params.inference_server_url)
# Create the langchain wrapper for the server once the server is up and running
d_params_generation = generation_params.to_dict()
llm = HuggingFaceTextGenInference(**d_params_generation)
return llm
@classmethod
def shutdown(cls, model_name: str):
""" Stops the container and removes it.
Since we've hardcoded `remove=True` in the startup method, shutting down
automatically removes the container (doing it this way gives up some
flexibility, but makes state management a lot easier, which we really
want to avoid doing from within our python code, and let docker handle
as much container state management as possible).
"""
cls.validate_environment()
stop_container(model_name)
@classmethod
def from_docker(cls, model_name: str, **kwargs) -> HuggingFaceTextGenInference:
""" Expose method to create an instance of HF Textgen server by spinning up docker container. """
cls.validate_environment()
kwargs = {**{'pretrained_model_name': model_name}, **kwargs}
client = docker.from_env()
docker_params = DockerParams(**kwargs)
generation_params = HFTextGenParams(**kwargs)
return cls._from_params(client, generation_params, docker_params)
if __name__ == """__main__""":
""" Example usage """
# model_name = 'tiiuae/falcon-7b-instruct'
model_name = 'bigscience/bloom-560m'
llm = AutoHuggingFaceTextGenInference.from_docker(model_name, host='0.0.0.0', port=8080, shm_size="1g")
answer = llm("How old is the universe?")
print(answer)
# import pytest
#
# import docker
#
# from llm import (
# AutoHuggingFaceTextGenInference,
# HuggingFaceTextGenInference,
# DockerParams,
# HFTextGenParams,
# )
#
# # Define some test data
# model_name = 'bigscience/bloom-560m'
# generation_params = HFTextGenParams(host='0.0.0.0', port=8081, max_new_tokens=512)
#
# # Start from pristine state with none of our containers running
# client = docker.from_env()
# for container in client.containers.list():
# if container.name.startswith(SERVICE_NAME):
# container.stop()
# container.wait(timeout=30)
#
#
# @pytest.mark.slow
# def test_startup_shutdown():
# import docker
#
# client = docker.from_env()
#
# # Test starting and stopping a container
# llm = AutoHuggingFaceTextGenInference.from_docker(model_name, host='0.0.0.0', port=8081, shm_size="1g")
# assert isinstance(llm, HuggingFaceTextGenInference)
#
# # Ensure the container is running
# container = AutoHuggingFaceTextGenInference._get_container(client, model_name)
# assert container.status == 'running'
#
# # Shutdown the container
# AutoHuggingFaceTextGenInference.shutdown(model_name)
# container = AutoHuggingFaceTextGenInference._get_container(client, model_name)
# assert container is None # Container should no longer exist
#
#
# @pytest.mark.slow
# def test_recreate_container():
# # Test recreating a container with the same parameters
# import docker
# client = docker.from_env()
#
# # Start the container initially
# llm = AutoHuggingFaceTextGenInference.from_docker(model_name, host='0.0.0.0', port=8081, shm_size="1g")
# container = AutoHuggingFaceTextGenInference._get_container(client, model_name)
# assert container.status == 'running'
#
# # Change some parameters
# new_port = 8082
# new_shm_size = '2g'
# new_generation_params = HFTextGenParams(host='0.0.0.0', port=new_port, max_new_tokens=256)
#
# # Recreate the container with new parameters
# llm = AutoHuggingFaceTextGenInference.from_docker(
# model_name, host='0.0.0.0', port=new_port, shm_size=new_shm_size)
#
# # Ensure the container is running with new parameters
# container = AutoHuggingFaceTextGenInference._get_container(client, model_name)
# assert container.status == 'running'
# assert container.attrs['HostConfig']['ShmSize'] == 2147483648
#
# # Shutdown the container
# AutoHuggingFaceTextGenInference.shutdown(model_name)
# container = AutoHuggingFaceTextGenInference._get_container(client, model_name)
# assert container is None # Container should no longer exist
#
#
# def test_invalid_environment():
# # Make sure we raise an exception if the python docker library is not installed
# # and we try to use the AutoHuggingFaceTextGenInference
# from llm import docker
# llm.docker = None
#
# # Test invalid environment without the docker library installed
# with pytest.raises(EnvironmentError):
# AutoHuggingFaceTextGenInference.validate_environment()
#
# # Re-import to go back to original state
# import docker
# llm.docker = docker
#
#
# def test_health_check_timeout():
# # Test health check timeout
# with pytest.raises(TimeoutError):
# AutoHuggingFaceTextGenInference.HEALTH_CHECK_TIMEOUT = 0
# llm = AutoHuggingFaceTextGenInference.from_docker(
# model_name, host='will-never-connect', port=8081, shm_size="1g")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment