Skip to content

Instantly share code, notes, and snippets.

@lewtun
Created July 4, 2023 15:28
Show Gist options
  • Save lewtun/3f1d23d4748043d5443d33c6eacdd336 to your computer and use it in GitHub Desktop.
Save lewtun/3f1d23d4748043d5443d33c6eacdd336 to your computer and use it in GitHub Desktop.
M4 inference
import torch
from m4.training.packing import image_attention_mask_for_packed_input_ids, incremental_to_binary_attention_mask
from m4.training.utils import build_image_transform
from io import BytesIO
from PIL import Image
import requests
from transformers import AutoTokenizer, AutoModelForCausalLM
MAX_SEQ_LEN=2048
# Local utils
def fetch_images(url_images):
images = []
for url in url_images:
if isinstance(url, str):
images.append(Image.open(BytesIO(requests.get(url, stream=True).content)))
else:
images.append(url)
return images
url_images = ["https://raw.githubusercontent.com/salesforce/LAVIS/main/docs/_static/Confusing-Pictures.jpg"]
# Preparing inputs
images = fetch_images(url_images)
model_id = "HuggingFaceH4/m4-9b-ift"
revision = "v0.0"
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True, revision=revision)
model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True, revision=revision)
model.to(0)
# Few-shot example
# prompt = "<|system|>\n</s>\n<|user|>\n<fake_token_around_image><image><fake_token_around_image>Why do people find this image funny?</s>\n<|assistant|>\nThe picture depicts an orange cat lying on the ground with a sticky note affixed to its head. On the sticky note, a human face is drawn, and it looks like the cat appears to be frowning and judging its human.</s>\n<|user|>\n<fake_token_around_image><image><fake_token_around_image>Can you locate the water bottle placed near the yellow tennis balls on the ground?</s>\n<|assistant|>\nThere is no object resembling a water bottle in this image. This question introduces a new object that doesn't exist in the image</s>\n<|user|>\n<fake_token_around_image><image><fake_token_around_image>What is unusual about this image?</s>\n<|assistant|>\n"
# Zero-shot example
prompt = "<|system|>\n</s>\n<|user|><fake_token_around_image><image><fake_token_around_image>What is unusual about this image?</s>\n<|assistant|>\n"
tokens = tokenizer(
[prompt],
truncation=True,
max_length=MAX_SEQ_LEN,
padding=True,
add_special_tokens=False,
)
input_ids = torch.tensor([[tokenizer.bos_token_id] + tokens.input_ids[0]])
attention_mask = torch.tensor([[1] + tokens.attention_mask[0]])
image_attention_mask = [
incremental_to_binary_attention_mask(
image_attention_mask_for_packed_input_ids(input_ids[0].unsqueeze(0), tokenizer)[0], num_classes=len(images)
)
]
image_transform = build_image_transform(eval=True)
pixel_values = [torch.stack([image_transform(img) for img in images])]
input_ids = input_ids.to(0)
attention_mask = attention_mask.to(0)
pixel_values = torch.stack(pixel_values).to(0)
image_attention_mask = torch.cat(image_attention_mask, 0).to(0)
inputs = {"input_ids": input_ids, "attention_mask": attention_mask, "pixel_values": pixel_values, "image_attention_mask": image_attention_mask}
outputs = model.generate(
**inputs,
max_new_tokens=256,
pad_token_id=tokenizer.unk_token_id,
eos_token_id=tokenizer.eos_token_id,
)
print(tokenizer.decode(outputs[0]))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment