Skip to content

Instantly share code, notes, and snippets.

@liuhengyue
Created September 6, 2018 20:20
Show Gist options
  • Save liuhengyue/8a86e0dfe03ef78a5d15df11d35746ae to your computer and use it in GitHub Desktop.
Save liuhengyue/8a86e0dfe03ef78a5d15df11d35746ae to your computer and use it in GitHub Desktop.
Comparision of tf.where and np.where
"""
Comparision of tf.where and np.where
"""
import tensorflow as tf
import numpy as np
X_batch = np.eye(5, dtype='bool')
X_batch[1,1] =False
X_batch[2,1] =True
X_batch[:,:] =False
print('Original ndarray:', X_batch)
X = tf.placeholder(dtype=tf.bool, shape=[5, 5])
tf_horizontal_indicies = tf.where(tf.reduce_any(X, axis=0))
tf_vertical_indicies = tf.where(tf.reduce_any(X, axis=1))
cond =tf.cond(
tf.equal(tf.size(tf_vertical_indicies), 0),
true_fn = lambda: tf_vertical_indicies,
false_fn = lambda: tf_vertical_indicies[1,-1]
)
with tf.Session() as sess:
print('tf indicies:')
print('vertical', sess.run(cond, feed_dict={X: X_batch}))
print(tf_vertical_indicies)
print(sess.run(tf_horizontal_indicies[1,-1], feed_dict={X: X_batch})) # access to specific element
print(sess.run(tf_vertical_indicies[1,-1], feed_dict={X: X_batch}))
print('np indicies:')
np_horizontal_indicies = np.where(np.any(X_batch, axis=0))[0]
np_vertical_indicies = np.where(np.any(X_batch, axis=1))[0]
print('vertical', np_vertical_indicies)
print(np_horizontal_indicies[1])
print(np_vertical_indicies[1])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment