Created
June 30, 2021 14:03
-
-
Save Muhammad4hmed/05cf977a7442c53a27df9346955229a4 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class LitModel(nn.Module): | |
def __init__(self): | |
super().__init__() | |
config = AutoConfig.from_pretrained(ROBERTA_PATH) | |
config.update({"output_hidden_states":True, | |
"hidden_dropout_prob": 0.0, | |
"layer_norm_eps": 1e-7}) | |
self.roberta = AutoModel.from_pretrained(ROBERTA_PATH, config=config) | |
self.attention = nn.Sequential( | |
nn.Linear(768, 512), | |
nn.Tanh(), | |
nn.Linear(512, 1), | |
nn.Softmax(dim=1) | |
) | |
self.regressor = nn.Sequential( | |
nn.Linear(768, 1) | |
) | |
def forward(self, input_ids, attention_mask): | |
roberta_output = self.roberta(input_ids=input_ids, | |
attention_mask=attention_mask) | |
# There are a total of 13 layers of hidden states. | |
# 1 for the embedding layer, and 12 for the 12 Roberta layers. | |
# We take the hidden states from the last Roberta layer. | |
last_layer_hidden_states = roberta_output.hidden_states[-1] | |
# The number of cells is MAX_LEN. | |
# The size of the hidden state of each cell is 768 (for roberta-base). | |
# In order to condense hidden states of all cells to a context vector, | |
# we compute a weighted average of the hidden states of all cells. | |
# We compute the weight of each cell, using the attention neural network. | |
weights = self.attention(last_layer_hidden_states) | |
# weights.shape is BATCH_SIZE x MAX_LEN x 1 | |
# last_layer_hidden_states.shape is BATCH_SIZE x MAX_LEN x 768 | |
# Now we compute context_vector as the weighted average. | |
# context_vector.shape is BATCH_SIZE x 768 | |
context_vector = torch.sum(weights * last_layer_hidden_states, dim=1) | |
# Now we reduce the context vector to the prediction score. | |
return self.regressor(context_vector) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment