Skip to content

Instantly share code, notes, and snippets.

@yqshao
Last active December 18, 2018 20:40
Show Gist options
  • Save yqshao/f73ba430a82ee2340d461bf35e142cf8 to your computer and use it in GitHub Desktop.
Save yqshao/f73ba430a82ee2340d461bf35e142cf8 to your computer and use it in GitHub Desktop.
Counting occurrence in Tensorflow

Suppose you have such a tensor

a = [0, 0, 0, 1, 1, 2]

and want to know for each i in a it is the nth i in a, that is, to get

b = [0, 1, 2, 0, 1, 0]

Here's the solution:

count_sum = tf.cumsum(tf.ones_like(a))
count_min = tf.segment_min(count_sum, a)
b = count_sum - tf.gather(count_min, a)

Now let a be unsorted:

a = [0, 1, 0, 2, 1, 2, 1, 2, 2, 2]

Then we need to sort a first

a = tf.constant([0,1,0,2,1,2,1,2,2,2])
a_args = tf.contrib.framework.argsort(a, stable=True)
a_sort = tf.gather(a, a_args)
count_sum = tf.cumsum(tf.ones_like(a_sort))
count_min = tf.segment_min(count_sum, a_sort)
b_sort = count_sum - tf.gather(count_min, a_sort)
b = tf.unsorted_segment_sum(b_sort, a_args, b_sort.shape[0])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment