Skip to content

Instantly share code, notes, and snippets.

@bowbowbow
Created August 17, 2019 19:26
Show Gist options
  • Save bowbowbow/fcd701d828a969932851bd23648ebb81 to your computer and use it in GitHub Desktop.
Save bowbowbow/fcd701d828a969932851bd23648ebb81 to your computer and use it in GitHub Desktop.
class BertNet(nn.Module):
def __init__(self, finetuning=False, num_classes=3, hidden_size=50):
super().__init__()
self.bert = BertModel.from_pretrained('bert-base-uncased')
self.bert_output_size = 768
self.hidden_size = hidden_size
self.rnn = nn.LSTM(input_size=self.bert_output_size, hidden_size=self.hidden_size, batch_first=True, bidirectional=True)
self.fc = nn.Linear(self.hidden_size * 2, num_classes)
self.finetuning = finetuning
def forward(self, x, ):
if self.training and self.finetuning:
self.bert.train()
encoded_layers, _ = self.bert(x)
enc1 = encoded_layers[-1] # [batch_size, max_len, hidden_size]
else:
self.bert.eval()
with torch.no_grad():
encoded_layers, _ = self.bert(x)
enc1 = encoded_layers[-1]
enc, (final_hidden_state, final_cell_state) = self.rnn(enc1) # final_hidden_sate: [1, batch_size, hidden_size]
# enc: [batch_size, seq_len, num_directions * hidden_size]
# Decode the hidden state of the last time step
enc = enc[:, -1, :]
logits = self.fc(enc) # logits: [batch_size, num_classes]
return logits
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment