Skip to content

Instantly share code, notes, and snippets.

@RayWilliam46
Last active February 1, 2021 00:30
Show Gist options
  • Save RayWilliam46/07e8718b2e7b102b9617e06c9faca27c to your computer and use it in GitHub Desktop.
Save RayWilliam46/07e8718b2e7b102b9617e06c9faca27c to your computer and use it in GitHub Desktop.
Template for building a model off of the BERT or DistilBERT architecture
MAX_LENGTH = 128
LAYER_DROPOUT = 0.2
LEARNING_RATE = 5e-5
RANDOM_STATE = 42
def build_model(transformer, max_length=MAX_LENGTH):
"""""""""
Template for building a model off of the BERT or DistilBERT architecture
for a binary classification task.
Input:
- transformer: a base Hugging Face transformer model object (BERT or DistilBERT)
with no added classification head attached.
- max_length: integer controlling the maximum number of encoded tokens
in a given sequence.
Output:
- model: a compiled tf.keras.Model with added classification layers
on top of the base pre-trained model architecture.
""""""""""
# Define weight initializer with a random seed to ensure reproducibility
weight_initializer = tf.keras.initializers.GlorotNormal(seed=RANDOM_STATE)
# Define input layers
input_ids_layer = tf.keras.layers.Input(shape=(max_length,),
name='input_ids',
dtype='int32')
input_attention_layer = tf.keras.layers.Input(shape=(max_length,),
name='input_attention',
dtype='int32')
# DistilBERT outputs a tuple where the first element at index 0
# represents the hidden-state at the output of the model's last layer.
# It is a tf.Tensor of shape (batch_size, sequence_length, hidden_size=768).
last_hidden_state = transformer([input_ids_layer, input_attention_layer])[0]
# We only care about DistilBERT's output for the [CLS] token,
# which is located at index 0 of every encoded sequence.
# Splicing out the [CLS] tokens gives us 2D data.
cls_token = last_hidden_state[:, 0, :]
## ##
## Define additional dropout and dense layers here ##
## ##
# Define a single node that makes up the output layer (for binary classification)
output = tf.keras.layers.Dense(1,
activation='sigmoid',
kernel_initializer=weight_initializer,
kernel_constraint=None,
bias_initializer='zeros'
)(cls_token)
# Define the model
model = tf.keras.Model([input_ids_layer, input_attention_layer], output)
# Compile the model
model.compile(tf.keras.optimizers.Adam(lr=LEARNING_RATE),
loss=focal_loss(),
metrics=['accuracy'])
return model
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment