This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
multi_args = TrainingArguments( | |
output_dir="checkpoint", | |
seed=12345, | |
evaluation_strategy="steps", | |
eval_steps=100, | |
logging_strategy="steps", | |
logging_steps=100, | |
save_strategy="steps", | |
save_steps=100, | |
save_total_limit=3, # Since models are large, save only the last 3 checkpoints at any given time while training |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Wrapper around the wup_measure(...) function to process batch inputs | |
def batch_wup_measure(labels, preds): | |
wup_scores = [wup_measure(answer_space[label], answer_space[pred]) for label, pred in zip(labels, preds)] | |
return np.mean(wup_scores) | |
# Function to compute all relevant performance metrics, to be passed into the trainer | |
def compute_metrics(eval_tuple: Tuple[np.ndarray, np.ndarray]) -> Dict[str, float]: | |
logits, labels = eval_tuple | |
preds = logits.argmax(axis=-1) | |
return { |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def createMultimodalVQACollatorAndModel(text='bert-base-uncased', image='google/vit-base-patch16-224-in21k'): | |
# Initialize the correct text tokenizer and image feature extractor, and use them to create the collator | |
tokenizer = AutoTokenizer.from_pretrained(text) | |
preprocessor = AutoFeatureExtractor.from_pretrained(image) | |
multimodal_collator = MultimodalCollator(tokenizer=tokenizer, preprocessor=preprocessor) | |
# Initialize the multimodal model with the appropriate weights from pretrained models | |
multimodal_model = MultimodalVQAModel(pretrained_text_name=text, pretrained_image_name=image).to(device) | |
return multimodal_collator, multimodal_model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class MultimodalVQAModel(nn.Module): | |
def __init__(self, pretrained_text_name, pretrained_image_name, num_labels=len(answer_space), intermediate_dim=512, dropout=0.5): | |
super(MultimodalVQAModel, self).__init__() | |
self.num_labels = num_labels | |
self.pretrained_text_name = pretrained_text_name | |
self.pretrained_image_name = pretrained_image_name | |
# Pretrained transformers for text & image featurization | |
self.text_encoder = AutoModel.from_pretrained(self.pretrained_text_name) | |
self.image_encoder = AutoModel.from_pretrained(self.pretrained_image_name) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
from copy import deepcopy | |
from dataclasses import dataclass | |
from typing import Dict, List, Optional, Tuple | |
from datasets import load_dataset, set_caching_enabled | |
import numpy as np | |
from PIL import Image | |
import torch | |
import torch.nn as nn | |
from transformers import ( |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# The dataclass decorator is used to automatically generate special methods to classes, | |
# including __init__, __str__ and __repr__. It helps reduce some boilerplate code. | |
@dataclass | |
class MultimodalCollator: | |
tokenizer: AutoTokenizer | |
preprocessor: AutoFeatureExtractor | |
def tokenize_text(self, texts: List[str]): | |
encoded_text = self.tokenizer( | |
text=texts, |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from IPython.display import display | |
def showExample(train=True, id=None): | |
if train: | |
data = dataset["train"] | |
else: | |
data = dataset["test"] | |
if id == None: | |
id = np.random.randint(len(data)) | |
image = Image.open(os.path.join("..", "dataset", "images", data[id]["image_id"] + ".png")) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Load the training & evaluation dataset present in CSV format | |
dataset = load_dataset( | |
"csv", | |
data_files={ | |
"train": os.path.join("dataset", "data_train.csv"), | |
"test": os.path.join("dataset", "data_eval.csv") | |
} | |
) | |
# Load the space of all possible answers |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Define a regex pattern to normalize the question & | |
# find the image ID for which the question is asked | |
image_pattern = re.compile("( (in |on |of )?(the |this )?(image\d*) \?)") | |
with open(os.path.join("dataset", "all_qa_pairs.txt")) as f: | |
qa_data = [x.replace("\n", "") for x in f.readlines()] | |
df = pd.DataFrame({"question": [], "answer": [], "image_id":[]}) | |
for i in range(0, len(qa_data), 2): |
NewerOlder