Skip to content

Instantly share code, notes, and snippets.

@CoderCowMoo
Last active April 18, 2024 11:42
Show Gist options
  • Save CoderCowMoo/735e84e35ca3b68a1125f738bf72f096 to your computer and use it in GitHub Desktop.
Save CoderCowMoo/735e84e35ca3b68a1125f738bf72f096 to your computer and use it in GitHub Desktop.
import torch
import transformers
import gradio as gr
import PIL
from threading import Thread
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
import warnings
# disable some warnings
transformers.logging.set_verbosity_error()
transformers.logging.disable_progress_bar()
warnings.filterwarnings('ignore')
# set device
torch.set_default_device('cuda') # or 'cpu'
torch.set_default_tensor_type('torch.cuda.FloatTensor')
# create model
model = AutoModelForCausalLM.from_pretrained(
'qnguyen3/nanoLLaVA',
torch_dtype=torch.float16,
device_map='auto',
trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(
'qnguyen3/nanoLLaVA',
trust_remote_code=True)
def answer_question(img: PIL.Image.Image, prompt: str):
# nanoLLaVA prompt tokenization stuff
messages = [
{"role": "user", "content": f'<image>\n{prompt}'}
]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True, skip_prompt=True)
text_chunks = [tokenizer(chunk).input_ids for chunk in text.split('<image>')]
input_ids = torch.tensor(text_chunks[0] + [-200] + text_chunks[1], dtype=torch.long).unsqueeze(0)
image_tensor = model.process_images([img], model.config).to(dtype=model.dtype)
# generate
thread = Thread(
target=model.generate,
kwargs = {
"input_ids": input_ids,
"images": image_tensor,
"max_new_tokens": 2048,
"use_cache": True,
"streamer": streamer
},
)
thread.start()
buf = ""
for new_text in streamer:
#buf += tokenizer.decode(new_text[input_ids.shape[1]:], skip_special_tokens=True).strip()
buf += new_text
yield buf
with gr.Blocks() as demo:
gr.Markdown(
"""
# NanoLLaVA
### A tiny vision language model. [HuggingFace 🤗](https://huggingface.co/qnguyen3/nanoLLaVA)
"""
)
with gr.Row():
prompt = gr.Textbox(label="Input Prompt", placeholder="Type here...", scale=4)
submit = gr.Button("Submit")
with gr.Row():
img = gr.Image(type="pil", label="Upload an Image")
output = gr.TextArea(label="Response")
submit.click(answer_question, [img, prompt], output)
prompt.submit(answer_question, [img, prompt], output)
demo.queue().launch(debug=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment