Skip to content

Instantly share code, notes, and snippets.

@satpalsr
Last active March 23, 2024 17:05
Show Gist options
  • Save satpalsr/81816e60e6343cc2eb39f7e15ce84191 to your computer and use it in GitHub Desktop.
Save satpalsr/81816e60e6343cc2eb39f7e15ce84191 to your computer and use it in GitHub Desktop.
import time
import os
import json
from modal import Image, Secret, Stub, method, web_endpoint, enter, gpu, exit, asgi_app
from fastapi import FastAPI, Request
from typing import Literal, Optional, List, Dict, Any, Union
import shortuuid
from pydantic import BaseModel, Field
class ErrorResponse(BaseModel):
object: str = "error"
message: str
code: int
class ModelPermission(BaseModel):
id: str = Field(default_factory=lambda: f"modelperm-{shortuuid.random()}")
object: str = "model_permission"
created: int = Field(default_factory=lambda: int(time.time()))
allow_create_engine: bool = False
allow_sampling: bool = True
allow_logprobs: bool = True
allow_search_indices: bool = True
allow_view: bool = True
allow_fine_tuning: bool = False
organization: str = "*"
group: Optional[str] = None
is_blocking: str = False
class ModelCard(BaseModel):
id: str
object: str = "model"
created: int = Field(default_factory=lambda: int(time.time()))
owned_by: str = "fastchat"
root: Optional[str] = None
parent: Optional[str] = None
permission: List[ModelPermission] = []
class ModelList(BaseModel):
object: str = "list"
data: List[ModelCard] = []
class UsageInfo(BaseModel):
prompt_tokens: int = 0
total_tokens: int = 0
completion_tokens: Optional[int] = 0
class LogProbs(BaseModel):
text_offset: List[int] = Field(default_factory=list)
token_logprobs: List[Optional[float]] = Field(default_factory=list)
tokens: List[str] = Field(default_factory=list)
top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list)
class ChatCompletionRequest(BaseModel):
model: str
messages: Union[
str,
List[Dict[str, str]],
List[Dict[str, Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]]],
]
temperature: Optional[float] = 0.7
top_p: Optional[float] = 1.0
top_k: Optional[int] = -1
n: Optional[int] = 1
max_tokens: Optional[int] = None
stop: Optional[Union[str, List[str]]] = None
stream: Optional[bool] = False
presence_penalty: Optional[float] = 0.0
frequency_penalty: Optional[float] = 0.0
user: Optional[str] = None
class ChatMessage(BaseModel):
role: str
content: str
class ChatCompletionResponseChoice(BaseModel):
index: int
message: ChatMessage
finish_reason: Optional[Literal["stop", "length"]] = None
class ChatCompletionResponse(BaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-{shortuuid.random()}")
object: str = "chat.completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[ChatCompletionResponseChoice]
usage: UsageInfo
class DeltaMessage(BaseModel):
role: Optional[str] = None
content: Optional[str] = None
class ChatCompletionResponseStreamChoice(BaseModel):
index: int
delta: DeltaMessage
finish_reason: Optional[Literal["stop", "length"]] = None
class ChatCompletionStreamResponse(BaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-{shortuuid.random()}")
object: str = "chat.completion.chunk"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[ChatCompletionResponseStreamChoice]
class TokenCheckRequestItem(BaseModel):
model: str
prompt: str
max_tokens: int
class TokenCheckRequest(BaseModel):
prompts: List[TokenCheckRequestItem]
class TokenCheckResponseItem(BaseModel):
fits: bool
tokenCount: int
contextLength: int
class TokenCheckResponse(BaseModel):
prompts: List[TokenCheckResponseItem]
class EmbeddingsRequest(BaseModel):
model: Optional[str] = None
engine: Optional[str] = None
input: Union[str, List[Any]]
user: Optional[str] = None
encoding_format: Optional[str] = None
class EmbeddingsResponse(BaseModel):
object: str = "list"
data: List[Dict[str, Any]]
model: str
usage: UsageInfo
class CompletionRequest(BaseModel):
model: str
prompt: Union[str, List[Any]]
suffix: Optional[str] = None
temperature: Optional[float] = 0.7
n: Optional[int] = 1
max_tokens: Optional[int] = 16
stop: Optional[Union[str, List[str]]] = None
stream: Optional[bool] = False
top_p: Optional[float] = 1.0
top_k: Optional[int] = -1
logprobs: Optional[int] = None
echo: Optional[bool] = False
presence_penalty: Optional[float] = 0.0
frequency_penalty: Optional[float] = 0.0
user: Optional[str] = None
use_beam_search: Optional[bool] = False
best_of: Optional[int] = None
class CompletionResponseChoice(BaseModel):
index: int
text: str
logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length"]] = None
class CompletionResponse(BaseModel):
id: str = Field(default_factory=lambda: f"cmpl-{shortuuid.random()}")
object: str = "text_completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[CompletionResponseChoice]
usage: UsageInfo
class CompletionResponseStreamChoice(BaseModel):
index: int
text: str
logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length"]] = None
class CompletionStreamResponse(BaseModel):
id: str = Field(default_factory=lambda: f"cmpl-{shortuuid.random()}")
object: str = "text_completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[CompletionResponseStreamChoice]
MODEL_DIR = "/model-gemma-telugulabs"
BASE_MODEL="Telugu-LLM-Labs/Indic-gemma-7b-finetuned-sft-Navarasa-2.0"
GPU_CONFIG = gpu.A100(memory=40, count=1)
def download_model_to_folder():
from huggingface_hub import snapshot_download
from transformers.utils import move_cache
os.makedirs(MODEL_DIR, exist_ok=True)
snapshot_download(
BASE_MODEL,
local_dir=MODEL_DIR,
token=os.environ["HUGGINGFACE_TOKEN"],
)
move_cache()
vllm_image = (
Image.from_registry(
"nvidia/cuda:12.1.1-devel-ubuntu22.04", add_python="3.10"
)
.pip_install(
"vllm==0.3.2",
"huggingface_hub==0.19.4",
"hf-transfer==0.1.4",
"torch==2.1.2",
"shortuuid",
)
.run_commands("pip install pydantic==1.10.11")
.env({"HF_HUB_ENABLE_HF_TRANSFER": "1"})
.run_function(
download_model_to_folder,
secrets=[Secret.from_name("huggingface")],
timeout=60 * 10,
)
)
stub = Stub("vllm-server-gemma-telugulabs")
@stub.cls(
gpu=GPU_CONFIG,
timeout=60 * 5,
container_idle_timeout=60 * 5,
allow_concurrent_inputs=15,
image=vllm_image,
)
class Model:
@enter()
def start_engine(self):
import time
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
print("🥶 cold starting inference")
start = time.monotonic_ns()
if GPU_CONFIG.count > 1:
# Patch issue from https://github.com/vllm-project/vllm/issues/1116
import ray
ray.shutdown()
ray.init(num_gpus=GPU_CONFIG.count)
engine_args = AsyncEngineArgs(
model=MODEL_DIR,
tensor_parallel_size=GPU_CONFIG.count,
gpu_memory_utilization=0.90,
enforce_eager=False, # capture the graph for faster inference, but slower cold starts
disable_log_stats=True, # disable logging so we can stream tokens
disable_log_requests=True,
)
# this can take some time!
self.engine = AsyncLLMEngine.from_engine_args(engine_args)
duration_s = (time.monotonic_ns() - start) / 1e9
print(f"🏎️ engine started in {duration_s:.0f}s")
def chatml_format(self, messages):
prompt = """<|im_start|>system
You are a helpful AI assistant created by Bhabha AI.<|im_end|>
"""
for message in messages:
if message["role"] == "system":
continue
if message["role"] == "user":
prompt += f"|<im_start|>user\n{message['content']}<|im_end|>\n|<im_start|>assistant\n"
elif message["role"] == "assistant":
prompt += f"{message['content']}<|im_end|>\n"
return prompt
def alpaca_format(self, messages):
prompt = ""
for message in messages:
if message["role"] == "system":
continue
if message["role"] == "user":
prompt += f"### Instruction: {message['content']}\n\n## Response:"
elif message["role"] == "assistant":
prompt += f"{message['content']}\n\n"
return prompt
@method()
async def completion_stream(self, request_data):
from vllm import SamplingParams
from vllm.utils import random_uuid
request_model_name = request_data["model"]
messages = request_data["messages"]
input_prompt = self.alpaca_format(messages)
print("Input prompt", input_prompt)
sampling_params = SamplingParams()
for key, value in request_data.items():
if hasattr(sampling_params, key):
setattr(sampling_params, key, value)
print("Sampling params", sampling_params)
request_id = random_uuid()
result_generator = self.engine.generate(
input_prompt,
sampling_params,
request_id,
)
index, num_tokens = 0, 0
start = time.monotonic_ns()
yield "data: " + ChatCompletionStreamResponse(
model=request_model_name,
choices=[
ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(role="assistant"),
finish_reason=None,
)
],
).json(exclude_unset=True, ensure_ascii=False) + "\n\n"
async for output in result_generator:
if (
output.outputs[0].text
and "\ufffd" == output.outputs[0].text[-1]
):
continue
text_delta = output.outputs[0].text[index:]
index = len(output.outputs[0].text)
num_tokens = len(output.outputs[0].token_ids)
yield "data: " + ChatCompletionStreamResponse(
model=request_model_name,
choices=[
ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(content=text_delta),
finish_reason=None,
)
],
).json(exclude_unset=True, ensure_ascii=False) + "\n\n"
duration_s = (time.monotonic_ns() - start) / 1e9
yield "data: " + ChatCompletionStreamResponse(
model=request_model_name,
choices=[
ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(content=""),
finish_reason="stop",
)
],
).json(exclude_unset=True, ensure_ascii=False) + "\n\n"
yield "data: [DONE]\n\n"
@exit()
def stop_engine(self):
if GPU_CONFIG.count > 1:
import ray
ray.shutdown()
web_app = FastAPI()
# @stub.function()
@web_app.post("/v1/chat/completions")
def web(request: ChatCompletionRequest):
from fastapi.responses import StreamingResponse
request_data = json.loads(request.json())
model = Model()
print("Sending new request:", request_data, "\n\n")
return StreamingResponse(
model.completion_stream.remote_gen(request_data),
media_type="text/event-stream",
)
@stub.function(image=vllm_image, allow_concurrent_inputs=15)
@asgi_app()
def fastapi_app():
return web_app
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment