Skip to content

Instantly share code, notes, and snippets.

@kokeshing
Last active April 8, 2019 16:44
Show Gist options
  • Save kokeshing/42fadb03a29eb2a8b438848d97161701 to your computer and use it in GitHub Desktop.
Save kokeshing/42fadb03a29eb2a8b438848d97161701 to your computer and use it in GitHub Desktop.
import tensorflow as tf
import numpy as np
import time
"""
https://github.com/kweisamx/TensorFlow-ESPCN
"""
def PS_1(X, r, out_filters=1):
Xc = tf.split(X, out_filters, 3)
X = tf.concat([_phase_shift_1(x, r) for x in Xc], 3)
return X
def _phase_shift_1(I, r):
batch_size = tf.shape(I)[0]
bsize, a, b, c = I.get_shape().as_list()
X = tf.reshape(I, (batch_size, a, b, r, r))
X = tf.split(X, a, 1)
X = tf.concat([tf.squeeze(x) for x in X], 2)
X = tf.split(X, b, 1)
X = tf.concat([tf.squeeze(x) for x in X], 2)
return tf.reshape(X, (batch_size, a * r, b * r, 1))
"""
https://github.com/Rayhane-mamah/Tacotron-2/blob/ab5cb08a931fc842d3892ebeb27c8b8734ddd4b8/wavenet_vocoder/models/modules.py#L604
"""
def PS_2(inputs, shuffle_strides=(4, 4), out_filters=1):
batch_size = tf.shape(inputs)[0]
H = inputs.shape[1]
W = tf.shape(inputs)[2]
C = inputs.shape[-1]
r1, r2 = shuffle_strides
out_c = out_filters
assert C == r1 * r2 * out_c
Xc = tf.split(inputs, out_c, axis=3)
outputs = tf.concat([_phase_shift_2(x, batch_size, H, W, r1, r2) for x in Xc], 3)
with tf.control_dependencies([tf.assert_equal(out_c, tf.shape(outputs)[-1]),
tf.assert_equal(H * r1, tf.shape(outputs)[1])]):
outputs = tf.identity(outputs, name='SubPixelConv_output_check')
return tf.reshape(outputs, [tf.shape(outputs)[0], r1 * H, tf.shape(outputs)[2], out_c])
def _phase_shift_2(inputs, batch_size, H, W, r1, r2):
x = tf.reshape(inputs, [batch_size, H, W, r1, r2])
x = tf.transpose(x, [4, 2, 3, 1, 0])
x = tf.batch_to_space_nd(x, [r2], [[0, 0]])
x = tf.squeeze(x, [0])
x = tf.transpose(x, [1, 2, 0, 3])
x = tf.batch_to_space_nd(x, [r1], [[0, 0]])
x = tf.transpose(x, [3, 1, 2, 0])
return x
"""
http://musyoku.github.io/2017/03/18/Deconvolution%E3%81%AE%E4%BB%A3%E3%82%8F%E3%82%8A%E3%81%ABPixel-Shuffler%E3%82%92%E4%BD%BF%E3%81%86/
http://disq.us/p/1hbhk1b
"""
def pixel_shuffler(inputs, shuffle_strides=(4, 4), out_filters=1):
batch_size = tf.shape(inputs)[0]
_, H, W, C = inputs.get_shape()
r1, r2 = shuffle_strides
out_c = out_filters
out_h = H * r1
out_w = W * r2
assert C == r1 * r2 * out_c
x = tf.reshape(inputs, (batch_size, H, W, r1, r2, out_c))
x = tf.transpose(x, (0, 1, 3, 2, 4, 5))
x = tf.reshape(x, (batch_size, out_h, out_w, out_c))
return x
def main():
test = np.random.rand(16, 256, 256, 4)
print(test.shape)
x = tf.placeholder(tf.float32, shape=[None, 256, 256, 4])
x_1 = PS_1(x, 2, out_filters=1)
x_2 = PS_2(x, shuffle_strides=(2, 2), out_filters=1)
x_3 = pixel_shuffler(x, shuffle_strides=(2, 2), out_filters=1)
with tf.Session() as sess:
x_1_ = sess.run(x_1, feed_dict={x: test})
x_2_ = sess.run(x_2, feed_dict={x: test})
x_3_ = sess.run(x_3, feed_dict={x: test})
iseq12 = np.allclose(x_1_, x_2_)
iseq13 = np.allclose(x_1_, x_3_)
iseq23 = np.allclose(x_2_, x_3_)
print(x_1_.shape)
print(x_2_.shape)
print(x_3_.shape)
print(iseq12)
print(iseq13)
print(iseq23)
with tf.Session() as sess:
start = time.time()
for i in range(100):
test = np.random.rand(16, 256, 256, 4)
_ = sess.run(x_1, feed_dict={x: test})
end_1 = time.time()
for i in range(100):
test = np.random.rand(16, 256, 256, 4)
_ = sess.run(x_2, feed_dict={x: test})
end_2 = time.time()
for i in range(100):
test = np.random.rand(16, 256, 256, 4)
_ = sess.run(x_3, feed_dict={x: test})
end_3 = time.time()
print(end_1 - start)
print(end_2 - end_1)
print(end_3 - end_2)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment