Skip to content

Instantly share code, notes, and snippets.

@kishida
Last active August 2, 2023 18:10
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kishida/ee107e002546ce2ab26c5d9a6eac33b1 to your computer and use it in GitHub Desktop.
Save kishida/ee107e002546ce2ab26c5d9a6eac33b1 to your computer and use it in GitHub Desktop.
rinna 画像対話モデルを使うためのUI
import torch
import requests
from PIL import Image
from minigpt4.processors.blip_processors import Blip2ImageEvalProcessor
from customized_mini_gpt4 import CustomizedMiniGPT4
ckpt_path = "./checkpoint.pth"
model = CustomizedMiniGPT4(gpt_neox_model="rinna/bilingual-gpt-neox-4b")
tokenizer = model.gpt_neox_tokenizer
if torch.cuda.is_available():
model = model.to("cuda")
if ckpt_path is not None:
print("Load BLIP2-LLM Checkpoint: {}".format(ckpt_path))
ckpt = torch.load(ckpt_path, map_location="cpu")
model.load_state_dict(ckpt['model'], strict=False)
vis_processor = Blip2ImageEvalProcessor()
image_url = "https://huggingface.co/rinna/bilingual-gpt-neox-4b-minigpt4/resolve/main/sample.jpg"
raw_image = Image.open(requests.get(image_url, stream=True).raw).convert('RGB')
ini_image = raw_image
image = vis_processor(raw_image).unsqueeze(0).to(model.device)
image_emb = model.encode_img(image)
import gradio as gr
def generate(input):
print (f"input comes")
prm = f"""ユーザー: <Img><ImageHere></Img> {input}
システム: """
embs = model.get_context_emb(prm, [image_emb])
output_ids = model.gpt_neox_model.generate(
inputs_embeds=embs,
max_new_tokens=512,
do_sample=True,
temperature=1.0,
top_p=0.85,
pad_token_id=tokenizer.pad_token_id,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id
)
output = tokenizer.decode(output_ids.tolist()[0], skip_special_tokens=True)
return output
def load_image(img):
print ("img comes")
global raw_image, image_emb
raw_image = img
image = vis_processor(img).unsqueeze(0).to(model.device)
image_emb = model.encode_img(image)
return img
def init_image(img):
return load_image(ini_image)
with gr.Blocks() as demo:
gr.Markdown("## multi modal rinna")
imgIn = gr.Image(lambda: raw_image, type="pil")
with gr.Row():
upload = gr.Button("Upload", variant="primary")
def_img = gr.Button("Default")
upload.click(load_image, inputs=imgIn)
def_img.click(init_image, outputs=imgIn)
with gr.Row():
with gr.Column():
question = gr.Textbox(lines=3, placeholder="質問を")
submit = gr.Button("Submit", variant="primary")
with gr.Row():
default = gr.Button("Default")
clear = gr.Button("Clear")
default.click(lambda: "画像を説明して", outputs=question)
clear.click(lambda: "", outputs=question)
answer = gr.Textbox(lines=3)
submit.click(generate, inputs=question, outputs=answer)
demo.launch()
@kishida
Copy link
Author

kishida commented Aug 1, 2023

bandicam.2023-08-01.09-59-40-360.mp4

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment