Skip to content

Instantly share code, notes, and snippets.

@sroecker
Last active June 7, 2024 09:44
Show Gist options
  • Save sroecker/5c3a9eb1fd0c898e4119b89ff1095038 to your computer and use it in GitHub Desktop.
Save sroecker/5c3a9eb1fd0c898e4119b89ff1095038 to your computer and use it in GitHub Desktop.
Modal: Batch eval Moondream with Pokemon dataset
import modal
app = modal.App(name="pokemon-eval")
moondream_image = modal.Image.micromamba(
python_version="3.11"
).apt_install(
"git"
).micromamba_install(
"cudatoolkit",
"cudnn",
"cuda-nvcc",
channels=["conda-forge", "nvidia"],
).pip_install(
"torch",
"torchvision",
"accelerate",
"transformers",
"datasets",
"einops",
"Pillow",
gpu="A100"
).run_commands("pip install flash-attn --no-build-isolation")
@app.function(gpu="A100", image=moondream_image, timeout=1200)
def eval_pokemon(image_urls):
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
model_id = "vikhyatk/moondream2"
revision = "2024-05-20"
model = AutoModelForCausalLM.from_pretrained(
model_id, trust_remote_code=True, revision=revision,
torch_dtype=torch.float16, attn_implementation="flash_attention_2", device_map = 'cuda',
).to("cuda")
tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision)
from PIL import Image
from io import BytesIO
import requests
# TODO loop over batches
# batch size that fits into 16GB
#N=16
#N=32
# 64 too high for A100 (40GB)
N=40
def batches(lst, n):
for i in range(0, len(lst), n):
yield lst[i:i + n]
all_answers = []
for batch in batches(image_urls, N):
# FIXME let's work with the urls for now
#image_urls = [item['image_url'] for item in dataset][:N]
#image_urls = image_urls[:N]
# Downdload images from URLs into PIL Image objects
#images = [Image.open(BytesIO(requests.get(url).content)) for url in image_urls]
images = [Image.open(BytesIO(requests.get(url).content)) for url in batch]
answers = model.batch_answer(
images=images,
prompts=["Describe this image." for i in range(len(images))],
tokenizer=tokenizer,
)
all_answers.append(answers)
# flatten
return sum(all_answers, [])
@app.local_entrypoint()
def main():
import pandas as pd
from datasets import load_dataset
dataset = load_dataset("TheFusion21/PokemonCards", split='train')
#splits = [f'train[{k}%:{k+10}%]' for k in range(0, 100, 10)]
splits = [f'train[{k}%:{k+10}%]' for k in range(0, 100, 10)]
split_datasets = [load_dataset("TheFusion21/PokemonCards", split=split) for split in splits]
split_urls = [split['image_url'] for split in split_datasets]
answers = []
# FIXME need to pass simple url list
for part in eval_pokemon.map(split_urls):
answers.append(part)
print(len(part))
flat_answers = sum(answers, [])
df = pd.DataFrame(flat_answers, columns=['answers'])
df.to_csv('pokemon_descriptions.csv', index=False)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment