Skip to content

Instantly share code, notes, and snippets.

@johnchrishays
Created May 9, 2020 01:51
Show Gist options
  • Save johnchrishays/80c01d95c3bb586c880667b439987055 to your computer and use it in GitHub Desktop.
Save johnchrishays/80c01d95c3bb586c880667b439987055 to your computer and use it in GitHub Desktop.
class FaceClassifier(nn.Module):
def __init__(self, n_vid_features, n_aud_features, n_head, n_layers, n_linear_hidden=30, dropout=0.1):
super(FaceClassifier, self).__init__()
# video
self.vid_pos_encoder = PositionalEncoding(d_model=n_vid_features)
vid_encoder_layer = nn.TransformerEncoderLayer(d_model=n_vid_features, nhead=n_head)
self.vid_transformer_encoder = nn.TransformerEncoder(vid_encoder_layer, num_layers=n_layers)
#self.dropout = nn.Dropout(p=dropout)
self.vid_pred = nn.Linear(n_vid_features, 1)
# audio
self.aud_pos_encoder = PositionalEncoding(d_model=n_aud_features)
aud_encoder_layer = nn.TransformerEncoderLayer(d_model=n_aud_features, nhead=1)
self.aud_transformer_encoder = nn.TransformerEncoder(aud_encoder_layer, num_layers=n_layers)
# combine video and audio
self.out_pred = nn.Linear(2, 1)
def forward(self, vid, aud):
vid = vid.permute(1, 0, 2)
vid = self.vid_pos_encoder(vid)
vid = self.vid_transformer_encoder(vid)
vid = self.vid_pred(vid)
vid = torch.sigmoid(vid)
vid = torch.mean(vid, axis=0)
aud = aud.permute(1, 0, 2)
aud = self.aud_pos_encoder(aud)
aud = self.aud_transformer_encoder(aud)
aud = torch.sigmoid(aud)
aud = torch.mean(aud, axis=0)
x = torch.cat((vid, aud), 1) # classify based on last output of the encoder
x = self.out_pred(x)
return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment