Skip to content

Instantly share code, notes, and snippets.

@yxlao
Last active April 27, 2022 09:27
Show Gist options
  • Star 11 You must be signed in to star a gist
  • Fork 5 You must be signed in to fork a gist
  • Save yxlao/ef50416011b9587835ac752aa3ce3530 to your computer and use it in GitHub Desktop.
Save yxlao/ef50416011b9587835ac752aa3ce3530 to your computer and use it in GitHub Desktop.
TensorFlow Convolution Gradients
"""
Demostrating how to compute the gradients for convolution with:
tf.nn.conv2d
tf.nn.conv2d_backprop_input
tf.nn.conv2d_backprop_filter
tf.nn.conv2d_transpose
This is the scripts for this answer: https://stackoverflow.com/a/44350789/1255535
"""
import tensorflow as tf
import numpy as np
import scipy.signal
def tf_rot180(w):
"""
Roate by 180 degrees
"""
return tf.reverse(w, axis=[0, 1])
def tf_pad_to_full_conv2d(x, w_size):
"""
Pad x, such that using a 'VALID' convolution in tensorflow is the same
as using a 'FULL' convolution. See
http://deeplearning.net/software/theano/library/tensor/nnet/conv.html#theano.tensor.nnet.conv2d
for description of 'FULL' convolution.
"""
return tf.pad(x, [[0, 0],
[w_size - 1, w_size - 1],
[w_size - 1, w_size - 1],
[0, 0]])
def tf_NHWC_to_HWIO(out):
"""
Converts [batch, in_height, in_width, in_channels]
to [filter_height, filter_width, in_channels, out_channels]
"""
return tf.transpose(out, perm=[1, 2, 0, 3])
# sizes, fixed strides, in_channel, out_channel be 1 for now
x_size = 4
w_size = 3 # use an odd number here
x_shape = (1, x_size, x_size, 1)
w_shape = (w_size, w_size, 1, 1)
out_shape = (1, x_size - w_size + 1, x_size - w_size + 1, 1)
strides = (1, 1, 1, 1)
# numpy value
x_np = np.random.randint(10, size=x_shape)
w_np = np.random.randint(10, size=w_shape)
out_scale_np = np.random.randint(10, size=out_shape)
# tf forward
x = tf.constant(x_np, dtype=tf.float32)
w = tf.constant(w_np, dtype=tf.float32)
out = tf.nn.conv2d(input=x, filter=w, strides=strides, padding='VALID')
out_scale = tf.constant(out_scale_np, dtype=tf.float32)
f = tf.reduce_sum(tf.multiply(out, out_scale))
# tf backward
d_out = tf.gradients(f, out)[0]
# 4 different ways to compute d_x
d_x = tf.gradients(f, x)[0]
d_x_manual = tf.nn.conv2d(input=tf_pad_to_full_conv2d(d_out, w_size),
filter=tf_rot180(w),
strides=strides,
padding='VALID')
d_x_backprop_input = tf.nn.conv2d_backprop_input(input_sizes=x_shape,
filter=w,
out_backprop=d_out,
strides=strides,
padding='VALID')
d_x_transpose = tf.nn.conv2d_transpose(value=d_out,
filter=w,
output_shape=x_shape,
strides=strides,
padding='VALID')
# 3 different ways to compute d_w
d_w = tf.gradients(f, w)[0]
d_w_manual = tf_NHWC_to_HWIO(tf.nn.conv2d(input=x,
filter=tf_NHWC_to_HWIO(d_out),
strides=strides,
padding='VALID'))
d_w_backprop_filter = tf.nn.conv2d_backprop_filter(input=x,
filter_sizes=w_shape,
out_backprop=d_out,
strides=strides,
padding='VALID')
# run
with tf.Session() as sess:
np.testing.assert_allclose(sess.run(d_x), sess.run(d_x_manual))
np.testing.assert_allclose(sess.run(d_x), sess.run(d_x_backprop_input))
np.testing.assert_allclose(sess.run(d_x), sess.run(d_x_transpose))
np.testing.assert_allclose(sess.run(d_w), sess.run(d_w_manual))
np.testing.assert_allclose(sess.run(d_w), sess.run(d_w_backprop_filter))
"""
Get the same results using numpy / scipy
"""
def rot180(x):
return np.flipud(np.fliplr(x))
def conv2d(x, w, mode='full', boundary='fill', fillvalue=0):
"""
2d convolution without rot180, scipy's conv2d rotates
the filter by 180 degress.
"""
return scipy.signal.convolve2d(x, rot180(w),
mode=mode,
boundary=boundary,
fillvalue=fillvalue)
# convert from 4d to 2d
x_np = x_np.squeeze()
w_np = w_np.squeeze()
d_out_np = out_scale_np.squeeze()
# compute gradient manually
out_np = conv2d(x_np, w_np, mode='valid')
d_x_np = conv2d(d_out_np, rot180(w_np), mode='full')
d_w_np = conv2d(x_np, d_out_np, mode='valid')
# run
with tf.Session() as sess:
np.testing.assert_allclose(sess.run(d_x).squeeze(), d_x_np)
np.testing.assert_allclose(sess.run(d_w).squeeze(), d_w_np)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment