Last active
August 4, 2016 23:30
-
-
Save MInner/8b0c0a0e528303b132bf02e277199996 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
import numpy as np | |
class tnpg: | |
def __init__(self, tf_arr): | |
self.arr = tf_arr | |
def __getitem__(self, idx): | |
if any([type(x) not in [slice, tf.python.framework.ops.Tensor, list] for x in idx]): | |
raise ValueError("Unsupported type (not slice, tensor or list)! " | |
"%s" % ([type(x) for x in idx])) | |
n_gather = sum([type(x) is not slice for x in idx]) | |
if n_gather == 0: | |
raise ValueError("An empty gather!") | |
elif n_gather == 1: | |
return self.plain_gather(idx) | |
elif n_gather > 1: | |
if any([type(x) is slice for x in idx]): | |
raise ValueError("Support only a[:, vec, :] -> same shape " | |
"and a[vec1, vec2, vec3] -> vec " | |
"(one or all)") | |
return self.multi_gather(idx) | |
def plain_gather(self, idx): | |
# can actually also do | |
# A[:, b, :] with multidimentonal b | |
idx_true_s = [i for i, x in enumerate(idx) if type(x) is not slice] | |
fwd_sub_perm = idx_true_s + [x for x in range(len(idx)) if (x not in idx_true_s)] | |
inv_sub_perm = [0] * len(idx) | |
for i, p in enumerate(fwd_sub_perm): | |
inv_sub_perm[p] = i | |
def pad_perm(perm, l): | |
return tf.concat(0, [perm, tf.range(len(perm), l)]) | |
full_len = tf.rank(self.arr) | |
gather_idx = (tf.pack([idx[i] for i in idx_true_s]) | |
if len(idx_true_s) > 1 else idx[idx_true_s[0]]) | |
val = tf.transpose( | |
tf.gather( | |
tf.transpose(self.arr, pad_perm(fwd_sub_perm, full_len)), | |
gather_idx | |
), pad_perm(inv_sub_perm, full_len) | |
) | |
return val | |
def multi_gather(self, idx): | |
indices = tf.transpose(tf.pack(idx)) | |
val = tf.gather_nd(self.arr, indices) | |
return val | |
def assert_np_tf_eq(arr, lambda_np, lambda_tf): | |
result_np = lambda_np(arr) | |
if type(lambda_tf) != list: | |
lambda_tf = [lambda_tf] | |
for lambda_tf_i in lambda_tf: | |
with tf.Session() as s: | |
expr = tf.constant(arr) | |
expr_l = lambda_tf_i(expr) | |
result_tf = s.run(expr_l) | |
assert np.all(np.isclose(result_tf, result_np)) | |
def test_arr_gather(): | |
arr = np.random.rand(5, 3, 4) | |
assert_np_tf_eq( | |
arr, | |
lambda arr: arr[[2, 1], :, :], | |
[ | |
lambda arr: tf.gather(arr, [2, 1]), | |
lambda arr: tf.gather(arr, tf.constant([2, 1])), | |
lambda arr: tnpg(arr)[[2, 1], :, :] | |
], | |
) | |
assert_np_tf_eq( | |
arr, | |
lambda arr: arr[:, [2, 1], :], | |
[ | |
(lambda arr: | |
tf.transpose( | |
tf.gather( | |
tf.transpose(arr, [1, 0, 2]), | |
[2, 1] | |
), | |
[1, 0, 2] | |
) | |
), | |
(lambda arr: | |
tf.transpose( | |
tf.gather( | |
tf.transpose(arr, [1, 0, 2]), | |
tf.constant([2, 1]) | |
), | |
[1, 0, 2] | |
) | |
), | |
lambda arr: tnpg(arr)[:, tf.constant([2, 1]), :], | |
lambda arr: tnpg(arr)[:, tf.constant([2, 1])], | |
lambda arr: tnpg(arr)[:, [2, 1]] | |
] | |
) | |
arr = np.random.rand(5, 4) | |
assert_np_tf_eq( | |
arr, | |
lambda arr: arr[[2, 1], [3, 1]], | |
[ | |
lambda arr: tnpg(arr)[tf.constant([2, 1]), tf.constant([3, 1])], | |
lambda arr: tnpg(arr)[[2, 1], [3, 1]] | |
] | |
) | |
print('OK') | |
test_arr_gather() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I need to slice a tensor this way:
matrix[row_indices, col_indices]
wherematrix
is amxn
tensor androw_indices
andcol_indices
are int32 vectors of sizesk
such that k is less than m and n respectively. I want to slice and obtain a lengthk
vector eventually i.e. vectorized operation ofmatrix[i,j]
What would be an optimized way of doing this?
tf.gather(matrix,indices)
takes only a vectorindices
and can return all the rows corresponding to the vectorindices