-
-
Save rockerBOO/8cb9515df2655947eb6179cf9c70544c to your computer and use it in GitHub Desktop.
# Install dependencies | |
# | |
# python -m venv venv | |
# source ./venv/bin/activate # linux | |
# call ./venv/scripts/Activate.bat # windows? | |
# | |
# pip install transformers peft datasets | |
# | |
# Use the PyTorch instructions for your machine: | |
# [Get started — PyTorch](https://pytorch.org/get-started/locally/) | |
# | |
import torch | |
from transformers import AutoModelForVision2Seq, AutoProcessor | |
from peft import LoraConfig, get_peft_model | |
from datasets import load_dataset | |
from torch.utils.data import DataLoader, Dataset | |
import torch.optim as optim | |
model_id = "Salesforce/blip-image-captioning-base" | |
# Set the device which works for you or use these defaults | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# loading the model in float16 | |
model = AutoModelForVision2Seq.from_pretrained(model_id) | |
model.to(device) | |
# Print the model to see the model architecture. PEFT works with Linear and Conv2D layers | |
# print(model) | |
processor = AutoProcessor.from_pretrained(model_id) | |
config = LoraConfig( | |
r=16, | |
lora_alpha=32, | |
lora_dropout=0.05, | |
bias="none", | |
# target specific modules (look for Linear in the model) | |
# print(model) to see the architecture of the model | |
target_modules=[ | |
"self.query", | |
"self.key", | |
"self.value", | |
"output.dense", | |
"self_attn.qkv", | |
"self_attn.projection", | |
"mlp.fc1", | |
"mlp.fc2", | |
], | |
) | |
# We layer our PEFT on top of our model using the PEFT config | |
model = get_peft_model(model, config) | |
model.print_trainable_parameters() | |
# We are extracting the train dataset | |
dataset = load_dataset("ybelkada/football-dataset", split="train") | |
print(len(dataset)) | |
print(next(iter(dataset))) | |
# Each dataset may have different names and format and this function should be | |
# adjusted for each dataset. Creating `image` and `text` keys to correlate with | |
# the collator | |
class FootballImageCaptioningDataset(Dataset): | |
def __init__(self, dataset, processor): | |
self.dataset = dataset | |
self.processor = processor | |
def __len__(self): | |
return len(self.dataset) | |
def __getitem__(self, idx): | |
item = self.dataset[idx] | |
encoding = self.processor( | |
images=item["image"], padding="max_length", return_tensors="pt" | |
) | |
# remove batch dimension | |
encoding = {k: v.squeeze() for k, v in encoding.items()} | |
encoding["text"] = item["text"] | |
return encoding | |
def collator(batch): | |
# pad the input_ids and attention_mask | |
processed_batch = {} | |
for key in batch[0].keys(): | |
if key != "text": | |
processed_batch[key] = torch.stack( | |
[example[key] for example in batch] | |
) | |
processed_batch[key].to(device) | |
else: | |
text_inputs = processor.tokenizer( | |
[example["text"] for example in batch], | |
padding=True, | |
return_tensors="pt", | |
) | |
processed_batch["input_ids"] = text_inputs["input_ids"] | |
processed_batch["attention_mask"] = text_inputs["attention_mask"] | |
return processed_batch | |
train_dataset = FootballImageCaptioningDataset(dataset, processor) | |
# Set batch_size to the batch that works for you. | |
train_dataloader = DataLoader( | |
train_dataset, shuffle=True, batch_size=2, collate_fn=collator | |
) | |
# Setup AdamW | |
optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4) | |
# Set the model as ready for training, makes sure the gradient are on | |
model.train() | |
for epoch in range(50): | |
print("Epoch:", epoch) | |
for idx, batch in enumerate(train_dataloader): | |
input_ids = batch.pop("input_ids").to(device) | |
pixel_values = batch.pop("pixel_values").to(device) | |
attention_mask = batch.pop("attention_mask").to(device) | |
outputs = model( | |
input_ids=input_ids, | |
pixel_values=pixel_values, | |
labels=input_ids, | |
attention_mask=attention_mask, | |
) | |
loss = outputs.loss | |
print("Loss:", loss.item()) | |
loss.backward() | |
optimizer.step() | |
optimizer.zero_grad() | |
if idx % 10 == 0: | |
generated_output = model.generate( | |
pixel_values=pixel_values, max_new_tokens=64 | |
) | |
print( | |
processor.batch_decode( | |
generated_output, skip_special_tokens=True | |
) | |
) | |
print("Saving to ./training/caption") | |
# hugging face models saved to the directory | |
model.save_pretrained("./training/caption") |
@Linn0910 Thank you for your kind words. What do you mean by fixed text format? Generally there will be limited control over the exact phrasing unless you overfit the model a lot. You can limit the number of tokens though.
@Linn0910 Thank you for your kind words. What do you mean by fixed text format? Generally there will be limited control over the exact phrasing unless you overfit the model a lot. You can limit the number of tokens though.
Thank you for your reply! My goal is to have BLIP generate a fixed text format, such as "A [shape] mirror located in [where], positioned [where]." I understand the limitations you mentioned, but my idea is to create structured labels during fine-tuning and use a small dataset to help the model generate these fixed descriptions. I hope to guide the model towards more consistent outputs by limiting the length of the text or applying a specific format.
I understand the risk of overfitting, so I plan to use a small, targeted training set and aim to avoid overfitting while fine-tuning. Do you have any suggestions for this approach? Are there any additional techniques that could help the model maintain the fixed format without overfitting?
I think something like Florence2 might be better as it has segmentation and relatively the same resource requirement. See https://github.com/rockerBOO/caption-train for a trainer of that. There might be masking but not sure you can input text for captioning generally. With BLIP2 they have Image Question Answer (IQA) which might be guidable enough though. I forget if Florence2 has IQA options. So you'd be like "What is it?" "Where is it?" and it'd respond.
…
On Wed, Dec 11, 2024 at 9:52 PM Linn0910 @.> wrote: @.* commented on this gist. ------------------------------ @Linn0910 https://github.com/Linn0910 Thank you for your kind words. What do you mean by fixed text format? Generally there will be limited control over the exact phrasing unless you overfit the model a lot. You can limit the number of tokens though. Thank you for your reply! My goal is to have BLIP generate a fixed text format, such as "A [shape] mirror located in [where], positioned [where]." I understand the limitations you mentioned, but my idea is to create structured labels during fine-tuning and use a small dataset to help the model generate these fixed descriptions. I hope to guide the model towards more consistent outputs by limiting the length of the text or applying a specific format. I understand the risk of overfitting, so I plan to use a small, targeted training set and aim to avoid overfitting while fine-tuning. Do you have any suggestions for this approach? Are there any additional techniques that could help the model maintain the fixed format without overfitting? — Reply to this email directly, view it on GitHub https://gist.github.com/rockerBOO/8cb9515df2655947eb6179cf9c70544c#gistcomment-5333681 or unsubscribe https://github.com/notifications/unsubscribe-auth/AAADVM2R6GOP2PG55476L332FD25TBFKMF2HI4TJMJ2XIZLTSKBKK5TBNR2WLJDUOJ2WLJDOMFWWLO3UNBZGKYLEL5YGC4TUNFRWS4DBNZ2F6YLDORUXM2LUPGBKK5TBNR2WLJDHNFZXJJDOMFWWLK3UNBZGKYLEL52HS4DFVRZXKYTKMVRXIX3UPFYGLK2HNFZXIQ3PNVWWK3TUUZ2G64DJMNZZDAVEOR4XAZNEM5UXG5FFOZQWY5LFVEYTEOBRGMYDANJQU52HE2LHM5SXFJTDOJSWC5DF . You are receiving this email because you authored the thread. Triage notifications on the go with GitHub Mobile for iOS https://apps.apple.com/app/apple-store/id1477376905?ct=notification-email&mt=8&pt=524675 or Android https://play.google.com/store/apps/details?id=com.github.android&referrer=utm_campaign%3Dnotification-email%26utm_medium%3Demail%26utm_source%3Dgithub .
Thank you for the suggestion! I see how Florence2 could be a good fit, especially with its segmentation capabilities and similar resource requirements. I’ll check out the trainer you mentioned for Florence2 and explore its potential for handling masking.
Regarding BLIP2, the Image Question Answering (IQA) feature sounds promising, especially if it can provide guided responses like "What is it?" or "Where is it?" I’ll investigate whether Florence2 has similar IQA options as well.
I’m still considering the approach for fixed-format captions, so I’ll explore both options and see which one might work best for my specific needs. Thanks again for your helpful insights!
Hello,really helpful for your code. Here I get a question is about if there any methods that I can do finetune for BLIP so I can make the model to generate the fixed text format I want.,this is really puzzlling me, I hope I can get your advice and I really appreciate for your kindness help!!