Skip to content

Instantly share code, notes, and snippets.

@rplzzz
Created August 10, 2018 20:43
Show Gist options
  • Save rplzzz/ba952cc6b07ee9d08ea08f2912126572 to your computer and use it in GitHub Desktop.
Save rplzzz/ba952cc6b07ee9d08ea08f2912126572 to your computer and use it in GitHub Desktop.
How to use the transpose convolution function in TensorFlow
## a batch of 3 4x4x2 images as input. We will upsample these to 8x8
input_data = np.ones([3, 4, 4, 2])
filter_shape = [3,3,4,2] # width, height, channels-out, channels-in
tconv_filter = np.ones(filter_shape) # make the filter all ones so that we can manually calculate the output
# obviously, having all the filter channels have the same coefficients defeats
# the purpose of having multiple channels, but this is just an example
output_shape = [8, 8, 4] # width, height, channels-out (notice we don't have the batch size dimension -- more on that later)
## set up the slots for the data
xin = tf.placeholder(dtype=tf.float32, shape = (None, 4, 4, 2), name='input')
filt = tf.placeholder(dtype=tf.float32, shape = filter_shape, name='filter')
## Run the transpose convolution. With a stride of 2, this should upsample our image by a factor of 2
## You have to specify the output shape, but as far as I can tell, it's not a free parameter; it's determined
## by the choice of stride. What's annoying is that the first dimension is usually unknown, but we have to
## include it, so we have to extract it and concatenate it onto the front of output_shape.
dimxin = tf.shape(xin)
ncase = dimxin[0:1]
oshp = tf.concat([ncase,output_shape], axis=0)
z1 = tf.nn.conv2d_transpose(xin, filt, oshp, strides=[1,2,2,1], name='xpose_conv')
## tf.layers has an all-in-one transpose convolution layer. It's a lot more convenient, but
## you don't get to specify the weights (they get initialized randomly). In this case I wanted to use
## specified weights so I could see what the actual effect is. Note that while the default padding for
## tf.nn.conv2d_transpose is 'same', the default for this function is 'valid'. You're @#$%ing killing
## me, Google.
z2 = tf.layers.conv2d_transpose(xin, 4, (3,3), strides=(2,2), padding='SAME')
with tf.Session() as sess:
summary_writer = tf.summary.FileWriter('logs', sess.graph)
sess.run(tf.global_variables_initializer())
(z1out, z2out) = sess.run(fetches=[z1,z2],
feed_dict={xin:input_data, filt:tconv_filter})
print(z1out.shape)
print(z2out.shape)
print(z1out[0, ..., 0])
@rplzzz
Copy link
Author

rplzzz commented Aug 10, 2018

Here's the output:

(3, 8, 8, 4)
(3, 8, 8, 4)
array([[2., 2., 4., 2., 4., 2., 4., 2.],
       [2., 2., 4., 2., 4., 2., 4., 2.],
       [4., 4., 8., 4., 8., 4., 8., 4.],
       [2., 2., 4., 2., 4., 2., 4., 2.],
       [4., 4., 8., 4., 8., 4., 8., 4.],
       [2., 2., 4., 2., 4., 2., 4., 2.],
       [4., 4., 8., 4., 8., 4., 8., 4.],
       [2., 2., 4., 2., 4., 2., 4., 2.]], dtype=float32)

z2out will be filled with random values, since the weights are initialized randomly.

@rplzzz
Copy link
Author

rplzzz commented Aug 10, 2018

Incidentally, in the output above you can see the well-known checkerboard pattern that transpose convolutions are often criticized for. Many sources recommend following this operation with a same-size convolution to smooth out the artifacts.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment