Skip to content

Instantly share code, notes, and snippets.

@indraforyou
Created November 21, 2016 00:46
Show Gist options
  • Star 9 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save indraforyou/a8964a398474216917459e168e67ab65 to your computer and use it in GitHub Desktop.
Save indraforyou/a8964a398474216917459e168e67ab65 to your computer and use it in GitHub Desktop.
Keras 3d Deconvolution
from keras.backend import *
from keras.backend.tensorflow_backend import _preprocess_conv3d_input, _preprocess_conv3d_kernel, _preprocess_border_mode, _postprocess_conv3d_output
def _preprocess_deconv3d_output_shape(shape, dim_ordering):
if dim_ordering == 'th':
shape = (shape[0], shape[2], shape[3], shape[4], shape[1])
return shape
def deconv3d(x, kernel, output_shape, strides=(1, 1, 1),
border_mode='valid',
dim_ordering='default',
volume_shape=None, filter_shape=None):
'''3D deconvolution (i.e. transposed convolution).
# Arguments
x: input tensor.
kernel: kernel tensor.
output_shape: 1D int tensor for the output shape.
strides: strides tuple.
border_mode: string, "same" or "valid".
dim_ordering: "tf" or "th".
Whether to use Theano or TensorFlow dimension ordering
for inputs/kernels/ouputs.
'''
# print '*** Deconv3d: \n\tX:{0} \n\tkernel:{1} \n\tstride:{2} \n\tfilter_shape:{3} \n\toutput_shape:{4}'.format(int_shape(x), int_shape(kernel), strides, filter_shape, output_shape)
if dim_ordering == 'default':
dim_ordering = image_dim_ordering()
if dim_ordering not in {'th', 'tf'}:
raise ValueError('Unknown dim_ordering ' + str(dim_ordering))
x = _preprocess_conv3d_input(x, dim_ordering)
output_shape = _preprocess_deconv3d_output_shape(output_shape, dim_ordering)
kernel = _preprocess_conv3d_kernel(kernel, dim_ordering)
kernel = tf.transpose(kernel, (0, 1, 2, 4, 3))
padding = _preprocess_border_mode(border_mode)
strides = (1,) + strides + (1,)
# print '*** Deconv3d: \n\tkernel:{0} \n\tfilter_shape:{1} '.format(int_shape(kernel), filter_shape)
# print output_shape
x = tf.nn.conv3d_transpose(x, kernel, output_shape, strides,
padding)
return _postprocess_conv3d_output(x, dim_ordering)
import numpy as np
import warnings
from keras import activations, initializations, regularizers
from keras.engine import Layer, InputSpec
from keras.utils.np_utils import conv_output_length
from keras.layers.convolutional import Convolution3D
from new import backend_updated as K
class Deconvolution3D(Convolution3D):
def __init__(self, nb_filter, kernel_dim1, kernel_dim2, kernel_dim3, output_shape,
init='glorot_uniform', activation=None, weights=None,
border_mode='valid', subsample=(1, 1, 1),
dim_ordering='default',
W_regularizer=None, b_regularizer=None, activity_regularizer=None,
W_constraint=None, b_constraint=None,
bias=True, **kwargs):
if dim_ordering == 'default':
dim_ordering = K.image_dim_ordering()
if border_mode not in {'valid', 'same', 'full'}:
raise Exception('Invalid border mode for Deconvolution3D:', border_mode)
self.output_shape_ = output_shape
super(Deconvolution3D, self).__init__(nb_filter, kernel_dim1, kernel_dim2, kernel_dim3,
init=init, activation=activation,
weights=weights, border_mode=border_mode,
subsample=subsample, dim_ordering=dim_ordering,
W_regularizer=W_regularizer, b_regularizer=b_regularizer,
activity_regularizer=activity_regularizer,
W_constraint=W_constraint, b_constraint=b_constraint,
bias=bias, **kwargs)
def get_output_shape_for(self, input_shape):
if self.dim_ordering == 'th':
conv_dim1 = self.output_shape_[2]
conv_dim2 = self.output_shape_[3]
conv_dim3 = self.output_shape_[4]
elif self.dim_ordering == 'tf':
conv_dim1 = self.output_shape_[1]
conv_dim2 = self.output_shape_[2]
conv_dim3 = self.output_shape_[3]
else:
raise Exception('Invalid dim_ordering: ' + self.dim_ordering)
if self.dim_ordering == 'th':
return (input_shape[0], self.nb_filter, conv_dim1, conv_dim2, conv_dim3)
elif self.dim_ordering == 'tf':
return (input_shape[0], conv_dim1, conv_dim2, conv_dim3, self.nb_filter)
else:
raise Exception('Invalid dim_ordering: ' + self.dim_ordering)
def call(self, x, mask=None):
input_shape = self.input_spec[0].shape
output = K.deconv3d(x, self.W, self.output_shape_,
strides=self.subsample,
border_mode=self.border_mode,
dim_ordering=self.dim_ordering,
volume_shape=input_shape,
filter_shape=self.W_shape)
if self.bias:
if self.dim_ordering == 'th':
output += K.reshape(self.b, (1, self.nb_filter, 1, 1))
elif self.dim_ordering == 'tf':
output += K.reshape(self.b, (1, 1, 1, self.nb_filter))
else:
raise Exception('Invalid dim_ordering: ' + self.dim_ordering)
output = self.activation(output)
return output
def get_config(self):
config = {'output_shape': self.output_shape_}
base_config = super(Deconvolution3D, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
from keras.models import Sequential, Model
from keras.layers import Input
from keras.layers.convolutional import Convolution3D
from keras.layers.normalization import BatchNormalization
from keras.callbacks import ModelCheckpoint, EarlyStopping
from new.deconv3D import Deconvolution3D
import numpy as np
import pylab as plt
filename = 'model'
_shape = (16,16)
# _shape = (128,128)
# _shape = (64,64)
# time_batch_sz = (None,)
time_batch_sz = (15,)
batch_sz = (10,)
x = Input(batch_shape=(batch_sz + time_batch_sz +_shape + (1,)))
conv1 = Convolution3D(nb_filter=5, kernel_dim1=3, kernel_dim2=3, kernel_dim3=3,
border_mode='same', subsample=(1, 2, 2))
conv2 = Convolution3D(nb_filter=10, kernel_dim1=3, kernel_dim2=3, kernel_dim3=3,
border_mode='same', subsample=(1, 2, 2))
out_shape_2 = (10, 15, 8, 8, 10)
dconv1 = Deconvolution3D(nb_filter=10, kernel_dim1=3, kernel_dim2=3, kernel_dim3=3, output_shape=out_shape_2,
border_mode='same', subsample=(1, 1, 1))
out_shape_1 = (10, 16, 17, 17, 5)
dconv2 = Deconvolution3D(nb_filter=5, kernel_dim1=3, kernel_dim2=3, kernel_dim3=3, output_shape=out_shape_1,
border_mode='same', subsample=(1, 1, 1))
decoder_squash = Convolution3D(1, 2, 2, 2, border_mode='valid', activation='sigmoid')
out = decoder_squash(dconv2(dconv1(conv2(conv1(x)))))
seq = Model(x,out)
seq.compile(loss='mse', optimizer='adadelta')
seq.summary(line_length=150)
# Artificial data generation:
# Generate movies with 3 to 7 moving squares inside.
# The squares are of shape 1x1 or 2x2 pixels,
# which move linearly over time.
# For convenience we first create movies with bigger width and height (_shape*2)
# and at the end we select a 40x40 window.
_shape = (16,16)
def generate_movies(n_samples=1200, n_frames=15):
row = _shape[0]*2
col = _shape[1]*2
noisy_movies = np.zeros((n_samples, n_frames, row, col, 1), dtype=np.float)
shifted_movies = np.zeros((n_samples, n_frames, row, col, 1),
dtype=np.float)
x_clip_st = _shape[0]-_shape[0]/2
x_clip_ed = _shape[0]+x_clip_st
y_clip_st = _shape[0]-_shape[0]/2
y_clip_ed = _shape[0]+y_clip_st
for i in range(n_samples):
# Add 3 to 7 moving squares
n = np.random.randint(3, 8)
for j in range(n):
# Initial position
xstart = np.random.randint(x_clip_st, x_clip_ed)
ystart = np.random.randint(y_clip_st, y_clip_ed)
# Direction of motion
directionx = np.random.randint(0, 3) - 1
directiony = np.random.randint(0, 3) - 1
# Size of the square
w = np.random.randint(2, 4)
for t in range(n_frames):
x_shift = xstart + directionx * t
y_shift = ystart + directiony * t
noisy_movies[i, t, x_shift - w: x_shift + w,
y_shift - w: y_shift + w, 0] += 1
# Make it more robust by adding noise.
# The idea is that if during inference,
# the value of the pixel is not exactly one,
# we need to train the network to be robust and still
# consider it as a pixel belonging to a square.
if np.random.randint(0, 2):
noise_f = (-1)**np.random.randint(0, 2)
noisy_movies[i, t,
x_shift - w - 1: x_shift + w + 1,
y_shift - w - 1: y_shift + w + 1,
0] += noise_f * 0.1
# Shift the ground truth by 1
x_shift = xstart + directionx * (t + 1)
y_shift = ystart + directiony * (t + 1)
shifted_movies[i, t, x_shift - w: x_shift + w,
y_shift - w: y_shift + w, 0] += 1
# Cut to a 40x40 window
noisy_movies = noisy_movies[::, ::, x_clip_st:x_clip_ed, y_clip_st:y_clip_ed, ::]
shifted_movies = shifted_movies[::, ::, x_clip_st:x_clip_ed, y_clip_st:y_clip_ed, ::]
noisy_movies[noisy_movies >= 1] = 1
shifted_movies[shifted_movies >= 1] = 1
return noisy_movies, shifted_movies
# Train the network
noisy_movies, shifted_movies = generate_movies(n_samples=1200)
checkpointer = []
checkpointer.append(EarlyStopping(monitor='val_loss', patience=5, verbose=1, mode='auto'))
print noisy_movies.shape
print shifted_movies.shape
seq.fit(noisy_movies[:1000], shifted_movies[:1000], batch_size=10,
nb_epoch=300, validation_split=0.05, callbacks=checkpointer)
seq.save_weights('{0}_final_wts.h5'.format(filename))
# Testing the network on one movie
# feed it with the first 7 positions and then
# predict the new positions
which = 1004
track = noisy_movies[which][:7, ::, ::, ::]
for j in range(16):
new_pos = seq.predict(track[np.newaxis, ::, ::, ::, ::])
new = new_pos[::, -1, ::, ::, ::]
track = np.concatenate((track, new), axis=0)
# And then compare the predictions
# to the ground truth
track2 = noisy_movies[which][::, ::, ::, ::]
for i in range(15):
fig = plt.figure(figsize=(10, 5))
ax = fig.add_subplot(121)
if i >= 7:
ax.text(1, 3, 'Predictions !', fontsize=20, color='w')
else:
ax.text(1, 3, 'Inital trajectory', fontsize=20)
toplot = track[i, ::, ::, 0]
plt.imshow(toplot)
ax = fig.add_subplot(122)
plt.text(1, 3, 'Ground truth', fontsize=20)
toplot = track2[i, ::, ::, 0]
if i >= 2:
toplot = shifted_movies[which][i - 1, ::, ::, 0]
plt.imshow(toplot)
plt.savefig('%i_animate.png' % (i + 1))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment