Skip to content

Instantly share code, notes, and snippets.

@ericjang
ericjang / draw_0.py
Last active February 24, 2016 05:19
DRAW high-level implementation
cs,mus,logsigmas,sigmas=[0]*T,[0]*T,[0]*T,[0]*T # parameters we'll need to access later
# initial states
DO_SHARE=False
h_dec_prev=tf.zeros((batch_size,dec_size))
enc_state=lstm_enc.zero_state(batch_size, tf.float32)
dec_state=lstm_dec.zero_state(batch_size, tf.float32)
# build the graph
for t in range(T):
c_prev = tf.zeros((batch_size,img_size)) if t==0 else cs[t-1]
x_hat=x-tf.sigmoid(c_prev) # error image
def read_no_attn(x,x_hat,h_dec_prev):
return tf.concat(1,[x,x_hat])
def linear(x,output_dim):
"""
affine transformation Wx+b
assumes x.shape = (batch_size, num_features)
"""
w=tf.get_variable("w", [x.get_shape()[1], output_dim])
b=tf.get_variable("b", [output_dim], initializer=tf.constant_initializer(0.0))
return tf.matmul(x,w)+b
def attn_window(scope,h_dec,N):
def filterbank(gx, gy, sigma2,delta, N):
grid_i = tf.reshape(tf.cast(tf.range(N), tf.float32), [1, -1])
mu_x = gx + (grid_i - N / 2 - 0.5) * delta # eq 19
mu_y = gy + (grid_i - N / 2 - 0.5) * delta # eq 20
a = tf.reshape(tf.cast(tf.range(A), tf.float32), [1, 1, -1])
b = tf.reshape(tf.cast(tf.range(B), tf.float32), [1, 1, -1])
mu_x = tf.reshape(mu_x, [-1, N, 1])
mu_y = tf.reshape(mu_y, [-1, N, 1])
sigma2 = tf.reshape(sigma2, [-1, 1, 1])
Fx = tf.exp(-tf.square((a - mu_x) / (2*sigma2))) # 2*sigma2?
def read_attn(x,x_hat,h_dec_prev):
Fx,Fy,gamma=attn_window("read",h_dec_prev,read_n)
def filter_img(img,Fx,Fy,gamma,N):
Fxt=tf.transpose(Fx,perm=[0,2,1])
img=tf.reshape(img,[-1,B,A])
glimpse=tf.batch_matmul(Fy,tf.batch_matmul(img,Fxt))
glimpse=tf.reshape(glimpse,[-1,N*N])
return glimpse*tf.reshape(gamma,[-1,1])
x=filter_img(x,Fx,Fy,gamma,read_n) # batch x (read_n*read_n)
x_hat=filter_img(x_hat,Fx,Fy,gamma,read_n)
@ericjang
ericjang / tf_mm_vectorized.py
Last active February 24, 2016 16:19
vectorized matrix multiplication in TF
# applying read filter. inspired by https://github.com/ikostrikov/TensorFlow-VAE-GAN-DRAW/blob/master/main-draw.py
Fxt=tf.transpose(Fx, [0,2,1]) # batch x N x A
Fxt=tf.reshape(Fxt, [-1,1,A,N,1]) # batch x 1 x A x N x 1
Fxt=tf.tile(Fxt, [1,N,1,1,1]) # batch x N x A x N x 1 (repmat'ed along dim=1)
Fy=tf.reshape(Fy, [-1,N,B,1,1]) # batch x N x B x 1 x 1
x=tf.reshape(x,[-1,1,B,A,1]) # batch x 1 x B x A x 1
x=tf.tile(x,[1,N,1,1,1]) # batch x N x B x A x 1
Fydotx=tf.reduce_sum(Fy*x,2) # batch x N x A x 1
Fydotx=tf.reshape(x,[-1,N,A,1,1]) # batch x N x A x 1 x 1
FydotxdotFxt=tf.reduce_sum(Fydotx*Fxt,2) # batch x N x N x 1
# encoder
from tensorflow.models.rnn.rnn_cell import LSTMCell
read_size = 2*read_n*read_n if FLAGS.read_attn else 2*img_size
lstm_enc = LSTMCell(enc_size, read_size+dec_size) # encoder Op
def encode(state,input):
with tf.variable_scope("encoder",reuse=DO_SHARE):
return lstm_enc(input,state)
def sampleQ(h_enc):
"""
Samples Zt ~ normrnd(mu,sigma) via reparameterization trick for normal dist
mu is (batch,z_size)
"""
with tf.variable_scope("mu",reuse=DO_SHARE):
mu=linear(h_enc,z_size)
with tf.variable_scope("sigma",reuse=DO_SHARE):
logsigma=linear(h_enc,z_size)
sigma=tf.exp(logsigma)
def decode(state,input):
with tf.variable_scope("decoder",reuse=DO_SHARE):
return lstm_dec(input, state)
def write_no_attn(h_dec):
with tf.variable_scope("write",reuse=DO_SHARE):
return linear(h_dec,img_size)
def write_attn(h_dec):
with tf.variable_scope("writeW",reuse=DO_SHARE):
w=linear(h_dec,write_size) # batch x (write_n*write_n)
N=write_n
w=tf.reshape(w,[batch_size,N,N])
Fx,Fy,gamma=attn_window("write",h_dec,write_n)