Skip to content

Instantly share code, notes, and snippets.

@vihari
Created October 7, 2018 18:50
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save vihari/c3c59bf2e4f18722a872499b0394986c to your computer and use it in GitHub Desktop.
Save vihari/c3c59bf2e4f18722a872499b0394986c to your computer and use it in GitHub Desktop.
def rankL(np_rank):
r = int(np_rank[-1])
_l = 0
for k in range(1, r+1):
_l += 1./k
return np.float32(_l)
"""
labels are assumed to be 1 hot encoded
"""
def warp_loss(labels, logits):
# for easy broadcasting
labels, logits = tf.transpose(labels, [1, 0]), tf.transpose(logits, [1, 0])
f_y = tf.reduce_sum(logits*labels, axis=0)
rank = tf.reduce_sum(tf.maximum(tf.sign(1+logits-f_y), 0), axis=0)
diff = tf.reduce_sum(tf.maximum(1+logits-f_y, 0), axis=0)
with tf.control_dependencies([tf.assert_greater(rank, tf.zeros_like(rank))]):
return tf.py_func(rankL, [rank], tf.float32) * diff/rank
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment