Skip to content

Instantly share code, notes, and snippets.

@gmihaila
Created January 10, 2021 17:02
Show Gist options
  • Save gmihaila/8052f4eb399c668a502173e7dde29e66 to your computer and use it in GitHub Desktop.
Save gmihaila/8052f4eb399c668a502173e7dde29e66 to your computer and use it in GitHub Desktop.
Bert Inner Workings output.
class BertOutput(torch.nn.Module):
def __init__(self, config):
super().__init__()
self.dense = torch.nn.Linear(config.intermediate_size, config.hidden_size)
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
print('\nHidden States:\n', hidden_states.shape)
hidden_states = self.dense(hidden_states)
print('\nHidden States Linear Layer:\n', hidden_states.shape)
hidden_states = self.dropout(hidden_states)
print('\nHidden States Dropout Layer:\n', hidden_states.shape)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
print('\nHidden States Layer Normalization:\n', hidden_states.shape)
return hidden_states
# Create bert output layer.
bert_output_block = BertOutput(bert_configuraiton)
# Perform forward pass - attention_output[0] dealing with tuple.
layer_output = bert_output_block.forward(hidden_states=intermediate_output, input_tensor=attention_output[0])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment