Skip to content

Instantly share code, notes, and snippets.

@dranger003
Last active September 21, 2024 10:35
Show Gist options
  • Save dranger003/daff444ebf04951d4279b5b2dee71ab4 to your computer and use it in GitHub Desktop.
Save dranger003/daff444ebf04951d4279b5b2dee71ab4 to your computer and use it in GitHub Desktop.
Phi-3-Vision-128K-Instruct Quick Local API
# server.py
# uvicorn server:app --reload
import base64
import queue
import threading
import torch
from PIL import Image
from io import BytesIO
from fastapi import FastAPI
from contextlib import asynccontextmanager
from typing import Optional
from pydantic import BaseModel
from fastapi.responses import StreamingResponse
from transformers import (
AutoModelForCausalLM,
AutoProcessor,
# BitsAndBytesConfig,
TextStreamer,
)
@asynccontextmanager
async def lifespan(app: FastAPI):
# Application startup
load(LoadRequest(num_gpus=1, models_per_gpu=1))
yield
# Application shutdown
app = FastAPI(lifespan=lifespan)
class TextStreamerEx(TextStreamer):
def __init__(self, tokenizer, output):
super().__init__(
tokenizer,
skip_prompt=False,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)
self.output = output
def put(self, value):
if len(value.shape) > 1:
return
super().put(value)
def on_finalized_text(self, text, stream_end=False):
self.output.put(text)
if stream_end:
self.output.put(None)
class Model:
def __init__(self, model_id, index):
self.model_id = model_id
self.index = index
self.loaded = False
def load(self):
self.processor = AutoProcessor.from_pretrained(
self.model_id, trust_remote_code=True
)
self.model = AutoModelForCausalLM.from_pretrained(
self.model_id,
device_map=f"cuda:{self.index}",
trust_remote_code=True,
torch_dtype=torch.bfloat16,
# quantization_config=BitsAndBytesConfig(load_in_8bit=True),
)
self.loaded = True
def run(self, input, output):
image, text = input
images = None
if image is not None:
images = [image]
input_ids = self.processor(images=images, text=text, return_tensors="pt").to(
self.model.device
)
_ = self.model.generate(
**input_ids,
eos_token_id=self.processor.tokenizer.eos_token_id,
max_new_tokens=4096,
do_sample=False,
repetition_penalty=1.2,
streamer=TextStreamerEx(self.processor.tokenizer, output),
)
def prompt(self, image, text):
output = queue.Queue()
thread = threading.Thread(target=self.run, args=((image, text), output))
thread.start()
while True:
text = output.get()
if text is None:
break
yield text
thread.join()
models = []
class LoadRequest(BaseModel):
num_gpus: int = 1
models_per_gpu: int = 1
@app.post("/load")
def load(request: LoadRequest):
for gpu_index in range(request.num_gpus):
for _ in range(request.models_per_gpu):
model = Model("microsoft/Phi-3-vision-128k-instruct", gpu_index)
model.load()
models.append(model)
return {"status": "200 OK"}
class PromptRequest(BaseModel):
model: int = 0
image: str = None
text: str
@app.post("/prompt")
def prompt(request: PromptRequest):
model = models[request.model]
if request.image is None:
image = None
prompt = f"{request.text}"
else:
image = Image.open(BytesIO(base64.b64decode(request.image)))
prompt = f"<|image_1|>\n{request.text}"
templatized_prompt = model.processor.tokenizer.apply_chat_template(
[{"role": "user", "content": prompt}],
tokenize=False,
add_generation_prompt=True,
)
# https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/sample_inference.py#L97
if templatized_prompt.endswith("<|endoftext|>"):
templatized_prompt = templatized_prompt.rstrip("<|endoftext|>")
def stream():
for text in model.prompt(image, templatized_prompt):
yield f"data: {base64.b64encode(text.encode('utf-8')).decode('utf-8')}\n\n"
return StreamingResponse(stream(), media_type="text/event-stream")
@dranger003
Copy link
Author

dranger003 commented May 25, 2024

C# client source code:

// Text prompt
await RunAsync(0, null, "Anyone home?", Console.Out);

// Image+text prompt
await RunAsync(0, @"Z:\image.png", "Describe the image.", Console.Out);

static async Task RunAsync(int modelIndex, string? path, string prompt, TextWriter writer, CancellationToken cancellationToken = default)
{
    using var httpClient = new HttpClient();

    var request = (dynamic)new { model = modelIndex, text = prompt };
    if (path != null)
    {
        request = new { request.model, image = Convert.ToBase64String(await File.ReadAllBytesAsync(path)), request.text };
    }

    using var response = await httpClient.SendAsync(
        new HttpRequestMessage { Method = HttpMethod.Post, RequestUri = new Uri("http://127.0.0.1:8000/prompt"), Content = JsonContent.Create(request) },
        HttpCompletionOption.ResponseHeadersRead,
        cancellationToken
    );

    await using var httpStream = await response.Content.ReadAsStreamAsync(cancellationToken);
    using var httpReader = new StreamReader(httpStream);

    while (!httpReader.EndOfStream && !cancellationToken.IsCancellationRequested)
    {
        var @event = await httpReader.ReadLineAsync(cancellationToken);
        if (@event == null)
            break;

        var data = Regex.Replace(@event, @"^data: |\n\n$", String.Empty);
        var text = Encoding.UTF8.GetString(Convert.FromBase64String(data));
        if (String.IsNullOrEmpty(text))
            continue;

        await writer.WriteAsync(text);
        await writer.FlushAsync(cancellationToken);
    }
}

@ChristianWeyer
Copy link

Nice!

@dranger003 Do you also have a version that can run phi-3 vision on macOS?

@MrJarnould
Copy link

Nice!

@dranger003 Do you also have a version that can run phi-3 vision on macOS?

+1

@dranger003
Copy link
Author

I don't have a Mac, but that code should run. Have you tried it?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment