Skip to content

Instantly share code, notes, and snippets.

@hanneshapke
Last active March 20, 2020 18:43
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save hanneshapke/3ae4be86b5d65aee7284080ad8d0405d to your computer and use it in GitHub Desktop.
Save hanneshapke/3ae4be86b5d65aee7284080ad8d0405d to your computer and use it in GitHub Desktop.
feature_spec = tf_transform_output.transformed_feature_spec()
feature_spec.pop(_LABEL_KEY)
inputs = {key: tf.keras.layers.Input(shape=(max_seq_length),
name=key, dtype=tf.int32)
for key in feature_spec.keys()}
input_word_ids = tf.cast(inputs["input_word_ids"], dtype=tf.int32)
input_mask = tf.cast(inputs["input_mask"], dtype=tf.int32)
input_type_ids = tf.cast(inputs["input_type_ids"], dtype=tf.int32)
bert_layer = load_bert_layer()
pooled_output, _ = bert_layer(
[input_word_ids,
input_mask,
input_type_ids
]
)
x = tf.keras.layers.Dense(256, activation='relu')(pooled_output)
dense = tf.keras.layers.Dense(64, activation='relu')(x)
pred = tf.keras.layers.Dense(1, activation='sigmoid')(dense)
model = tf.keras.Model(
inputs=[inputs['input_word_ids'],
inputs['input_mask'],
inputs['input_type_ids']],
outputs=pred
)
model.compile(loss='binary_crossentropy',
optimizer='adam',
metrics=['accuracy'])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment