Skip to content

Instantly share code, notes, and snippets.

@sroecker
Last active June 11, 2024 19:45
Show Gist options
  • Save sroecker/feaa61ea69182cb7ae1c9328b755786a to your computer and use it in GitHub Desktop.
Save sroecker/feaa61ea69182cb7ae1c9328b755786a to your computer and use it in GitHub Desktop.
A script to caption datikz graphs with Moondream
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import xxhash
from tqdm import tqdm
# load moondream model
model_id = "vikhyatk/moondream2"
revision = "2024-05-20"
model = AutoModelForCausalLM.from_pretrained(
model_id, trust_remote_code=True, revision=revision,
torch_dtype=torch.float16, attn_implementation="flash_attention_2",
).to("cuda")
tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision)
# load HF dataset
HF_DATASET = "nllg/datikz-v2"
from datasets import load_dataset, Dataset
#dataset = load_dataset(HF_DATASET, split='train')
# FIXME load subset for testing
#dataset = load_dataset(HF_DATASET, split='train[2708:2716]')
#dataset = load_dataset(HF_DATASET, split='train[:36]')
dataset = load_dataset(HF_DATASET, split='train[:3200]')
# create a unique id for every row using xxhash32
ds = dataset.map(lambda r: {'id_': xxhash.xxh32_hexdigest(str(list(r.values())))})
# Batch size
#N=8
N=12 # Fits in 16G VRAM when truncating prompt
import pandas as pd
from datasets import Image
img_enc = Image()
# simple batch generator
def batches(lst, n):
for i in range(0, len(lst), n):
yield lst[i:i + n]
r = []
with tqdm(total=len(ds)) as pbar:
for batch in batches(ds, N):
"""
# DEBUG
for img in batch['image']:
print(img.size)
for c in batch['caption']:
print(len(c))
"""
prompts = ["Describe this diagram using the following context, excluding anything that is not directly deducible from the graph: "+c[:1280] for c in batch['caption']]
answers = model.batch_answer(
images=batch['image'],
prompts=prompts,
tokenizer=tokenizer,
repetition_penalty=1.2, # Important to avoid repetitions, chosen value might not be best
)
# DEBUG
print(answers)
pbar.update(len(answers))
r.append(pd.DataFrame({'id': batch['id_'], 'caption': answers, 'orig_caption': batch['caption'], 'image': [img_enc.encode_example(img) for img in batch['image']]} ))
# concatenate the list of pandas dfs and load as HF ds
df = pd.concat(r)
result_ds = Dataset.from_pandas(df)
# properly cast image column
result_ds = result_ds.cast_column("image", Image())
# save result to disk and push to HF
result_ds.save_to_disk('datikz-v2-moondream-caption-test3')
result_ds.push_to_hub('datikz-v2-moondream-caption-test3')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment