October 22, 2019
import torchvision.models as models
resnet18 = models.resnet18()
from transformers import BertEmbeddings, BertEncoder
class MMBDEmbeddings(nn.Module):
def __init__(self,
text_mod_embds = BertEmbeddings, # Or your favorite bidirectional transformer
vision_mod_embds = resnet18): # Or your favorite vision model
super(MMBDEmbeddings, self).__init__()
self.text_mod_embds = text_mod_embds
self.vision_mod_embds = vision_mod_embds
self.vision_to_text_proj = nn.Linear(vision_dim, text_dim)
def forward(self, input_ids, images):
image_embs = self.vision_mod_embds(images)
proj_image_embds = self.vision_to_text_proj(image_embds)
token_embds = self.text_mod_embds(input_ids)
rslt = {'image': proj_image_embds, 'text': token_embds}
return rslt
class MMBDModel(nn.Module):
def __init__(self,
embeddings, # a MMBDEmbeddings object
encoder, # BertEncoder for instance
pooler): # It can be as simplest as take the [CLS] hiddens state.
super(MMBDModel, self).__init__()
self.embeddings = embeddings
self.encoder = encoder
def forward(self,
images): # and other arguments such as attention_mask, token_type_ids, etc. (see the encoder)
embings = self.embeddings(inputs_ids, images)
# do the concatenation of the two sequences of embeddings --> embds_seq
hidden_states = self.encoder(embds_seq)
pooled_output = self.pooler(hidden_states)
outputs = (hidden_states, pooled_output)
return outputs
class MMDBForMultiModalClassification(nn.Module):
def __init__(self,
self.classification_head = nn.Linear(768, 2) # for instance for a binary classification
self.mmdb_model = mmdb_model
def forward(self,
_, pooled_output = self.mmdb_model(input_ids, images)
return self.classification_head(pooled_output)
