Skip to content

Instantly share code, notes, and snippets.

@MInner
Last active August 4, 2016 23:30
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save MInner/8b0c0a0e528303b132bf02e277199996 to your computer and use it in GitHub Desktop.
Save MInner/8b0c0a0e528303b132bf02e277199996 to your computer and use it in GitHub Desktop.
import tensorflow as tf
import numpy as np
class tnpg:
def __init__(self, tf_arr):
self.arr = tf_arr
def __getitem__(self, idx):
if any([type(x) not in [slice, tf.python.framework.ops.Tensor, list] for x in idx]):
raise ValueError("Unsupported type (not slice, tensor or list)! "
"%s" % ([type(x) for x in idx]))
n_gather = sum([type(x) is not slice for x in idx])
if n_gather == 0:
raise ValueError("An empty gather!")
elif n_gather == 1:
return self.plain_gather(idx)
elif n_gather > 1:
if any([type(x) is slice for x in idx]):
raise ValueError("Support only a[:, vec, :] -> same shape "
"and a[vec1, vec2, vec3] -> vec "
"(one or all)")
return self.multi_gather(idx)
def plain_gather(self, idx):
# can actually also do
# A[:, b, :] with multidimentonal b
idx_true_s = [i for i, x in enumerate(idx) if type(x) is not slice]
fwd_sub_perm = idx_true_s + [x for x in range(len(idx)) if (x not in idx_true_s)]
inv_sub_perm = [0] * len(idx)
for i, p in enumerate(fwd_sub_perm):
inv_sub_perm[p] = i
def pad_perm(perm, l):
return tf.concat(0, [perm, tf.range(len(perm), l)])
full_len = tf.rank(self.arr)
gather_idx = (tf.pack([idx[i] for i in idx_true_s])
if len(idx_true_s) > 1 else idx[idx_true_s[0]])
val = tf.transpose(
tf.gather(
tf.transpose(self.arr, pad_perm(fwd_sub_perm, full_len)),
gather_idx
), pad_perm(inv_sub_perm, full_len)
)
return val
def multi_gather(self, idx):
indices = tf.transpose(tf.pack(idx))
val = tf.gather_nd(self.arr, indices)
return val
def assert_np_tf_eq(arr, lambda_np, lambda_tf):
result_np = lambda_np(arr)
if type(lambda_tf) != list:
lambda_tf = [lambda_tf]
for lambda_tf_i in lambda_tf:
with tf.Session() as s:
expr = tf.constant(arr)
expr_l = lambda_tf_i(expr)
result_tf = s.run(expr_l)
assert np.all(np.isclose(result_tf, result_np))
def test_arr_gather():
arr = np.random.rand(5, 3, 4)
assert_np_tf_eq(
arr,
lambda arr: arr[[2, 1], :, :],
[
lambda arr: tf.gather(arr, [2, 1]),
lambda arr: tf.gather(arr, tf.constant([2, 1])),
lambda arr: tnpg(arr)[[2, 1], :, :]
],
)
assert_np_tf_eq(
arr,
lambda arr: arr[:, [2, 1], :],
[
(lambda arr:
tf.transpose(
tf.gather(
tf.transpose(arr, [1, 0, 2]),
[2, 1]
),
[1, 0, 2]
)
),
(lambda arr:
tf.transpose(
tf.gather(
tf.transpose(arr, [1, 0, 2]),
tf.constant([2, 1])
),
[1, 0, 2]
)
),
lambda arr: tnpg(arr)[:, tf.constant([2, 1]), :],
lambda arr: tnpg(arr)[:, tf.constant([2, 1])],
lambda arr: tnpg(arr)[:, [2, 1]]
]
)
arr = np.random.rand(5, 4)
assert_np_tf_eq(
arr,
lambda arr: arr[[2, 1], [3, 1]],
[
lambda arr: tnpg(arr)[tf.constant([2, 1]), tf.constant([3, 1])],
lambda arr: tnpg(arr)[[2, 1], [3, 1]]
]
)
print('OK')
test_arr_gather()
@tejaskhot
Copy link

I need to slice a tensor this way:
matrix[row_indices, col_indices] where matrix is a mxn tensor and row_indices and col_indices are int32 vectors of sizes k such that k is less than m and n respectively. I want to slice and obtain a length k vector eventually i.e. vectorized operation of matrix[i,j]
What would be an optimized way of doing this? tf.gather(matrix,indices) takes only a vector indices and can return all the rows corresponding to the vector indices

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment