Skip to content

Instantly share code, notes, and snippets.

@prafulgondane
Last active June 20, 2022 17:04
Show Gist options
  • Save prafulgondane/8d1e8fc1ada1e14b9ae98999c5c4242a to your computer and use it in GitHub Desktop.
Save prafulgondane/8d1e8fc1ada1e14b9ae98999c5c4242a to your computer and use it in GitHub Desktop.
max_len = 70
input_ids = Input(shape=(max_len,), dtype=tf.int32, name="input_ids")
input_mask = Input(shape=(max_len,), dtype=tf.int32, name="attention_mask")
embeddings = bert(input_ids,attention_mask = input_mask)[0]
out = tf.keras.layers.GlobalMaxPool1D()(embeddings)
out = Dense(128, activation='relu')(out)
out = tf.keras.layers.Dropout(0.1)(out)
out = Dense(32,activation = 'relu')(out)
y = Dense(2,activation = 'softmax')(out)
model = tf.keras.Model(inputs=[input_ids, input_mask], outputs=y)
model.layers[2].trainable = False
optimizer = Adam(
learning_rate=5e-05, # this learning rate is for bert model , taken from huggingface website
epsilon=1e-08,
decay=0.01,
clipnorm=1.0)
# Set loss and metrics
loss =CategoricalCrossentropy(from_logits = True)
metric = CategoricalAccuracy('balanced_accuracy'),
# Compile the model
model.compile(
optimizer = optimizer,
loss = loss,
metrics=['acc'])
keras.utils.plot_model(model, "multi_input_and_output_model.png", show_shapes=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment