Skip to content

Instantly share code, notes, and snippets.

@dranger003
Created May 25, 2024 22:51
Show Gist options
  • Save dranger003/845739ac3a64f49d608e9bb39317dbf5 to your computer and use it in GitHub Desktop.
Save dranger003/845739ac3a64f49d608e9bb39317dbf5 to your computer and use it in GitHub Desktop.
Phi-3-Vision-128K-Instruct Stream
# run_model_stream.py
import sys
import re
import os
import queue
import threading
import torch
from PIL import Image
from transformers import (
AutoModelForCausalLM,
AutoProcessor,
# BitsAndBytesConfig, # Uncomment if you want to run quantized (see below, i.e. quantization_config)
TextStreamer,
)
class TextStreamerEx(TextStreamer):
def __init__(self, tokenizer, output):
super().__init__(
tokenizer,
skip_prompt=False,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)
self.output = output
def put(self, value):
if len(value.shape) > 1:
return
super().put(value)
def on_finalized_text(self, text, stream_end=False):
self.output.put(text)
if stream_end:
self.output.put(None)
class Model:
def __init__(self, model_id, device_index=0):
self.model_id = model_id
self.index = device_index
self.loaded = False
def load(self):
self.processor = AutoProcessor.from_pretrained(
self.model_id, trust_remote_code=True
)
self.model = AutoModelForCausalLM.from_pretrained(
self.model_id,
device_map=f"cuda:{self.index}",
trust_remote_code=True,
torch_dtype=torch.bfloat16,
# quantization_config=BitsAndBytesConfig(load_in_4bit=True), # Uncomment for quantize 4bit, or;
# quantization_config=BitsAndBytesConfig(load_in_8bit=True), # Uncomment for quantize 8bit
)
self.loaded = True
def run(self, input, output):
image, text = input
images = None
if image is not None:
images = [image]
input_ids = self.processor(images=images, text=text, return_tensors="pt").to(
self.model.device
)
_ = self.model.generate(
**input_ids,
eos_token_id=self.processor.tokenizer.eos_token_id,
max_new_tokens=4096,
do_sample=False,
streamer=TextStreamerEx(self.processor.tokenizer, output),
)
def prompt(self, image, text):
output = queue.Queue()
thread = threading.Thread(target=self.run, args=((image, text), output))
thread.start()
while True:
text = output.get()
if text is None:
break
yield text
thread.join()
def get_image_path(line):
match = re.search(r"^`(.*?)`", line)
if match:
return match.group(1), match.start(), match.end()
return None, None, None
def main():
print("Loading model...")
model = Model("microsoft/Phi-3-vision-128k-instruct")
model.load()
print("Model loaded.")
print("Press <enter> on a blank line to quit.")
print("To load an image, prefix your prompt with the image path using backticks, for example: `~/tmp/image.png`Describe the image.")
print()
while True:
print("> ", end="", flush=True)
line = sys.stdin.readline().strip()
if not line:
break
image = None
image_path, start, end = get_image_path(line)
if image_path is not None:
image_path = os.path.expandvars(os.path.expanduser(image_path))
if os.path.exists(image_path):
image = Image.open(image_path)
if image is not None:
prompt = f"<|image_1|>\n{line[:start] + line[end:]}"
else:
prompt = line
templatized_prompt = model.processor.tokenizer.apply_chat_template(
[{"role": "user", "content": prompt}],
tokenize=False,
add_generation_prompt=True,
)
for text in model.prompt(image, templatized_prompt):
print(text, end="", flush=True)
print(flush=True)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment