Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save deelipku23/a39ee0216ac0932eaa52264aa1e8ee55 to your computer and use it in GitHub Desktop.
Save deelipku23/a39ee0216ac0932eaa52264aa1e8ee55 to your computer and use it in GitHub Desktop.
def create_model(model_name):
config = XLNetConfig()
config.output_hidden_states = False
question_bert_model = TFXLNetModel.from_pretrained(model_name)
answer_bert_model = TFXLNetModel.from_pretrained(model_name)
question_enc = tf.keras.layers.Input((MAX_SEQUENCE_LENGTH,), dtype=tf.int32)
question_mask = tf.keras.layers.Input((MAX_SEQUENCE_LENGTH,), dtype=tf.int32)
question_type_ids = tf.keras.layers.Input((MAX_SEQUENCE_LENGTH,), dtype=tf.int32)
answer_enc = tf.keras.layers.Input((MAX_SEQUENCE_LENGTH,), dtype=tf.int32)
answer_mask = tf.keras.layers.Input((MAX_SEQUENCE_LENGTH,), dtype=tf.int32)
answer_type_ids = tf.keras.layers.Input((MAX_SEQUENCE_LENGTH,), dtype=tf.int32)
question_bert = question_bert_model(question_enc, attention_mask=question_mask, token_type_ids=question_type_ids)[0]
answer_bert = answer_bert_model(answer_enc, attention_mask=answer_mask, token_type_ids=answer_type_ids)[0]
question_bert_summary = tf.keras.layers.Flatten()(tf.keras.layers.AveragePooling1D(MAX_SEQUENCE_LENGTH)(question_bert))
answer_bert_summary = tf.keras.layers.Flatten()(tf.keras.layers.AveragePooling1D(MAX_SEQUENCE_LENGTH)(answer_bert))
combined_bert_summary = tf.keras.layers.Concatenate()([question_bert_summary, answer_bert_summary])
dropout_bert = tf.keras.layers.Dropout(0.2)(combined_bert_summary)
output = tf.keras.layers.Dense(30, activation='sigmoid')(dropout_bert)
model = tf.keras.models.Model(inputs=[question_enc, question_mask, question_type_ids, answer_enc, answer_mask, answer_type_ids], outputs=output)
return model
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment