Skip to content

Instantly share code, notes, and snippets.

@drasros
Created September 19, 2017 14:58
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save drasros/cf2bea04d4e7c134a90e66156c626576 to your computer and use it in GitHub Desktop.
Save drasros/cf2bea04d4e7c134a90e66156c626576 to your computer and use it in GitHub Desktop.
Application of a separate convolutional filter on each batch element
# test convolution1d with one specific kernel per batch element
import tensorflow as tf
import numpy as np
#################################################
# try normal separable conv with channel_multiplier=1 first
batch_size = 3
in_size = 128
nb_chan = 2
in_X = tf.placeholder(tf.float32, [batch_size, 1, in_size, nb_chan])
init_W = np.random.uniform(0., 0.1,
size=(1, 1, nb_chan, 1)).astype(np.float32)
W = tf.get_variable(
'W', initializer=init_W)
c = tf.nn.depthwise_conv2d(
in_X,
filter=W,
strides=[1, 1, 1, 1],
padding='VALID')
print('Output shape after depthwise convolution, depth multiplier=1')
print(c.get_shape())
# OK, EACH filter convolves on ONE channel...
#################################################
# Now use depth multiplier to check the order of output channels
batch_size = 3
in_size = 128
nb_chan = 2
in_X_dm = tf.placeholder(tf.float32, [batch_size, 1, in_size, nb_chan])
init_W_0 = np.zeros(shape=(1, 1, nb_chan, 1)).astype(np.float32)
init_W_1 = np.random.uniform(0., 0.1,
size=(1, 1, nb_chan, 1)).astype(np.float32)
init_W_dm = np.concatenate((init_W_0, init_W_1), axis=3)
print(init_W_dm.shape)
W_dm = tf.get_variable(
'W_dm', initializer=init_W_dm)
c_dm = tf.nn.depthwise_conv2d(
in_X_dm,
filter=W_dm,
strides=[1, 1, 1, 1],
padding='VALID')
print(c_dm.get_shape())
# OK, now depth mul works, now see whether the order of resulting
# channels is :
# ORDER1: (chfil0,mul0), (chfil0, mul1), (chfil_1, mul0), (chfil1, mul1)
# OR ORDER2: (chfil0, mul0), (chfil1, mul0), (chfil0, mul1), (chfil1, mul1)
# For this let's feed an INPUT that is non-zero on its FIRST(chfil0) channel
# and a FILTER that is non-zero on its SECOND(mul1) chan_mul
# then, if we choose a batch element and enumerate on channels:
# IF THE ORDER IS ORDER1 the only non-nulleverywhere output channel
# will be out[.., 1]
# OTHERWISE IF THE ORDER IS ORDER2 the only non-nulleverywhere output
# channel will be out[.., 2]
sess = tf.Session()
sess.run(tf.global_variables_initializer())
in_X_dm_val = np.concatenate(
(np.random.uniform(0., 1., size=(batch_size, 1, in_size, 1)),
np.zeros(
shape=(batch_size, 1, in_size, 1))),
axis=3).astype(np.float32)
c_dm_val = sess.run(c_dm, feed_dict={
in_X_dm: in_X_dm_val
})
print('----- out[:, 1] : -------')
print(c_dm_val[0, 0, :, 1])
print('----- out[:, 2] : -------')
print(c_dm_val[0, 0, :, 2])
# RESULT : ORDER1 !!
# Hence, we can reshape out_chans with
# shape (batch, 1, in_size, n_chan, chan_multiplier).
# (and NOT (batch, 1, in_size, chan_multiplier, n_chan) which would
# compute but be wrong. )
c_dm_val_r = np.reshape(c_dm_val, [batch_size, 1, in_size, nb_chan, 2])
print('----- first element of batch, one sample only: ---- ')
print('only element (0, 1) should be non-zero: ')
print('If so, no problem with reshape. ')
print(print(c_dm_val_r[0, 0, 0, :, :]))
##################################################
# Now, in the case when we use a 'batch-reshape' to apply
# a different filter for each batch_element:
batch_size = 3
in_size = 128
nb_chan = 5
chan_mul = 2
in_X_br = tf.placeholder(tf.float32,
[batch_size, 1, in_size, nb_chan])
in_X_br_r = tf.transpose(in_X_br, [1, 2, 0, 3])
# now (1, in_size, batch_size, nb_chan)
in_X_br_r = tf.reshape(in_X_br_r,
[1, 1, in_size, batch_size*nb_chan])
init_W_br0 = np.zeros(
shape=(batch_size, 1, 1, nb_chan, 1)).astype(np.float32)
init_W_br1 = np.random.uniform(0, 0.1,
size=(batch_size, 1, 1, nb_chan, 1)).astype(np.float32)
init_W_br = np.concatenate(
[init_W_br0, init_W_br1], axis=-1)
# (batch_size, 1, nb_chan=2, chan_mul=2)
W_br = tf.get_variable(
'W_br_', initializer=init_W_br)
# make sure we reshape and apply it in the same order as in_X_br
# so that it gets applied the right way
W_br = tf.reshape(W_br, [1, 1, batch_size*nb_chan, chan_mul])
c_br = tf.nn.depthwise_conv2d(
in_X_br_r,
filter=W_br,
strides=[1, 1, 1, 1],
padding='VALID') # REM: here we don't care about padding but the argument is required
print('After depthwise conv, out_shape is: ')
print(c_br.shape)
# c_br shape is (1, 1, in_size, batch_size*nb_chan*chan_mul)
# Reshape
c_br = tf.reshape(c_br, [1, in_size, batch_size, nb_chan, chan_mul])
#c_br = tf.reshape(c_br, [1, in_size, nb_chan, batch_size, chan_mul])#chan_mul, nb_chan])
#c_br = tf.transpose(c_br, [3, 0, 1, 2, 4])
c_br = tf.transpose(c_br, [2, 0, 1, 3, 4])
print('output shape before summing channels: ')
print(c_br.shape)
# and finally sum on input channels
c_br = tf.reduce_sum(c_br, axis=3)
print('End shape: ')
print(c_br.shape)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
in_val_b_0 = np.random.uniform(0., 1.,
size=(1, 1, in_size, nb_chan))
in_val_b_other = np.zeros(
shape=(batch_size-1, 1, in_size, nb_chan))
in_val = np.concatenate([in_val_b_0, in_val_b_other], axis=0)
c_br_val = sess.run(c_br, feed_dict={
in_X_br: in_val,
})
print(c_br_val.shape)
print('-- First output channel should be 0 everywhere --')
print(c_br_val[0, 0, :, 0])
print('-- Second output channel should not be 0 everywhere --')
print(c_br_val[0, 0, :, 1])
# print('############################################')
# print('------ FIRST ELEMENT OF BATCH, LOOKING AT ONLY ONE SAMPLE ----')
# print('Should be zero on first columns, not on second..')
# print(c_br_val[0, 0, 0, :, :])
# print('------ SECOND ELEMENT OF BATCH, LOOKING AT ONLY ONE SAMPLE ----')
# print('Should be zero everywhere..')
# print(c_br_val[1, 0, 0, :, :])
# print('############################################')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment