Created
September 19, 2017 14:58
-
-
Save drasros/cf2bea04d4e7c134a90e66156c626576 to your computer and use it in GitHub Desktop.
Application of a separate convolutional filter on each batch element
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# test convolution1d with one specific kernel per batch element | |
import tensorflow as tf | |
import numpy as np | |
################################################# | |
# try normal separable conv with channel_multiplier=1 first | |
batch_size = 3 | |
in_size = 128 | |
nb_chan = 2 | |
in_X = tf.placeholder(tf.float32, [batch_size, 1, in_size, nb_chan]) | |
init_W = np.random.uniform(0., 0.1, | |
size=(1, 1, nb_chan, 1)).astype(np.float32) | |
W = tf.get_variable( | |
'W', initializer=init_W) | |
c = tf.nn.depthwise_conv2d( | |
in_X, | |
filter=W, | |
strides=[1, 1, 1, 1], | |
padding='VALID') | |
print('Output shape after depthwise convolution, depth multiplier=1') | |
print(c.get_shape()) | |
# OK, EACH filter convolves on ONE channel... | |
################################################# | |
# Now use depth multiplier to check the order of output channels | |
batch_size = 3 | |
in_size = 128 | |
nb_chan = 2 | |
in_X_dm = tf.placeholder(tf.float32, [batch_size, 1, in_size, nb_chan]) | |
init_W_0 = np.zeros(shape=(1, 1, nb_chan, 1)).astype(np.float32) | |
init_W_1 = np.random.uniform(0., 0.1, | |
size=(1, 1, nb_chan, 1)).astype(np.float32) | |
init_W_dm = np.concatenate((init_W_0, init_W_1), axis=3) | |
print(init_W_dm.shape) | |
W_dm = tf.get_variable( | |
'W_dm', initializer=init_W_dm) | |
c_dm = tf.nn.depthwise_conv2d( | |
in_X_dm, | |
filter=W_dm, | |
strides=[1, 1, 1, 1], | |
padding='VALID') | |
print(c_dm.get_shape()) | |
# OK, now depth mul works, now see whether the order of resulting | |
# channels is : | |
# ORDER1: (chfil0,mul0), (chfil0, mul1), (chfil_1, mul0), (chfil1, mul1) | |
# OR ORDER2: (chfil0, mul0), (chfil1, mul0), (chfil0, mul1), (chfil1, mul1) | |
# For this let's feed an INPUT that is non-zero on its FIRST(chfil0) channel | |
# and a FILTER that is non-zero on its SECOND(mul1) chan_mul | |
# then, if we choose a batch element and enumerate on channels: | |
# IF THE ORDER IS ORDER1 the only non-nulleverywhere output channel | |
# will be out[.., 1] | |
# OTHERWISE IF THE ORDER IS ORDER2 the only non-nulleverywhere output | |
# channel will be out[.., 2] | |
sess = tf.Session() | |
sess.run(tf.global_variables_initializer()) | |
in_X_dm_val = np.concatenate( | |
(np.random.uniform(0., 1., size=(batch_size, 1, in_size, 1)), | |
np.zeros( | |
shape=(batch_size, 1, in_size, 1))), | |
axis=3).astype(np.float32) | |
c_dm_val = sess.run(c_dm, feed_dict={ | |
in_X_dm: in_X_dm_val | |
}) | |
print('----- out[:, 1] : -------') | |
print(c_dm_val[0, 0, :, 1]) | |
print('----- out[:, 2] : -------') | |
print(c_dm_val[0, 0, :, 2]) | |
# RESULT : ORDER1 !! | |
# Hence, we can reshape out_chans with | |
# shape (batch, 1, in_size, n_chan, chan_multiplier). | |
# (and NOT (batch, 1, in_size, chan_multiplier, n_chan) which would | |
# compute but be wrong. ) | |
c_dm_val_r = np.reshape(c_dm_val, [batch_size, 1, in_size, nb_chan, 2]) | |
print('----- first element of batch, one sample only: ---- ') | |
print('only element (0, 1) should be non-zero: ') | |
print('If so, no problem with reshape. ') | |
print(print(c_dm_val_r[0, 0, 0, :, :])) | |
################################################## | |
# Now, in the case when we use a 'batch-reshape' to apply | |
# a different filter for each batch_element: | |
batch_size = 3 | |
in_size = 128 | |
nb_chan = 5 | |
chan_mul = 2 | |
in_X_br = tf.placeholder(tf.float32, | |
[batch_size, 1, in_size, nb_chan]) | |
in_X_br_r = tf.transpose(in_X_br, [1, 2, 0, 3]) | |
# now (1, in_size, batch_size, nb_chan) | |
in_X_br_r = tf.reshape(in_X_br_r, | |
[1, 1, in_size, batch_size*nb_chan]) | |
init_W_br0 = np.zeros( | |
shape=(batch_size, 1, 1, nb_chan, 1)).astype(np.float32) | |
init_W_br1 = np.random.uniform(0, 0.1, | |
size=(batch_size, 1, 1, nb_chan, 1)).astype(np.float32) | |
init_W_br = np.concatenate( | |
[init_W_br0, init_W_br1], axis=-1) | |
# (batch_size, 1, nb_chan=2, chan_mul=2) | |
W_br = tf.get_variable( | |
'W_br_', initializer=init_W_br) | |
# make sure we reshape and apply it in the same order as in_X_br | |
# so that it gets applied the right way | |
W_br = tf.reshape(W_br, [1, 1, batch_size*nb_chan, chan_mul]) | |
c_br = tf.nn.depthwise_conv2d( | |
in_X_br_r, | |
filter=W_br, | |
strides=[1, 1, 1, 1], | |
padding='VALID') # REM: here we don't care about padding but the argument is required | |
print('After depthwise conv, out_shape is: ') | |
print(c_br.shape) | |
# c_br shape is (1, 1, in_size, batch_size*nb_chan*chan_mul) | |
# Reshape | |
c_br = tf.reshape(c_br, [1, in_size, batch_size, nb_chan, chan_mul]) | |
#c_br = tf.reshape(c_br, [1, in_size, nb_chan, batch_size, chan_mul])#chan_mul, nb_chan]) | |
#c_br = tf.transpose(c_br, [3, 0, 1, 2, 4]) | |
c_br = tf.transpose(c_br, [2, 0, 1, 3, 4]) | |
print('output shape before summing channels: ') | |
print(c_br.shape) | |
# and finally sum on input channels | |
c_br = tf.reduce_sum(c_br, axis=3) | |
print('End shape: ') | |
print(c_br.shape) | |
sess = tf.Session() | |
sess.run(tf.global_variables_initializer()) | |
in_val_b_0 = np.random.uniform(0., 1., | |
size=(1, 1, in_size, nb_chan)) | |
in_val_b_other = np.zeros( | |
shape=(batch_size-1, 1, in_size, nb_chan)) | |
in_val = np.concatenate([in_val_b_0, in_val_b_other], axis=0) | |
c_br_val = sess.run(c_br, feed_dict={ | |
in_X_br: in_val, | |
}) | |
print(c_br_val.shape) | |
print('-- First output channel should be 0 everywhere --') | |
print(c_br_val[0, 0, :, 0]) | |
print('-- Second output channel should not be 0 everywhere --') | |
print(c_br_val[0, 0, :, 1]) | |
# print('############################################') | |
# print('------ FIRST ELEMENT OF BATCH, LOOKING AT ONLY ONE SAMPLE ----') | |
# print('Should be zero on first columns, not on second..') | |
# print(c_br_val[0, 0, 0, :, :]) | |
# print('------ SECOND ELEMENT OF BATCH, LOOKING AT ONLY ONE SAMPLE ----') | |
# print('Should be zero everywhere..') | |
# print(c_br_val[1, 0, 0, :, :]) | |
# print('############################################') | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment