Skip to content

Instantly share code, notes, and snippets.

@hccho2
Created February 21, 2019 08:39
Show Gist options
  • Save hccho2/3430e79703129d9ee815c9c226c4cb22 to your computer and use it in GitHub Desktop.
Save hccho2/3430e79703129d9ee815c9c226c4cb22 to your computer and use it in GitHub Desktop.
# coding: utf-8
import tensorflow as tf
import numpy as np
tf.reset_default_graph()
col_size=6
n = 3
a = np.arange(col_size*n*n).reshape(col_size*n,n)
x = tf.convert_to_tensor(a)
y = tf.reshape(x,[col_size*n*n])
index = [n*j+n*i*col_size +i for i in range(n) for j in range(col_size)]
z = tf.gather(y,index)
sess= tf.Session()
sess.run(z)
@hccho2
Copy link
Author

hccho2 commented Feb 22, 2019

def g():
    a = np.zeros([10,5])
    for j in range(5):
        a[2*j:2*(j+1),j] = j
    print(a)
    print('-'*10)
    
    a_cast = tf.cast(a,tf.float32)
    a_tf0 = tf.reshape(a,[5,2,5])
    a_tf1 = tf.transpose(a_tf0,[1,0,2])
    a_tf2 = tf.linalg.diag_part(a_tf1)
    a_tf3 = tf.reshape(tf.transpose(a_tf2,[1,0]),[-1])

    sess = tf.Session()
    print(sess.run(a_cast))
    print('-'*10)
    print("a_tf0\n", sess.run(a_tf0))
    print('-'*10)
    print("a_tf1\n", sess.run(a_tf1))

    print('-'*10)
    print("a_tf2\n", sess.run(a_tf2))
    
    
    print('-'*10)
    print("a_tf3\n", sess.run(a_tf3))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment