Last active
June 7, 2024 09:44
-
-
Save sroecker/5c3a9eb1fd0c898e4119b89ff1095038 to your computer and use it in GitHub Desktop.
Modal: Batch eval Moondream with Pokemon dataset
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import modal | |
app = modal.App(name="pokemon-eval") | |
moondream_image = modal.Image.micromamba( | |
python_version="3.11" | |
).apt_install( | |
"git" | |
).micromamba_install( | |
"cudatoolkit", | |
"cudnn", | |
"cuda-nvcc", | |
channels=["conda-forge", "nvidia"], | |
).pip_install( | |
"torch", | |
"torchvision", | |
"accelerate", | |
"transformers", | |
"datasets", | |
"einops", | |
"Pillow", | |
gpu="A100" | |
).run_commands("pip install flash-attn --no-build-isolation") | |
@app.function(gpu="A100", image=moondream_image, timeout=1200) | |
def eval_pokemon(image_urls): | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
model_id = "vikhyatk/moondream2" | |
revision = "2024-05-20" | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, trust_remote_code=True, revision=revision, | |
torch_dtype=torch.float16, attn_implementation="flash_attention_2", device_map = 'cuda', | |
).to("cuda") | |
tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision) | |
from PIL import Image | |
from io import BytesIO | |
import requests | |
# TODO loop over batches | |
# batch size that fits into 16GB | |
#N=16 | |
#N=32 | |
# 64 too high for A100 (40GB) | |
N=40 | |
def batches(lst, n): | |
for i in range(0, len(lst), n): | |
yield lst[i:i + n] | |
all_answers = [] | |
for batch in batches(image_urls, N): | |
# FIXME let's work with the urls for now | |
#image_urls = [item['image_url'] for item in dataset][:N] | |
#image_urls = image_urls[:N] | |
# Downdload images from URLs into PIL Image objects | |
#images = [Image.open(BytesIO(requests.get(url).content)) for url in image_urls] | |
images = [Image.open(BytesIO(requests.get(url).content)) for url in batch] | |
answers = model.batch_answer( | |
images=images, | |
prompts=["Describe this image." for i in range(len(images))], | |
tokenizer=tokenizer, | |
) | |
all_answers.append(answers) | |
# flatten | |
return sum(all_answers, []) | |
@app.local_entrypoint() | |
def main(): | |
import pandas as pd | |
from datasets import load_dataset | |
dataset = load_dataset("TheFusion21/PokemonCards", split='train') | |
#splits = [f'train[{k}%:{k+10}%]' for k in range(0, 100, 10)] | |
splits = [f'train[{k}%:{k+10}%]' for k in range(0, 100, 10)] | |
split_datasets = [load_dataset("TheFusion21/PokemonCards", split=split) for split in splits] | |
split_urls = [split['image_url'] for split in split_datasets] | |
answers = [] | |
# FIXME need to pass simple url list | |
for part in eval_pokemon.map(split_urls): | |
answers.append(part) | |
print(len(part)) | |
flat_answers = sum(answers, []) | |
df = pd.DataFrame(flat_answers, columns=['answers']) | |
df.to_csv('pokemon_descriptions.csv', index=False) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment