Skip to content

Instantly share code, notes, and snippets.

@ankitshekhawat
Last active May 8, 2019 16:42
Show Gist options
  • Save ankitshekhawat/baedb4840698b5c34dca6418c05ce74f to your computer and use it in GitHub Desktop.
Save ankitshekhawat/baedb4840698b5c34dca6418c05ce74f to your computer and use it in GitHub Desktop.
accuracy metric for online triplet
def recall_top_k(y_true, y_pred):
# get batch size
batch_size = y_pred.shape[0]
# get pairwise distances
dists = _pairwise_distances(y_pred)
# get indexes of the closest item in the batch predictions
# (item itself is always the top_k, so hence getting the second top K)
# top_k function gives indexes for the highest values so subtracting from 2(or a high enough number) to invert the matrix
_, pred_indexes = tf.nn.top_k(2-dists, 2, sorted=False)
# get the indexes of truth vector, multiplying with inverted identity matrix to mask out the item itself
_, true_indexes = tf.nn.top_k((1-tf.eye(batch_size))*y_true)
# sum of equality divided by batchsize to get accuracy between 0 to 1
return tf.math.reduce_sum(tf.cast(tf.equal(true_indexes[:,0], true_indexes[:,0]), tf.float32))/batch_size
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment