Skip to content

Instantly share code, notes, and snippets.

@eustin
Created June 8, 2020 04:41
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save eustin/f4196b65b507f7c635dbbcbf5e50a8a1 to your computer and use it in GitHub Desktop.
Save eustin/f4196b65b507f7c635dbbcbf5e50a8a1 to your computer and use it in GitHub Desktop.
query_input = tf.keras.layers.Input(shape=(1, EMBEDDING_DIMS, ), dtype=tf.float32, name='query')
docs_input = tf.keras.layers.Input(shape=(NUM_DOCS_PER_QUERY, EMBEDDING_DIMS, ), dtype=tf.float32,
name='docs')
expand_batch = ExpandBatchLayer(name='expand_batch')
dense_1 = tf.keras.layers.Dense(units=3, activation='linear', name='dense_1')
dense_out = tf.keras.layers.Dense(units=1, activation='linear', name='scores')
scores_prob_dist = tf.keras.layers.Dense(units=NUM_DOCS_PER_QUERY, activation='softmax',
name='scores_prob_dist')
expanded_batch = expand_batch([query_input, docs_input])
dense_1_out = dense_1(expanded_batch)
scores = tf.keras.layers.Flatten()(dense_out(dense_1_out))
model_out = scores_prob_dist(scores)
model = tf.keras.models.Model(inputs=[query_input, docs_input], outputs=[model_out])
model.compile(optimizer=tf.keras.optimizers.SGD(learning_rate=0.03, momentum=0.9),
loss=tf.keras.losses.KLDivergence())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment