Skip to content

Instantly share code, notes, and snippets.

@kusal1990
Created June 25, 2022 06:36
Show Gist options
  • Save kusal1990/37152227710090be2ca844218657b4ae to your computer and use it in GitHub Desktop.
Save kusal1990/37152227710090be2ca844218657b4ae to your computer and use it in GitHub Desktop.
tf.keras.backend.clear_session()
pre_trained_model = TFBertForMultipleChoice.from_pretrained('bert-base-uncased')
model_input_ids = Input(shape=(5,512,), name='input_tokens', dtype='int32')
masks_input = Input(shape=(5,512,), name='attention_mask', dtype='int32')
model_token_type_ids = Input(shape=(5,512,), name='token_type_ids', dtype='int32')
x = {'input_ids':model_input_ids,
'attention_mask':masks_input,
'token_type_ids':model_token_type_ids}
x = pre_trained_model(x)['logits']
outputs = Dense(5, activation='softmax')(x)
model = Model(inputs=[model_input_ids, masks_input,model_token_type_ids], outputs=outputs)
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-5)
loss = tf.keras.losses.CategoricalCrossentropy()
model.compile(optimizer=optimizer,
loss=loss, metrics=['accuracy'])
y=tf.keras.utils.to_categorical(easy_train_labels.ravel())
early_stop = EarlyStopping(patience=2, monitor='val_accuracy')
model.fit(x=[easy_train_input_ids, easy_train_attention_mask, easy_train_token_type_ids], y=y,
epochs=3,batch_size=2)
@kusal1990
Copy link
Author

ok

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment