Skip to content

Instantly share code, notes, and snippets.

@Ending2015a
Last active March 14, 2020 11:17
Show Gist options
  • Save Ending2015a/95de2008b96326923f9c4b4968defe68 to your computer and use it in GitHub Desktop.
Save Ending2015a/95de2008b96326923f9c4b4968defe68 to your computer and use it in GitHub Desktop.
My gather
import numpy as np
import tensorflow as tf # 2.0
'''
input = ... # shape (batch, h*w, 19, 2)
indices = ... # shape (h*w)
tf.stack([tf.gather(x, y) for x, y in zip(tf.unstack(input, axis=1), tf.unstack(tf.reshape(indices, [-1, 1]), axis=0))], axis=1)
'''
# examples
x = np.array([[[1, 2, 3],
[4, 5, 6],
[7, 8, 9]],
[[10, 11, 12],
[13, 14, 15],
[16, 17, 18]]], dtype=np.float32)
indices = np.array([[0, 1], [0, 2]], dtype=np.int32)
tf.stack([tf.gather(x, y, axis=1) for x, y in zip(tf.unstack(x, axis=0), tf.unstack(indices, axis=0))], axis=0)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment