Skip to content

Instantly share code, notes, and snippets.

@sroecker
Last active June 13, 2024 13:01
Show Gist options
  • Save sroecker/5adbf966d9eba5522df1c56f9bf7a818 to your computer and use it in GitHub Desktop.
Save sroecker/5adbf966d9eba5522df1c56f9bf7a818 to your computer and use it in GitHub Desktop.
A script to caption datikz graphs with Moondream using Modal
import modal
app = modal.App(name="moondream-label-datikz_v2")
data_dict = modal.Dict.from_name("HF_DATASET", create_if_missing=True)
def download_dataset():
from datasets import load_dataset
data_dict["HF_DATASET"] = "nllg/datikz-v2"
dataset = load_dataset(data_dict["HF_DATASET"])
def download_model():
model_id = "vikhyatk/moondream2"
revision = "2024-05-20"
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained(
model_id, trust_remote_code=True, revision=revision,
)
tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision)
moondream_image = modal.Image.micromamba(
python_version="3.11"
).apt_install(
"git"
).micromamba_install(
"cudatoolkit",
"cudnn",
"cuda-nvcc",
channels=["conda-forge", "nvidia"],
).pip_install(
"torch",
"torchvision",
"accelerate",
"transformers",
"datasets",
"einops",
"Pillow",
"xxhash",
gpu="A100"
).run_commands(
"pip install flash-attn --no-build-isolation"
).run_function(
download_dataset
).run_function(download_model)
@app.function(gpu="A100", image=moondream_image, timeout=3600)
def label_dataset(split):
import torch
import pandas as pd
from transformers import AutoModelForCausalLM, AutoTokenizer
import xxhash
# load moondream model
model_id = "vikhyatk/moondream2"
revision = "2024-05-20"
model = AutoModelForCausalLM.from_pretrained(
model_id, trust_remote_code=True, revision=revision, device_map = 'cuda',
torch_dtype=torch.float16, attn_implementation="flash_attention_2",
).to("cuda")
tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision)
print("torch.cuda.memory_allocated: %fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
print("torch.cuda.memory_reserved: %fGB"%(torch.cuda.memory_reserved(0)/1024/1024/1024))
print("torch.cuda.max_memory_reserved: %fGB"%(torch.cuda.max_memory_reserved(0)/1024/1024/1024))
# load HF dataset
from datasets import load_dataset
ds = load_dataset(data_dict["HF_DATASET"], split=split, keep_in_memory=True)
#ds = ds.select(range(100)) # for debugging
print(len(ds))
# Batch size
#N=12 # Fits in 16G VRAM when truncating prompt
N=26 # Fits into 40GB VRAM
# simple mini batch generator
def batches(lst, n):
for i in range(0, len(lst), n):
yield lst[i:i + n]
import pandas as pd
from datasets import Image
img_enc = Image()
r = []
for batch in batches(ds, N):
prompts = ["Describe this diagram using the following context, excluding anything that is not directly deducible from the graph: "+c[:1280] for c in batch['caption']]
answers = model.batch_answer(
images=batch['image'],
prompts=prompts,
tokenizer=tokenizer,
repetition_penalty=1.2, # Important to avoid repetitions, chosen value might not be best
)
r.append(pd.DataFrame({'caption': answers, 'orig_caption': batch['caption'], 'image': [img_enc.encode_example(img) for img in batch['image']]} ))
if len(r) % 10 == 0:
print(len(r))
print("torch.cuda.max_memory_allocated: %fGB"%(torch.cuda.max_memory_allocated(0)/1024/1024/1024))
return pd.concat(r)
@app.local_entrypoint()
def main():
import pandas as pd
from datasets import load_dataset, Dataset
# split dataset into 10 equal parts
#splits = [f'train[{k}%:{k+10}%]' for k in range(0, 100, 10)]
# split dataset into 100 equal parts
splits = [f'train[{k}%:{k+2}%]' for k in range(0, 100, 2)]
print(splits)
results = []
#for part_result in label_dataset.map([splits[0]]): # for debugging
for part_result in label_dataset.map(splits):
results.append(part_result)
print(len(part_result))
# concatenate the list of part_results and load as HF ds
result_df = pd.concat(results)
result_ds = Dataset.from_pandas(result_df)
# properly cast image column
from datasets import Image
img_enc = Image()
result_ds = result_ds.cast_column("image", Image())
print(result_ds)
# push result to HF
result_ds.push_to_hub('datikz-v2-moondream-labels')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment