Skip to content

Instantly share code, notes, and snippets.

@kishida
Created August 21, 2023 03:08
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kishida/11f09fe9f8494be30070c016b4837ae7 to your computer and use it in GitHub Desktop.
Save kishida/11f09fe9f8494be30070c016b4837ae7 to your computer and use it in GitHub Desktop.
Stability AIの画像言語モデル用UI
import torch
from transformers import LlamaTokenizer, AutoModelForVision2Seq, BlipImageProcessor
from PIL import Image
import requests
# need Set PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
# load model
model_name = "stabilityai/japanese-instructblip-alpha"
model = AutoModelForVision2Seq.from_pretrained(model_name,load_in_8bit=True, trust_remote_code=True)
processor = BlipImageProcessor.from_pretrained(model_name)
tokenizer = LlamaTokenizer.from_pretrained("novelai/nerdstash-tokenizer-v1", additional_special_tokens=['▁▁'])
device = "cuda" if torch.cuda.is_available() else "cpu"
#model.to(device)
print ("model loaded")
# prepare inputs
url = "https://images.unsplash.com/photo-1582538885592-e70a5d7ab3d3?ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D&auto=format&fit=crop&w=1770&q=80"
image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
# helper function to format input prompts
def build_prompt(prompt="", sep="\n\n### "):
sys_msg = "以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。"
p = sys_msg
roles = ["指示", "応答"]
user_query = "与えられた画像について、詳細に述べてください。"
msgs = [": \n" + user_query, ": "]
if prompt:
roles.insert(1, "入力")
msgs.insert(1, ": \n" + prompt)
for role, msg in zip(roles, msgs):
p += sep + role + msg
return p
def load_image(img):
global image
image = img
def generate(prompt):
#prompt = "" # input empty string for image captioning. You can also input questions as prompts
prompt = build_prompt(prompt)
inputs = processor(images=image, return_tensors="pt")
text_encoding = tokenizer(prompt, add_special_tokens=False, return_tensors="pt")
text_encoding["qformer_input_ids"] = text_encoding["input_ids"].clone()
text_encoding["qformer_attention_mask"] = text_encoding["attention_mask"].clone()
inputs.update(text_encoding)
# generate
outputs = model.generate(
**inputs.to(device, dtype=model.dtype),
num_beams=5,
max_new_tokens=32,
min_length=1,
pad_token_id=tokenizer.pad_token_id,
)
generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0].strip()
return generated_text
import gradio as gr
with gr.Blocks() as demo:
gr.Markdown("## Stability AI multi modal model")
imgIn = gr.Image(lambda: image, type="pil")
imgIn.change(load_image, inputs=imgIn)
with gr.Row():
with gr.Column():
question = gr.Textbox(lines=3, placeholder="質問を")
submit = gr.Button("Submit", variant="primary")
with gr.Row():
default = gr.Button("Default")
clear = gr.Button("Clear")
default.click(lambda: "画像を説明して", outputs=question)
clear.click(lambda: "", outputs=question)
answer = gr.Textbox(lines=3)
submit.click(generate, inputs=question, outputs=answer)
demo.launch()
@kishida
Copy link
Author

kishida commented Aug 21, 2023

bandicam.2023-08-21.11-09-51-708.mp4

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment