Created
July 4, 2023 15:28
-
-
Save lewtun/3f1d23d4748043d5443d33c6eacdd336 to your computer and use it in GitHub Desktop.
M4 inference
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from m4.training.packing import image_attention_mask_for_packed_input_ids, incremental_to_binary_attention_mask | |
from m4.training.utils import build_image_transform | |
from io import BytesIO | |
from PIL import Image | |
import requests | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
MAX_SEQ_LEN=2048 | |
# Local utils | |
def fetch_images(url_images): | |
images = [] | |
for url in url_images: | |
if isinstance(url, str): | |
images.append(Image.open(BytesIO(requests.get(url, stream=True).content))) | |
else: | |
images.append(url) | |
return images | |
url_images = ["https://raw.githubusercontent.com/salesforce/LAVIS/main/docs/_static/Confusing-Pictures.jpg"] | |
# Preparing inputs | |
images = fetch_images(url_images) | |
model_id = "HuggingFaceH4/m4-9b-ift" | |
revision = "v0.0" | |
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True, revision=revision) | |
model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True, revision=revision) | |
model.to(0) | |
# Few-shot example | |
# prompt = "<|system|>\n</s>\n<|user|>\n<fake_token_around_image><image><fake_token_around_image>Why do people find this image funny?</s>\n<|assistant|>\nThe picture depicts an orange cat lying on the ground with a sticky note affixed to its head. On the sticky note, a human face is drawn, and it looks like the cat appears to be frowning and judging its human.</s>\n<|user|>\n<fake_token_around_image><image><fake_token_around_image>Can you locate the water bottle placed near the yellow tennis balls on the ground?</s>\n<|assistant|>\nThere is no object resembling a water bottle in this image. This question introduces a new object that doesn't exist in the image</s>\n<|user|>\n<fake_token_around_image><image><fake_token_around_image>What is unusual about this image?</s>\n<|assistant|>\n" | |
# Zero-shot example | |
prompt = "<|system|>\n</s>\n<|user|><fake_token_around_image><image><fake_token_around_image>What is unusual about this image?</s>\n<|assistant|>\n" | |
tokens = tokenizer( | |
[prompt], | |
truncation=True, | |
max_length=MAX_SEQ_LEN, | |
padding=True, | |
add_special_tokens=False, | |
) | |
input_ids = torch.tensor([[tokenizer.bos_token_id] + tokens.input_ids[0]]) | |
attention_mask = torch.tensor([[1] + tokens.attention_mask[0]]) | |
image_attention_mask = [ | |
incremental_to_binary_attention_mask( | |
image_attention_mask_for_packed_input_ids(input_ids[0].unsqueeze(0), tokenizer)[0], num_classes=len(images) | |
) | |
] | |
image_transform = build_image_transform(eval=True) | |
pixel_values = [torch.stack([image_transform(img) for img in images])] | |
input_ids = input_ids.to(0) | |
attention_mask = attention_mask.to(0) | |
pixel_values = torch.stack(pixel_values).to(0) | |
image_attention_mask = torch.cat(image_attention_mask, 0).to(0) | |
inputs = {"input_ids": input_ids, "attention_mask": attention_mask, "pixel_values": pixel_values, "image_attention_mask": image_attention_mask} | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=256, | |
pad_token_id=tokenizer.unk_token_id, | |
eos_token_id=tokenizer.eos_token_id, | |
) | |
print(tokenizer.decode(outputs[0])) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment