Skip to content

Instantly share code, notes, and snippets.

@Akashdesarda
Created April 23, 2020 10:13
Show Gist options
  • Save Akashdesarda/dcf6ef3436d69e783834cc4da29766cc to your computer and use it in GitHub Desktop.
Save Akashdesarda/dcf6ef3436d69e783834cc4da29766cc to your computer and use it in GitHub Desktop.
distil_bert = 'distilbert-base-uncased'
config = DistilBertConfig(dropout=0.2, attention_dropout=0.2)
config.output_hidden_states = False
transformer_model = TFDistilBertModel.from_pretrained(distil_bert, config = config)
input_ids_in = tf.keras.layers.Input(shape=(128,), name='input_token', dtype='int32')
input_masks_in = tf.keras.layers.Input(shape=(128,), name='masked_token', dtype='int32')
embedding_layer = transformer_model(input_ids_in, attention_mask=input_masks_in)[0]
X = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(50, return_sequences=True, dropout=0.1, recurrent_dropout=0.1))(embedding_layer)
X = tf.keras.layers.GlobalMaxPool1D()(X)
X = tf.keras.layers.Dense(50, activation='relu')(X)
X = tf.keras.layers.Dropout(0.2)(X)
X = tf.keras.layers.Dense(6, activation='sigmoid')(X)
model = tf.keras.Model(inputs=[input_ids_in, input_masks_in], outputs = X)
for layer in model.layers[:3]:
layer.trainable = False
@marcosclima
Copy link

The code doesn't match your observations:
→Look at line #17 as 3D data is generated earlier embedding layer, we can use LSTM to extract great details.
→Next thing is to transform the 3D data into 2D so that we can use a FC layer. You can use any Pooling layer to perform this.
→ Also, note on line #18 & #19. We should always freeze the pre-trained weights of transformer model & never update them and update only remaining weights.

This code is a copy of: model_as_feature_extractor.py

Do you have the correct one?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment