Skip to content

Instantly share code, notes, and snippets.

@bowbowbow
Created August 21, 2019 16:05
Show Gist options
  • Save bowbowbow/4e19cdcb9a238f023b6e9da3f811bdb3 to your computer and use it in GitHub Desktop.
Save bowbowbow/4e19cdcb9a238f023b6e9da3f811bdb3 to your computer and use it in GitHub Desktop.
# Convolutional neural network (two convolutional layers)
class ConvNet(nn.Module):
def __init__(self, hidden_size=200):
super(ConvNet, self).__init__()
self.hidden_size = hidden_size
self.layer1 = nn.Sequential(
nn.Conv1d(34, 68, kernel_size=5, stride=1, padding=2),
nn.BatchNorm1d(68),
nn.CELU(),
nn.MaxPool1d(kernel_size=5, stride=2),
)
self.layer2 = nn.Sequential(
nn.Conv1d(68, 128, kernel_size=5, stride=1, padding=2),
nn.BatchNorm1d(128),
nn.CELU(),
nn.MaxPool1d(kernel_size=5, stride=2),
)
self.layer3 = nn.Sequential(
nn.Conv1d(128, 256, kernel_size=5, stride=1, padding=2),
nn.BatchNorm1d(256),
nn.CELU(),
nn.MaxPool1d(kernel_size=5, stride=2),
nn.Dropout(0.5),
)
self.fc = nn.Sequential(
nn.Linear(15104, 6000),
nn.CELU(),
nn.Linear(6000, 2000),
nn.Linear(2000, self.hidden_size),
nn.CELU(),
)
def forward(self, x):
out = self.layer1(x)
out = self.layer2(out)
out = self.layer3(out)
out = out.reshape(out.size(0), -1)
out = self.fc(out)
return out
class BertNet(nn.Module):
def __init__(self, finetuning=True, hidden_size=200):
super().__init__()
self.bert = BertModel.from_pretrained('bert-base-uncased')
self.bert_output_size = 768
self.hidden_size = hidden_size
self.rnn = nn.LSTM(input_size=self.bert_output_size, hidden_size=self.hidden_size, batch_first=True, bidirectional=True)
self.fc = nn.Linear(self.hidden_size * 2, self.hidden_size)
self.drop = nn.Dropout(0.5)
self.finetuning = finetuning
def forward(self, x):
if self.training and self.finetuning:
self.bert.train()
encoded_layers, _ = self.bert(x)
enc1 = encoded_layers[-1] # [batch_size, max_len, hidden_size]
else:
self.bert.eval()
with torch.no_grad():
encoded_layers, _ = self.bert(x)
enc1 = encoded_layers[-1]
enc, (final_hidden_state, final_cell_state) = self.rnn(enc1) # final_hidden_sate: [1, batch_size, hidden_size]
# enc: [batch_size, seq_len, num_directions * hidden_size]
# Decode the hidden state of the last time step
enc = enc[:, -1, :]
logits = self.fc(enc)
logits = self.drop(logits)
return logits
class MultiModal(nn.Module):
def __init__(self, num_classes=3, hidden_size=300):
super().__init__()
self.hidden_size = hidden_size
self.num_classes = num_classes
self.bert_net = BertNet(hidden_size = self.hidden_size)
self.conv_et = ConvNet(hidden_size = self.hidden_size)
self.fc = nn.Sequential(
nn.Linear(self.hidden_size *2, self.hidden_size * 2),
nn.CELU(),
nn.Linear(self.hidden_size * 2, self.hidden_size),
nn.Linear(self.hidden_size, self.num_classes),
)
def forward(self, text_x, sound_x):
text_x = self.bert_net(text_x)
sound_x = self.conv_et(sound_x)
out = torch.cat([text_x, sound_x], 1)
out = self.fc(out)
return out
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment