Skip to content

Instantly share code, notes, and snippets.

Avatar
✌️

Justin Evans eustin

✌️
View GitHub Profile
View letor-part-2-model-training.py
hist = model.fit(
[query_embeddings, docs_averaged_embeddings],
relevance_grades_prob_dist,
epochs=50,
verbose=False
)
View letor-part-2-listnet.py
query_input = tf.keras.layers.Input(shape=(1, EMBEDDING_DIMS, ), dtype=tf.float32, name='query')
docs_input = tf.keras.layers.Input(shape=(NUM_DOCS_PER_QUERY, EMBEDDING_DIMS, ), dtype=tf.float32,
name='docs')
expand_batch = ExpandBatchLayer(name='expand_batch')
dense_1 = tf.keras.layers.Dense(units=3, activation='linear', name='dense_1')
dense_out = tf.keras.layers.Dense(units=1, activation='linear', name='scores')
scores_prob_dist = tf.keras.layers.Dense(units=NUM_DOCS_PER_QUERY, activation='softmax',
name='scores_prob_dist')
View letor-part-2-expand-batch-layer.py
class ExpandBatchLayer(tf.keras.layers.Layer):
def __init__(self, **kwargs):
super(ExpandBatchLayer, self).__init__(**kwargs)
def call(self, input):
queries, docs = input
batch, num_docs, embedding_dims = tf.unstack(tf.shape(docs))
expanded_queries = tf.gather(queries, tf.zeros([num_docs], tf.int32), axis=1)
return tf.concat([expanded_queries, docs], axis=-1)
View letor-part-2-kl-divergence-manual-eg-loss.py
per_example_loss = tf.reduce_sum(
relevance_grades_prob_dist * tf.math.log(relevance_grades_prob_dist / scores_prob_dist),
axis=-1
)
print(per_example_loss)
View letor-part-2-kl-divergence-tf.py
loss = tf.keras.losses.KLDivergence()
batch_loss = loss(relevance_grades_prob_dist, scores_prob_dist)
print(batch_loss)
View letor-part-2-grades-softmax.py
relevance_grades_prob_dist = tf.nn.softmax(relevance_grades, axis=-1)
print(relevance_grades_prob_dist)
View letor-part-2-scores-softmax.py
scores_for_softmax = tf.squeeze(scores_out, axis=-1)
scores_prob_dist = tf.nn.softmax(scores_for_softmax, axis=-1)
print(scores_prob_dist)
View letor-part-2-scores.py
scores = tf.keras.layers.Dense(units=1, activation='linear')
scores_out = scores(dense_1_out)
print(scores_out)