This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
cs,mus,logsigmas,sigmas=[0]*T,[0]*T,[0]*T,[0]*T # parameters we'll need to access later | |
# initial states | |
DO_SHARE=False | |
h_dec_prev=tf.zeros((batch_size,dec_size)) | |
enc_state=lstm_enc.zero_state(batch_size, tf.float32) | |
dec_state=lstm_dec.zero_state(batch_size, tf.float32) | |
# build the graph | |
for t in range(T): | |
c_prev = tf.zeros((batch_size,img_size)) if t==0 else cs[t-1] | |
x_hat=x-tf.sigmoid(c_prev) # error image |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def read_no_attn(x,x_hat,h_dec_prev): | |
return tf.concat(1,[x,x_hat]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def linear(x,output_dim): | |
""" | |
affine transformation Wx+b | |
assumes x.shape = (batch_size, num_features) | |
""" | |
w=tf.get_variable("w", [x.get_shape()[1], output_dim]) | |
b=tf.get_variable("b", [output_dim], initializer=tf.constant_initializer(0.0)) | |
return tf.matmul(x,w)+b | |
def attn_window(scope,h_dec,N): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def filterbank(gx, gy, sigma2,delta, N): | |
grid_i = tf.reshape(tf.cast(tf.range(N), tf.float32), [1, -1]) | |
mu_x = gx + (grid_i - N / 2 - 0.5) * delta # eq 19 | |
mu_y = gy + (grid_i - N / 2 - 0.5) * delta # eq 20 | |
a = tf.reshape(tf.cast(tf.range(A), tf.float32), [1, 1, -1]) | |
b = tf.reshape(tf.cast(tf.range(B), tf.float32), [1, 1, -1]) | |
mu_x = tf.reshape(mu_x, [-1, N, 1]) | |
mu_y = tf.reshape(mu_y, [-1, N, 1]) | |
sigma2 = tf.reshape(sigma2, [-1, 1, 1]) | |
Fx = tf.exp(-tf.square((a - mu_x) / (2*sigma2))) # 2*sigma2? |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def read_attn(x,x_hat,h_dec_prev): | |
Fx,Fy,gamma=attn_window("read",h_dec_prev,read_n) | |
def filter_img(img,Fx,Fy,gamma,N): | |
Fxt=tf.transpose(Fx,perm=[0,2,1]) | |
img=tf.reshape(img,[-1,B,A]) | |
glimpse=tf.batch_matmul(Fy,tf.batch_matmul(img,Fxt)) | |
glimpse=tf.reshape(glimpse,[-1,N*N]) | |
return glimpse*tf.reshape(gamma,[-1,1]) | |
x=filter_img(x,Fx,Fy,gamma,read_n) # batch x (read_n*read_n) | |
x_hat=filter_img(x_hat,Fx,Fy,gamma,read_n) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# applying read filter. inspired by https://github.com/ikostrikov/TensorFlow-VAE-GAN-DRAW/blob/master/main-draw.py | |
Fxt=tf.transpose(Fx, [0,2,1]) # batch x N x A | |
Fxt=tf.reshape(Fxt, [-1,1,A,N,1]) # batch x 1 x A x N x 1 | |
Fxt=tf.tile(Fxt, [1,N,1,1,1]) # batch x N x A x N x 1 (repmat'ed along dim=1) | |
Fy=tf.reshape(Fy, [-1,N,B,1,1]) # batch x N x B x 1 x 1 | |
x=tf.reshape(x,[-1,1,B,A,1]) # batch x 1 x B x A x 1 | |
x=tf.tile(x,[1,N,1,1,1]) # batch x N x B x A x 1 | |
Fydotx=tf.reduce_sum(Fy*x,2) # batch x N x A x 1 | |
Fydotx=tf.reshape(x,[-1,N,A,1,1]) # batch x N x A x 1 x 1 | |
FydotxdotFxt=tf.reduce_sum(Fydotx*Fxt,2) # batch x N x N x 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# encoder | |
from tensorflow.models.rnn.rnn_cell import LSTMCell | |
read_size = 2*read_n*read_n if FLAGS.read_attn else 2*img_size | |
lstm_enc = LSTMCell(enc_size, read_size+dec_size) # encoder Op | |
def encode(state,input): | |
with tf.variable_scope("encoder",reuse=DO_SHARE): | |
return lstm_enc(input,state) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def sampleQ(h_enc): | |
""" | |
Samples Zt ~ normrnd(mu,sigma) via reparameterization trick for normal dist | |
mu is (batch,z_size) | |
""" | |
with tf.variable_scope("mu",reuse=DO_SHARE): | |
mu=linear(h_enc,z_size) | |
with tf.variable_scope("sigma",reuse=DO_SHARE): | |
logsigma=linear(h_enc,z_size) | |
sigma=tf.exp(logsigma) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def decode(state,input): | |
with tf.variable_scope("decoder",reuse=DO_SHARE): | |
return lstm_dec(input, state) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def write_no_attn(h_dec): | |
with tf.variable_scope("write",reuse=DO_SHARE): | |
return linear(h_dec,img_size) | |
def write_attn(h_dec): | |
with tf.variable_scope("writeW",reuse=DO_SHARE): | |
w=linear(h_dec,write_size) # batch x (write_n*write_n) | |
N=write_n | |
w=tf.reshape(w,[batch_size,N,N]) | |
Fx,Fy,gamma=attn_window("write",h_dec,write_n) |