Skip to content

Instantly share code, notes, and snippets.

@danijar
Created June 29, 2018 14:18
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save danijar/c5d93d051a60b356ecb99b561153ceb1 to your computer and use it in GitHub Desktop.
Save danijar/c5d93d051a60b356ecb99b561153ceb1 to your computer and use it in GitHub Desktop.
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
def gaussian_blur(image, diameter):
padding = [[0, 0]] + [[(diameter) // 2, (diameter - 1) // 2]] * 2 + [[0, 0]]
diameter = tf.to_float(diameter)
filter_ = tf.range(-(diameter - 1) // 2, (diameter - 1) // 2 + 1)
filter_ = tf.exp(-0.5 * filter_ ** 2 / (diameter / 4) ** 2) # 2 stds.
filter_ /= tf.reduce_sum(filter_)
filter_ = tf.tile(filter_[:, None, None], [1, image.shape[3].value, 1])
image = tf.pad(image, padding, 'SYMMETRIC') # No 'edge' mode.
image = tf.nn.depthwise_conv2d(image, filter_[:, None], [1, 1, 1, 1], 'VALID')
image = tf.nn.depthwise_conv2d(image, filter_[None, :], [1, 1, 1, 1], 'VALID')
return image
def example_image(height=60, width=80, seed=3):
image = np.ones((height, width, 3))
image *= np.linspace( 5, 10, image.shape[0])[:, None, None]
image *= np.linspace(-2, 10, image.shape[1])[None, :, None]
image *= np.linspace(-1, 10, image.shape[2])[None, None, :]
image = (seed * image / 255) % 1
return image
diameter = tf.placeholder(tf.int32, [])
image = tf.placeholder(tf.float32, [None, None, None, 3])
output = gaussian_blur(tf.to_float(image), diameter)
original = example_image()
diameters = [1, 3, 10, 20]
fig, ax = plt.subplots(2, len(diameters), figsize=(5 * len(diameters), 7))
ax[0, 0].set_ylabel('Blurred')
ax[1, 0].set_ylabel('Difference')
with tf.Session() as sess:
for index, value in enumerate(diameters):
blurred = sess.run(output, {image: [original], diameter: value})[0]
ax[0, index].set_title('Diameter {}'.format(value))
ax[0, index].imshow(blurred, interpolation='nearest')
ax[1, index].imshow((blurred - original + 1) / 2, interpolation='nearest')
for axes in ax.flatten():
axes.set_xticks([])
axes.set_yticks([])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment