Last active
March 23, 2024 17:05
-
-
Save satpalsr/81816e60e6343cc2eb39f7e15ce84191 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
import os | |
import json | |
from modal import Image, Secret, Stub, method, web_endpoint, enter, gpu, exit, asgi_app | |
from fastapi import FastAPI, Request | |
from typing import Literal, Optional, List, Dict, Any, Union | |
import shortuuid | |
from pydantic import BaseModel, Field | |
class ErrorResponse(BaseModel): | |
object: str = "error" | |
message: str | |
code: int | |
class ModelPermission(BaseModel): | |
id: str = Field(default_factory=lambda: f"modelperm-{shortuuid.random()}") | |
object: str = "model_permission" | |
created: int = Field(default_factory=lambda: int(time.time())) | |
allow_create_engine: bool = False | |
allow_sampling: bool = True | |
allow_logprobs: bool = True | |
allow_search_indices: bool = True | |
allow_view: bool = True | |
allow_fine_tuning: bool = False | |
organization: str = "*" | |
group: Optional[str] = None | |
is_blocking: str = False | |
class ModelCard(BaseModel): | |
id: str | |
object: str = "model" | |
created: int = Field(default_factory=lambda: int(time.time())) | |
owned_by: str = "fastchat" | |
root: Optional[str] = None | |
parent: Optional[str] = None | |
permission: List[ModelPermission] = [] | |
class ModelList(BaseModel): | |
object: str = "list" | |
data: List[ModelCard] = [] | |
class UsageInfo(BaseModel): | |
prompt_tokens: int = 0 | |
total_tokens: int = 0 | |
completion_tokens: Optional[int] = 0 | |
class LogProbs(BaseModel): | |
text_offset: List[int] = Field(default_factory=list) | |
token_logprobs: List[Optional[float]] = Field(default_factory=list) | |
tokens: List[str] = Field(default_factory=list) | |
top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list) | |
class ChatCompletionRequest(BaseModel): | |
model: str | |
messages: Union[ | |
str, | |
List[Dict[str, str]], | |
List[Dict[str, Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]]], | |
] | |
temperature: Optional[float] = 0.7 | |
top_p: Optional[float] = 1.0 | |
top_k: Optional[int] = -1 | |
n: Optional[int] = 1 | |
max_tokens: Optional[int] = None | |
stop: Optional[Union[str, List[str]]] = None | |
stream: Optional[bool] = False | |
presence_penalty: Optional[float] = 0.0 | |
frequency_penalty: Optional[float] = 0.0 | |
user: Optional[str] = None | |
class ChatMessage(BaseModel): | |
role: str | |
content: str | |
class ChatCompletionResponseChoice(BaseModel): | |
index: int | |
message: ChatMessage | |
finish_reason: Optional[Literal["stop", "length"]] = None | |
class ChatCompletionResponse(BaseModel): | |
id: str = Field(default_factory=lambda: f"chatcmpl-{shortuuid.random()}") | |
object: str = "chat.completion" | |
created: int = Field(default_factory=lambda: int(time.time())) | |
model: str | |
choices: List[ChatCompletionResponseChoice] | |
usage: UsageInfo | |
class DeltaMessage(BaseModel): | |
role: Optional[str] = None | |
content: Optional[str] = None | |
class ChatCompletionResponseStreamChoice(BaseModel): | |
index: int | |
delta: DeltaMessage | |
finish_reason: Optional[Literal["stop", "length"]] = None | |
class ChatCompletionStreamResponse(BaseModel): | |
id: str = Field(default_factory=lambda: f"chatcmpl-{shortuuid.random()}") | |
object: str = "chat.completion.chunk" | |
created: int = Field(default_factory=lambda: int(time.time())) | |
model: str | |
choices: List[ChatCompletionResponseStreamChoice] | |
class TokenCheckRequestItem(BaseModel): | |
model: str | |
prompt: str | |
max_tokens: int | |
class TokenCheckRequest(BaseModel): | |
prompts: List[TokenCheckRequestItem] | |
class TokenCheckResponseItem(BaseModel): | |
fits: bool | |
tokenCount: int | |
contextLength: int | |
class TokenCheckResponse(BaseModel): | |
prompts: List[TokenCheckResponseItem] | |
class EmbeddingsRequest(BaseModel): | |
model: Optional[str] = None | |
engine: Optional[str] = None | |
input: Union[str, List[Any]] | |
user: Optional[str] = None | |
encoding_format: Optional[str] = None | |
class EmbeddingsResponse(BaseModel): | |
object: str = "list" | |
data: List[Dict[str, Any]] | |
model: str | |
usage: UsageInfo | |
class CompletionRequest(BaseModel): | |
model: str | |
prompt: Union[str, List[Any]] | |
suffix: Optional[str] = None | |
temperature: Optional[float] = 0.7 | |
n: Optional[int] = 1 | |
max_tokens: Optional[int] = 16 | |
stop: Optional[Union[str, List[str]]] = None | |
stream: Optional[bool] = False | |
top_p: Optional[float] = 1.0 | |
top_k: Optional[int] = -1 | |
logprobs: Optional[int] = None | |
echo: Optional[bool] = False | |
presence_penalty: Optional[float] = 0.0 | |
frequency_penalty: Optional[float] = 0.0 | |
user: Optional[str] = None | |
use_beam_search: Optional[bool] = False | |
best_of: Optional[int] = None | |
class CompletionResponseChoice(BaseModel): | |
index: int | |
text: str | |
logprobs: Optional[LogProbs] = None | |
finish_reason: Optional[Literal["stop", "length"]] = None | |
class CompletionResponse(BaseModel): | |
id: str = Field(default_factory=lambda: f"cmpl-{shortuuid.random()}") | |
object: str = "text_completion" | |
created: int = Field(default_factory=lambda: int(time.time())) | |
model: str | |
choices: List[CompletionResponseChoice] | |
usage: UsageInfo | |
class CompletionResponseStreamChoice(BaseModel): | |
index: int | |
text: str | |
logprobs: Optional[LogProbs] = None | |
finish_reason: Optional[Literal["stop", "length"]] = None | |
class CompletionStreamResponse(BaseModel): | |
id: str = Field(default_factory=lambda: f"cmpl-{shortuuid.random()}") | |
object: str = "text_completion" | |
created: int = Field(default_factory=lambda: int(time.time())) | |
model: str | |
choices: List[CompletionResponseStreamChoice] | |
MODEL_DIR = "/model-gemma-telugulabs" | |
BASE_MODEL="Telugu-LLM-Labs/Indic-gemma-7b-finetuned-sft-Navarasa-2.0" | |
GPU_CONFIG = gpu.A100(memory=40, count=1) | |
def download_model_to_folder(): | |
from huggingface_hub import snapshot_download | |
from transformers.utils import move_cache | |
os.makedirs(MODEL_DIR, exist_ok=True) | |
snapshot_download( | |
BASE_MODEL, | |
local_dir=MODEL_DIR, | |
token=os.environ["HUGGINGFACE_TOKEN"], | |
) | |
move_cache() | |
vllm_image = ( | |
Image.from_registry( | |
"nvidia/cuda:12.1.1-devel-ubuntu22.04", add_python="3.10" | |
) | |
.pip_install( | |
"vllm==0.3.2", | |
"huggingface_hub==0.19.4", | |
"hf-transfer==0.1.4", | |
"torch==2.1.2", | |
"shortuuid", | |
) | |
.run_commands("pip install pydantic==1.10.11") | |
.env({"HF_HUB_ENABLE_HF_TRANSFER": "1"}) | |
.run_function( | |
download_model_to_folder, | |
secrets=[Secret.from_name("huggingface")], | |
timeout=60 * 10, | |
) | |
) | |
stub = Stub("vllm-server-gemma-telugulabs") | |
@stub.cls( | |
gpu=GPU_CONFIG, | |
timeout=60 * 5, | |
container_idle_timeout=60 * 5, | |
allow_concurrent_inputs=15, | |
image=vllm_image, | |
) | |
class Model: | |
@enter() | |
def start_engine(self): | |
import time | |
from vllm.engine.arg_utils import AsyncEngineArgs | |
from vllm.engine.async_llm_engine import AsyncLLMEngine | |
print("🥶 cold starting inference") | |
start = time.monotonic_ns() | |
if GPU_CONFIG.count > 1: | |
# Patch issue from https://github.com/vllm-project/vllm/issues/1116 | |
import ray | |
ray.shutdown() | |
ray.init(num_gpus=GPU_CONFIG.count) | |
engine_args = AsyncEngineArgs( | |
model=MODEL_DIR, | |
tensor_parallel_size=GPU_CONFIG.count, | |
gpu_memory_utilization=0.90, | |
enforce_eager=False, # capture the graph for faster inference, but slower cold starts | |
disable_log_stats=True, # disable logging so we can stream tokens | |
disable_log_requests=True, | |
) | |
# this can take some time! | |
self.engine = AsyncLLMEngine.from_engine_args(engine_args) | |
duration_s = (time.monotonic_ns() - start) / 1e9 | |
print(f"🏎️ engine started in {duration_s:.0f}s") | |
def chatml_format(self, messages): | |
prompt = """<|im_start|>system | |
You are a helpful AI assistant created by Bhabha AI.<|im_end|> | |
""" | |
for message in messages: | |
if message["role"] == "system": | |
continue | |
if message["role"] == "user": | |
prompt += f"|<im_start|>user\n{message['content']}<|im_end|>\n|<im_start|>assistant\n" | |
elif message["role"] == "assistant": | |
prompt += f"{message['content']}<|im_end|>\n" | |
return prompt | |
def alpaca_format(self, messages): | |
prompt = "" | |
for message in messages: | |
if message["role"] == "system": | |
continue | |
if message["role"] == "user": | |
prompt += f"### Instruction: {message['content']}\n\n## Response:" | |
elif message["role"] == "assistant": | |
prompt += f"{message['content']}\n\n" | |
return prompt | |
@method() | |
async def completion_stream(self, request_data): | |
from vllm import SamplingParams | |
from vllm.utils import random_uuid | |
request_model_name = request_data["model"] | |
messages = request_data["messages"] | |
input_prompt = self.alpaca_format(messages) | |
print("Input prompt", input_prompt) | |
sampling_params = SamplingParams() | |
for key, value in request_data.items(): | |
if hasattr(sampling_params, key): | |
setattr(sampling_params, key, value) | |
print("Sampling params", sampling_params) | |
request_id = random_uuid() | |
result_generator = self.engine.generate( | |
input_prompt, | |
sampling_params, | |
request_id, | |
) | |
index, num_tokens = 0, 0 | |
start = time.monotonic_ns() | |
yield "data: " + ChatCompletionStreamResponse( | |
model=request_model_name, | |
choices=[ | |
ChatCompletionResponseStreamChoice( | |
index=0, | |
delta=DeltaMessage(role="assistant"), | |
finish_reason=None, | |
) | |
], | |
).json(exclude_unset=True, ensure_ascii=False) + "\n\n" | |
async for output in result_generator: | |
if ( | |
output.outputs[0].text | |
and "\ufffd" == output.outputs[0].text[-1] | |
): | |
continue | |
text_delta = output.outputs[0].text[index:] | |
index = len(output.outputs[0].text) | |
num_tokens = len(output.outputs[0].token_ids) | |
yield "data: " + ChatCompletionStreamResponse( | |
model=request_model_name, | |
choices=[ | |
ChatCompletionResponseStreamChoice( | |
index=0, | |
delta=DeltaMessage(content=text_delta), | |
finish_reason=None, | |
) | |
], | |
).json(exclude_unset=True, ensure_ascii=False) + "\n\n" | |
duration_s = (time.monotonic_ns() - start) / 1e9 | |
yield "data: " + ChatCompletionStreamResponse( | |
model=request_model_name, | |
choices=[ | |
ChatCompletionResponseStreamChoice( | |
index=0, | |
delta=DeltaMessage(content=""), | |
finish_reason="stop", | |
) | |
], | |
).json(exclude_unset=True, ensure_ascii=False) + "\n\n" | |
yield "data: [DONE]\n\n" | |
@exit() | |
def stop_engine(self): | |
if GPU_CONFIG.count > 1: | |
import ray | |
ray.shutdown() | |
web_app = FastAPI() | |
# @stub.function() | |
@web_app.post("/v1/chat/completions") | |
def web(request: ChatCompletionRequest): | |
from fastapi.responses import StreamingResponse | |
request_data = json.loads(request.json()) | |
model = Model() | |
print("Sending new request:", request_data, "\n\n") | |
return StreamingResponse( | |
model.completion_stream.remote_gen(request_data), | |
media_type="text/event-stream", | |
) | |
@stub.function(image=vllm_image, allow_concurrent_inputs=15) | |
@asgi_app() | |
def fastapi_app(): | |
return web_app |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment