Skip to content

Instantly share code, notes, and snippets.

@bxshi
Created January 12, 2017 21:34
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save bxshi/4911847068987d037e09b4bb973bf472 to your computer and use it in GitHub Desktop.
Save bxshi/4911847068987d037e09b4bb973bf472 to your computer and use it in GitHub Desktop.
Implement Hits@K evaluation metric for Knowledge Graph Completion tasks.
import tensorflow as tf
with tf.Session() as sess:
"""
idx (h,r) top_3
[ [
[0,1], [0,8,3],
[1,3], [7,2,1],
[2,4], [4,3,9],
] ]
triples (h,r,t)
[
[0,1,2],
[0,1,5],
[1,3,2],
[1,2,7],
[2,4,3],
[2,4,9],
[2,4,4]
]
"""
idx = tf.Variable([[0, 1], [1, 3], [2, 4]], trainable=False, dtype=tf.int32)
top_3 = tf.Variable([[0, 8, 3], [7, 2, 1], [4, 3, 9]], trainable=False, dtype=tf.int32)
triples = tf.Variable([[0, 1, 2], [0, 1, 5], [1, 3, 2], [1, 2, 7], [2, 4, 3], [2, 4, 9], [2, 4, 4]],
trainable=False,
dtype=tf.int32)
def hits_func(acc, item):
hr = item[0]
top = item[1]
mask = tf.logical_and(tf.equal(hr[0], triples[:, 0]), tf.equal(hr[1], triples[:, 1]))
t = tf.boolean_mask(triples[:, 2], mask)
def in_op(acc, it):
return tf.reduce_any(tf.equal(t, it))
hits = tf.scan(in_op, top, initializer=tf.Variable(initial_value=False, dtype=tf.bool, trainable=False))
return tf.reduce_mean(tf.cast(hits, dtype=tf.float32))
hits_list = tf.scan(hits_func, (idx, top_3),
initializer=tf.Variable(initial_value=0., dtype=tf.float32, trainable=False))
tf.global_variables_initializer().run()
print("Hits@3 is", sess.run(tf.reduce_mean(hits_list)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment