-
-
Save hvaara/db010dcace960a9822bbcf5e6709f4eb to your computer and use it in GitHub Desktop.
A bug in MPS/PyTroch that leads to overflow on cumsum prevents LLaVA models from being run in transformers when using MPS devices (https://github.com/huggingface/transformers/issues/30294). https://github.com/hvaara/transformers/commit/ed2f0df5992fbe11521a725efa65c99262633913 provides a workaround. This gist shows the output when running a LLaVA…
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" | |
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration | |
import torch | |
from PIL import Image | |
import requests | |
processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf") | |
model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf", torch_dtype=torch.float16, low_cpu_mem_usage=True) | |
model.to("mps") | |
# prepare image and text prompt, using the appropriate prompt template | |
url = "https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true" | |
image = Image.open(requests.get(url, stream=True).raw) | |
prompt = "<|im_start|>system\nAnswer the questions.<|im_end|><|im_start|>user\n<image>\nWhat is shown in this image?<|im_end|><|im_start|>assistant\n" | |
inputs = processor(prompt, image, return_tensors="pt").to("mps") | |
# autoregressively complete prompt | |
output = model.generate(**inputs, max_new_tokens=100) | |
output_text = processor.decode(output[0], skip_special_tokens=True) | |
print(f"{output_text=}") | |
# Example output: | |
# output_text='<|im_start|>system\nAnswer the questions.<|im_end|><|im_start|>user\n \nWhat is shown in this image?<|im_end|><|im_start|>assistant\nThe image shows a radar chart, which is a type of graph that displays multivariate data in the form of a two-dimensional chart of three or more quantitative variables represented on axes starting from the same point. This particular radar chart is used to compare the performance of different models or algorithms across various metrics. The axes represent different metrics, such as MM-Vet, MM-Vet, MM-Vet, MM-Vet, MM' |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment