Last active
February 1, 2021 00:30
-
-
Save RayWilliam46/07e8718b2e7b102b9617e06c9faca27c to your computer and use it in GitHub Desktop.
Template for building a model off of the BERT or DistilBERT architecture
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
MAX_LENGTH = 128 | |
LAYER_DROPOUT = 0.2 | |
LEARNING_RATE = 5e-5 | |
RANDOM_STATE = 42 | |
def build_model(transformer, max_length=MAX_LENGTH): | |
""""""""" | |
Template for building a model off of the BERT or DistilBERT architecture | |
for a binary classification task. | |
Input: | |
- transformer: a base Hugging Face transformer model object (BERT or DistilBERT) | |
with no added classification head attached. | |
- max_length: integer controlling the maximum number of encoded tokens | |
in a given sequence. | |
Output: | |
- model: a compiled tf.keras.Model with added classification layers | |
on top of the base pre-trained model architecture. | |
"""""""""" | |
# Define weight initializer with a random seed to ensure reproducibility | |
weight_initializer = tf.keras.initializers.GlorotNormal(seed=RANDOM_STATE) | |
# Define input layers | |
input_ids_layer = tf.keras.layers.Input(shape=(max_length,), | |
name='input_ids', | |
dtype='int32') | |
input_attention_layer = tf.keras.layers.Input(shape=(max_length,), | |
name='input_attention', | |
dtype='int32') | |
# DistilBERT outputs a tuple where the first element at index 0 | |
# represents the hidden-state at the output of the model's last layer. | |
# It is a tf.Tensor of shape (batch_size, sequence_length, hidden_size=768). | |
last_hidden_state = transformer([input_ids_layer, input_attention_layer])[0] | |
# We only care about DistilBERT's output for the [CLS] token, | |
# which is located at index 0 of every encoded sequence. | |
# Splicing out the [CLS] tokens gives us 2D data. | |
cls_token = last_hidden_state[:, 0, :] | |
## ## | |
## Define additional dropout and dense layers here ## | |
## ## | |
# Define a single node that makes up the output layer (for binary classification) | |
output = tf.keras.layers.Dense(1, | |
activation='sigmoid', | |
kernel_initializer=weight_initializer, | |
kernel_constraint=None, | |
bias_initializer='zeros' | |
)(cls_token) | |
# Define the model | |
model = tf.keras.Model([input_ids_layer, input_attention_layer], output) | |
# Compile the model | |
model.compile(tf.keras.optimizers.Adam(lr=LEARNING_RATE), | |
loss=focal_loss(), | |
metrics=['accuracy']) | |
return model |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment