Skip to content

Instantly share code, notes, and snippets.

@tezansahu
Last active February 16, 2024 13:12
Show Gist options
  • Save tezansahu/0367787e388c10dd7b8956c837d6c715 to your computer and use it in GitHub Desktop.
Save tezansahu/0367787e388c10dd7b8956c837d6c715 to your computer and use it in GitHub Desktop.
class MultimodalVQAModel(nn.Module):
def __init__(self, pretrained_text_name, pretrained_image_name, num_labels=len(answer_space), intermediate_dim=512, dropout=0.5):
super(MultimodalVQAModel, self).__init__()
self.num_labels = num_labels
self.pretrained_text_name = pretrained_text_name
self.pretrained_image_name = pretrained_image_name
# Pretrained transformers for text & image featurization
self.text_encoder = AutoModel.from_pretrained(self.pretrained_text_name)
self.image_encoder = AutoModel.from_pretrained(self.pretrained_image_name)
# Fusion layer for cross-modal interaction
self.fusion = nn.Sequential(
nn.Linear(self.text_encoder.config.hidden_size + self.image_encoder.config.hidden_size, intermediate_dim),
nn.ReLU(),
nn.Dropout(0.5),
)
# Fully-connected classifier
self.classifier = nn.Linear(intermediate_dim, self.num_labels)
self.criterion = nn.CrossEntropyLoss()
def forward(
self,
input_ids: torch.LongTensor,
pixel_values: torch.FloatTensor,
attention_mask: Optional[torch.LongTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
labels: Optional[torch.LongTensor] = None):
encoded_text = self.text_encoder(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
return_dict=True,
)
encoded_image = self.image_encoder(
pixel_values=pixel_values,
return_dict=True,
)
fused_output = self.fusion(
torch.cat(
[
encoded_text['pooler_output'],
encoded_image['pooler_output'],
],
dim=1
)
)
logits = self.classifier(fused_output)
out = {
"logits": logits
}
if labels is not None:
loss = self.criterion(logits, labels)
out["loss"] = loss
return out
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment