Skip to content

Instantly share code, notes, and snippets.

@bxshi
Created January 19, 2017 18:36
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save bxshi/eecc1fe0aafac85e61be9c38190d5fd1 to your computer and use it in GitHub Desktop.
Save bxshi/eecc1fe0aafac85e61be9c38190d5fd1 to your computer and use it in GitHub Desktop.
Calculate MeanRank and Hits@K using TensorFlow. From github.com/nddsg/ProjC (private repo right now)
def create_eval_ops(model_input, pred_y, all_triples, eval_triples, n_entity,
top_k, idx_1=0, idx_2=1, idx_3=2):
""" Evaluation operations for any model.
For given <h,r> predict t, idx_1 = 0, idx_2 = 1, idx_3 = 2
For given <t,r> predict h, idx_1 = 2, idx_2 = 1, idx_3 = 0
:param model_input: N by 3 matrix, each row is a h,r,t pair
:param pred_y: N by ENTITY_VOCAB matrix
:param all_triples: M by 3 matrix, contains all triples in the KG
:param eval_triples: M_{eval} by 3 matrix, contains all triples that will be
evaluated, this is a subset of all_triples. model_input
is a subset of eval_triples where the joint index
model_input[idx_1] and model_input[idx_2] is unique in
model_input
:param n_entity: Number of unique entities in the KG
:param top_k: Parameter of Hits@top_k
:param idx_1: First index of the <?,r> pair
:param idx_2: Second index of the <?,r> pair
:param idx_3: Target index in the h,r,t triple
:return:
"""
def get_id_mask(hrt, triples):
return tf.logical_and(tf.equal(hrt[idx_1], triples[:, idx_1]),
tf.equal(hrt[idx_2], triples[:, idx_2]))
def calculate_metrics(tensors):
# eval_hrt, a 3 element h,r,t triple
eval_hrt = tensors
# find the entity_vocab vector row id of the given h,r pair
pred_y_mask = get_id_mask(eval_hrt, model_input)
pred_score = tf.reshape(tf.boolean_mask(pred_y, pred_y_mask), [-1])
# score of current tail
target_score = pred_score[eval_hrt[idx_3]]
triple_mask = get_id_mask(eval_hrt, all_triples)
# disabling validate_indices will disable duplication check
entity_mask = tf.sparse_to_dense(tf.boolean_mask(all_triples[:, idx_3], triple_mask),
output_shape=[n_entity],
sparse_values=True,
default_value=False,
validate_indices=False)
# After masking, [i,j] will equals to min_score - 1e-5 if it is a positive instance
masked_pred_score = pred_score * tf.cast(tf.logical_not(entity_mask), tf.float32) - \
tf.cast(entity_mask, tf.float32) * 1e30
# Count how many entities has a score larger than target
def get_rank(score, entity_scores):
return tf.reduce_sum(tf.cast(tf.greater(score, entity_scores), tf.int32)) + 1
unfiltered_rank = get_rank(pred_score, target_score)
filtered_rank = get_rank(masked_pred_score, target_score)
unfiltered_hit = tf.where(unfiltered_rank <= top_k, 1, 0)
filtered_hit = tf.where(filtered_rank <= top_k, 1, 0)
return tf.stack(
[tf.cast(x, tf.float32) for x in [unfiltered_rank, filtered_rank, unfiltered_hit, filtered_hit]])
metrics = tf.reduce_mean(
tf.map_fn(calculate_metrics, eval_triples,
dtype=tf.float32, parallel_iterations=20,
back_prop=False, swap_memory=True),
axis=0)
return metrics
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment