Skip to content

Instantly share code, notes, and snippets.

@hvaara
Created May 1, 2024 13:29
Show Gist options
  • Save hvaara/db010dcace960a9822bbcf5e6709f4eb to your computer and use it in GitHub Desktop.
Save hvaara/db010dcace960a9822bbcf5e6709f4eb to your computer and use it in GitHub Desktop.
A bug in MPS/PyTroch that leads to overflow on cumsum prevents LLaVA models from being run in transformers when using MPS devices (https://github.com/huggingface/transformers/issues/30294). https://github.com/hvaara/transformers/commit/ed2f0df5992fbe11521a725efa65c99262633913 provides a workaround. This gist shows the output when running a LLaVA…
import os
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
import torch
from PIL import Image
import requests
processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf", torch_dtype=torch.float16, low_cpu_mem_usage=True)
model.to("mps")
# prepare image and text prompt, using the appropriate prompt template
url = "https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true"
image = Image.open(requests.get(url, stream=True).raw)
prompt = "<|im_start|>system\nAnswer the questions.<|im_end|><|im_start|>user\n<image>\nWhat is shown in this image?<|im_end|><|im_start|>assistant\n"
inputs = processor(prompt, image, return_tensors="pt").to("mps")
# autoregressively complete prompt
output = model.generate(**inputs, max_new_tokens=100)
output_text = processor.decode(output[0], skip_special_tokens=True)
print(f"{output_text=}")
# Example output:
# output_text='<|im_start|>system\nAnswer the questions.<|im_end|><|im_start|>user\n \nWhat is shown in this image?<|im_end|><|im_start|>assistant\nThe image shows a radar chart, which is a type of graph that displays multivariate data in the form of a two-dimensional chart of three or more quantitative variables represented on axes starting from the same point. This particular radar chart is used to compare the performance of different models or algorithms across various metrics. The axes represent different metrics, such as MM-Vet, MM-Vet, MM-Vet, MM-Vet, MM'
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment