Skip to content

Instantly share code, notes, and snippets.

@ditwoo
Created December 19, 2019 12:42
Show Gist options
  • Save ditwoo/bc221036d04c103dcf4063819221d4c6 to your computer and use it in GitHub Desktop.
Save ditwoo/bc221036d04c103dcf4063819221d4c6 to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
from transformers import AutoConfig, AutoModel
class PooledLstmTransfModel(nn.Module):
def __init__(self,
pretrain_dir: str,
num_classes: int = 1):
super(PooledLstmTransfModel, self).__init__()
config = AutoConfig.from_pretrained(
pretrain_dir,
num_labels=num_classes
)
self.bert = AutoModel.from_pretrained(
pretrain_dir,
config=config
)
self.rnns = nn.LSTM(
input_size=config.hidden_size,
hidden_size=config.hidden_size // 2,
batch_first=True,
bidirectional=True,
)
self.pre_classifier = nn.Linear(config.hidden_size * 4, config.hidden_size)
self.classifier = nn.Sequential(
nn.ReLU(),
nn.Dropout(config.hidden_dropout_prob),
nn.Linear(config.hidden_size, num_classes)
)
def forward(self, sequences, segments=None, head_mask=None):
mask = (sequences > 0).float()
bert_output = self.bert(
input_ids=sequences,
attention_mask=mask,
token_type_ids=segments,
head_mask=head_mask
)
hidden_state = bert_output[0] # (bs, seq_len, dim)
rnn_hidden_states, _ = self.rnns(hidden_state) # (bs, seq_len, dim)
pooled_output = torch.cat([
torch.max(hidden_state, 1)[0],
torch.mean(hidden_state, 1),
torch.max(rnn_hidden_states, 1)[0],
torch.mean(rnn_hidden_states, 1),
], 1)
pooled_output = self.pre_classifier(pooled_output) # (bs, dim)
logits = self.classifier(pooled_output) # (bs, dim)
return logits
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment