Skip to content

Instantly share code, notes, and snippets.

@rockerBOO
Created February 5, 2024 18:40
Show Gist options
  • Save rockerBOO/8cb9515df2655947eb6179cf9c70544c to your computer and use it in GitHub Desktop.
Save rockerBOO/8cb9515df2655947eb6179cf9c70544c to your computer and use it in GitHub Desktop.
# Install dependencies
#
# python -m venv venv
# source ./venv/bin/activate # linux
# call ./venv/scripts/Activate.bat # windows?
#
# pip install transformers peft datasets
#
# Use the PyTorch instructions for your machine:
# [Get started — PyTorch](https://pytorch.org/get-started/locally/)
#
import torch
from transformers import AutoModelForVision2Seq, AutoProcessor
from peft import LoraConfig, get_peft_model
from datasets import load_dataset
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim
model_id = "Salesforce/blip-image-captioning-base"
# Set the device which works for you or use these defaults
device = "cuda" if torch.cuda.is_available() else "cpu"
# loading the model in float16
model = AutoModelForVision2Seq.from_pretrained(model_id)
model.to(device)
# Print the model to see the model architecture. PEFT works with Linear and Conv2D layers
# print(model)
processor = AutoProcessor.from_pretrained(model_id)
config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
# target specific modules (look for Linear in the model)
# print(model) to see the architecture of the model
target_modules=[
"self.query",
"self.key",
"self.value",
"output.dense",
"self_attn.qkv",
"self_attn.projection",
"mlp.fc1",
"mlp.fc2",
],
)
# We layer our PEFT on top of our model using the PEFT config
model = get_peft_model(model, config)
model.print_trainable_parameters()
# We are extracting the train dataset
dataset = load_dataset("ybelkada/football-dataset", split="train")
print(len(dataset))
print(next(iter(dataset)))
# Each dataset may have different names and format and this function should be
# adjusted for each dataset. Creating `image` and `text` keys to correlate with
# the collator
class FootballImageCaptioningDataset(Dataset):
def __init__(self, dataset, processor):
self.dataset = dataset
self.processor = processor
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
item = self.dataset[idx]
encoding = self.processor(
images=item["image"], padding="max_length", return_tensors="pt"
)
# remove batch dimension
encoding = {k: v.squeeze() for k, v in encoding.items()}
encoding["text"] = item["text"]
return encoding
def collator(batch):
# pad the input_ids and attention_mask
processed_batch = {}
for key in batch[0].keys():
if key != "text":
processed_batch[key] = torch.stack(
[example[key] for example in batch]
)
processed_batch[key].to(device)
else:
text_inputs = processor.tokenizer(
[example["text"] for example in batch],
padding=True,
return_tensors="pt",
)
processed_batch["input_ids"] = text_inputs["input_ids"]
processed_batch["attention_mask"] = text_inputs["attention_mask"]
return processed_batch
train_dataset = FootballImageCaptioningDataset(dataset, processor)
# Set batch_size to the batch that works for you.
train_dataloader = DataLoader(
train_dataset, shuffle=True, batch_size=2, collate_fn=collator
)
# Setup AdamW
optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
# Set the model as ready for training, makes sure the gradient are on
model.train()
for epoch in range(50):
print("Epoch:", epoch)
for idx, batch in enumerate(train_dataloader):
input_ids = batch.pop("input_ids").to(device)
pixel_values = batch.pop("pixel_values").to(device)
attention_mask = batch.pop("attention_mask").to(device)
outputs = model(
input_ids=input_ids,
pixel_values=pixel_values,
labels=input_ids,
attention_mask=attention_mask,
)
loss = outputs.loss
print("Loss:", loss.item())
loss.backward()
optimizer.step()
optimizer.zero_grad()
if idx % 10 == 0:
generated_output = model.generate(
pixel_values=pixel_values, max_new_tokens=64
)
print(
processor.batch_decode(
generated_output, skip_special_tokens=True
)
)
print("Saving to ./training/caption")
# hugging face models saved to the directory
model.save_pretrained("./training/caption")
@Linn0910
Copy link

Hello,really helpful for your code. Here I get a question is about if there any methods that I can do finetune for BLIP so I can make the model to generate the fixed text format I want.,this is really puzzlling me, I hope I can get your advice and I really appreciate for your kindness help!!

@rockerBOO
Copy link
Author

@Linn0910 Thank you for your kind words. What do you mean by fixed text format? Generally there will be limited control over the exact phrasing unless you overfit the model a lot. You can limit the number of tokens though.

@Linn0910
Copy link

@Linn0910 Thank you for your kind words. What do you mean by fixed text format? Generally there will be limited control over the exact phrasing unless you overfit the model a lot. You can limit the number of tokens though.

Thank you for your reply! My goal is to have BLIP generate a fixed text format, such as "A [shape] mirror located in [where], positioned [where]." I understand the limitations you mentioned, but my idea is to create structured labels during fine-tuning and use a small dataset to help the model generate these fixed descriptions. I hope to guide the model towards more consistent outputs by limiting the length of the text or applying a specific format.

I understand the risk of overfitting, so I plan to use a small, targeted training set and aim to avoid overfitting while fine-tuning. Do you have any suggestions for this approach? Are there any additional techniques that could help the model maintain the fixed format without overfitting?

@rockerBOO
Copy link
Author

rockerBOO commented Dec 12, 2024 via email

@Linn0910
Copy link

I think something like Florence2 might be better as it has segmentation and relatively the same resource requirement. See https://github.com/rockerBOO/caption-train for a trainer of that. There might be masking but not sure you can input text for captioning generally. With BLIP2 they have Image Question Answer (IQA) which might be guidable enough though. I forget if Florence2 has IQA options. So you'd be like "What is it?" "Where is it?" and it'd respond.

On Wed, Dec 11, 2024 at 9:52 PM Linn0910 @.> wrote: @.* commented on this gist. ------------------------------ @Linn0910 https://github.com/Linn0910 Thank you for your kind words. What do you mean by fixed text format? Generally there will be limited control over the exact phrasing unless you overfit the model a lot. You can limit the number of tokens though. Thank you for your reply! My goal is to have BLIP generate a fixed text format, such as "A [shape] mirror located in [where], positioned [where]." I understand the limitations you mentioned, but my idea is to create structured labels during fine-tuning and use a small dataset to help the model generate these fixed descriptions. I hope to guide the model towards more consistent outputs by limiting the length of the text or applying a specific format. I understand the risk of overfitting, so I plan to use a small, targeted training set and aim to avoid overfitting while fine-tuning. Do you have any suggestions for this approach? Are there any additional techniques that could help the model maintain the fixed format without overfitting? — Reply to this email directly, view it on GitHub https://gist.github.com/rockerBOO/8cb9515df2655947eb6179cf9c70544c#gistcomment-5333681 or unsubscribe https://github.com/notifications/unsubscribe-auth/AAADVM2R6GOP2PG55476L332FD25TBFKMF2HI4TJMJ2XIZLTSKBKK5TBNR2WLJDUOJ2WLJDOMFWWLO3UNBZGKYLEL5YGC4TUNFRWS4DBNZ2F6YLDORUXM2LUPGBKK5TBNR2WLJDHNFZXJJDOMFWWLK3UNBZGKYLEL52HS4DFVRZXKYTKMVRXIX3UPFYGLK2HNFZXIQ3PNVWWK3TUUZ2G64DJMNZZDAVEOR4XAZNEM5UXG5FFOZQWY5LFVEYTEOBRGMYDANJQU52HE2LHM5SXFJTDOJSWC5DF . You are receiving this email because you authored the thread. Triage notifications on the go with GitHub Mobile for iOS https://apps.apple.com/app/apple-store/id1477376905?ct=notification-email&mt=8&pt=524675 or Android https://play.google.com/store/apps/details?id=com.github.android&referrer=utm_campaign%3Dnotification-email%26utm_medium%3Demail%26utm_source%3Dgithub .

Thank you for the suggestion! I see how Florence2 could be a good fit, especially with its segmentation capabilities and similar resource requirements. I’ll check out the trainer you mentioned for Florence2 and explore its potential for handling masking.

Regarding BLIP2, the Image Question Answering (IQA) feature sounds promising, especially if it can provide guided responses like "What is it?" or "Where is it?" I’ll investigate whether Florence2 has similar IQA options as well.

I’m still considering the approach for fixed-format captions, so I’ll explore both options and see which one might work best for my specific needs. Thanks again for your helpful insights!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment