Skip to content

Instantly share code, notes, and snippets.

@VictorSanh
Created October 22, 2019 00:41
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save VictorSanh/84333b5ed0f3673d30be20737d7a1be7 to your computer and use it in GitHub Desktop.
Save VictorSanh/84333b5ed0f3673d30be20737d7a1be7 to your computer and use it in GitHub Desktop.
import torchvision.models as models
resnet18 = models.resnet18()
from transformers import BertEmbeddings, BertEncoder
class MMBDEmbeddings(nn.Module):
def __init__(self,
text_mod_embds = BertEmbeddings, # Or your favorite bidirectional transformer
vision_mod_embds = resnet18): # Or your favorite vision model
super(MMBDEmbeddings, self).__init__()
self.text_mod_embds = text_mod_embds
self.vision_mod_embds = vision_mod_embds
self.vision_to_text_proj = nn.Linear(vision_dim, text_dim)
def forward(self, input_ids, images):
image_embs = self.vision_mod_embds(images)
proj_image_embds = self.vision_to_text_proj(image_embds)
token_embds = self.text_mod_embds(input_ids)
rslt = {'image': proj_image_embds, 'text': token_embds}
return rslt
class MMBDModel(nn.Module):
def __init__(self,
embeddings, # a MMBDEmbeddings object
encoder, # BertEncoder for instance
pooler): # It can be as simplest as take the [CLS] hiddens state.
super(MMBDModel, self).__init__()
self.embeddings = embeddings
self.encoder = encoder
def forward(self,
input_ids,
images): # and other arguments such as attention_mask, token_type_ids, etc. (see the encoder)
embings = self.embeddings(inputs_ids, images)
# do the concatenation of the two sequences of embeddings --> embds_seq
hidden_states = self.encoder(embds_seq)
pooled_output = self.pooler(hidden_states)
outputs = (hidden_states, pooled_output)
return outputs
class MMDBForMultiModalClassification(nn.Module):
def __init__(self,
mmdb_model):
self.classification_head = nn.Linear(768, 2) # for instance for a binary classification
self.mmdb_model = mmdb_model
def forward(self,
input_ids,
images):
_, pooled_output = self.mmdb_model(input_ids, images)
return self.classification_head(pooled_output)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment