Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save deelipku23/df51d90a7b375bb81f5f874b2a754734 to your computer and use it in GitHub Desktop.
Save deelipku23/df51d90a7b375bb81f5f874b2a754734 to your computer and use it in GitHub Desktop.
def create_model(model_name):
config = RobertaConfig()
config.output_hidden_states = False
question_bert_model = TFRobertaModel.from_pretrained(model_name)
answer_bert_model = TFRobertaModel.from_pretrained(model_name)
question_enc = tf.keras.layers.Input((MAX_SEQUENCE_LENGTH,), dtype=tf.int32)
question_mask = tf.keras.layers.Input((MAX_SEQUENCE_LENGTH,), dtype=tf.int32)
question_type_ids = tf.keras.layers.Input((MAX_SEQUENCE_LENGTH,), dtype=tf.int32)
answer_enc = tf.keras.layers.Input((MAX_SEQUENCE_LENGTH,), dtype=tf.int32)
answer_mask = tf.keras.layers.Input((MAX_SEQUENCE_LENGTH,), dtype=tf.int32)
answer_type_ids = tf.keras.layers.Input((MAX_SEQUENCE_LENGTH,), dtype=tf.int32)
question_bert = question_bert_model(question_enc, attention_mask=question_mask)[0]
answer_bert = answer_bert_model(answer_enc, attention_mask=answer_mask)[0]
question_bert_summary = tf.keras.layers.Flatten()(tf.keras.layers.AveragePooling1D(MAX_SEQUENCE_LENGTH)(question_bert))
answer_bert_summary = tf.keras.layers.Flatten()(tf.keras.layers.AveragePooling1D(MAX_SEQUENCE_LENGTH)(answer_bert))
combined_bert_summary = tf.keras.layers.Concatenate()([question_bert_summary, answer_bert_summary])
dropout_bert = tf.keras.layers.Dropout(0.2)(combined_bert_summary)
output = tf.keras.layers.Dense(30, activation='sigmoid')(dropout_bert)
model = tf.keras.models.Model(inputs=[question_enc, question_mask, answer_enc, answer_mask], outputs=output)
print(model.summary())
return model
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment