Skip to content

Instantly share code, notes, and snippets.

@ydennisy
Created August 11, 2023 18:44
Show Gist options
  • Save ydennisy/fec55fab84d107b72852ba2d2c2b61db to your computer and use it in GitHub Desktop.
Save ydennisy/fec55fab84d107b72852ba2d2c2b61db to your computer and use it in GitHub Desktop.
A siamese network for text embedding.
import tensorflow as tf
from keras_nlp.layers import TransformerDecoder
MAX_LEN, VOCAB_SIZE, EMBED_DIMS = 128, 128, 32
class TokenAndPositionEmbedding(tf.keras.layers.Layer):
def __init__(self, maxlen, vocab_size, embed_dim):
super().__init__()
self.token_emb = tf.keras.layers.Embedding(
input_dim=vocab_size, output_dim=embed_dim
)
self.pos_emb = tf.keras.layers.Embedding(input_dim=maxlen, output_dim=embed_dim)
def call(self, x):
maxlen = tf.shape(x)[-1]
positions = tf.range(start=0, limit=maxlen, delta=1)
positions = self.pos_emb(positions)
x = self.token_emb(x)
return x + positions
def compute_similarity_matrix(embeddings_1, embeddings_2):
similarity_matrix = tf.matmul(embeddings_1, embeddings_2, transpose_b=True)
return similarity_matrix
loss_function = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
@tf.function
def multiple_negatives_ranking_loss(y_true, similarity_scores):
labels = tf.range(tf.shape(similarity_scores)[0])
return loss_function(labels, similarity_scores)
@tf.function
def top_k_accuracy(y_true, similarity_scores, k=1):
top_k_indices = tf.math.top_k(similarity_scores, k=k).indices
correct = tf.reduce_any(
tf.equal(
top_k_indices,
tf.expand_dims(tf.range(tf.shape(similarity_scores)[0]), axis=1),
),
axis=1,
)
return tf.reduce_mean(tf.cast(correct, dtype=tf.float32))
@tf.function
def mean_reciprocal_rank(y_true, similarity_scores):
sorted_indices = tf.argsort(similarity_scores, direction="DESCENDING")
rank = tf.where(
tf.equal(
sorted_indices,
tf.expand_dims(tf.range(tf.shape(similarity_scores)[0]), axis=1),
)
)[:, 1]
reciprocal_rank = 1 / (tf.cast(rank, tf.float32) + 1)
return tf.reduce_mean(reciprocal_rank)
decoder = TransformerDecoder(intermediate_dim=8, num_heads=2)
embedding = TokenAndPositionEmbedding(MAX_LEN, VOCAB_SIZE, EMBED_DIMS)
inputs = tf.keras.Input(shape=(MAX_LEN,))
x = embedding(inputs)
x = decoder(x)
outputs = tf.keras.layers.GlobalAveragePooling1D()(x)
embedding_model = tf.keras.Model(inputs, outputs, name="embedding_model")
inputs_1 = tf.keras.Input(shape=(MAX_LEN,), name="query_input")
inputs_2 = tf.keras.Input(shape=(MAX_LEN,), name="text_input")
tower_1 = embedding_model(inputs_1)
tower_2 = embedding_model(inputs_2)
similarity_matrix = compute_similarity_matrix(tower_1, tower_2)
model = tf.keras.Model(inputs=[inputs_1, inputs_2], outputs=similarity_matrix)
model.compile(
loss=multiple_negatives_ranking_loss, metrics=[top_k_accuracy, mean_reciprocal_rank]
)
model.summary()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment