Skip to content

Instantly share code, notes, and snippets.

@aidin-zadeh
Last active October 28, 2019 01:56
Show Gist options
  • Save aidin-zadeh/58f7e9bd2b43778c5b600a0ac397c5df to your computer and use it in GitHub Desktop.
Save aidin-zadeh/58f7e9bd2b43778c5b600a0ac397c5df to your computer and use it in GitHub Desktop.
import os
import datetime
import numpy as np
os.environ['KERAS_BACKEND'] = 'theano'
os.environ['THEANO_FLAGS'] = \
'mode=FAST_RUN,device=cuda0,floatX=float32,optimizer=None'
import keras as ks
class Conv3DAutoEncoder(object):
"""Documentation for Conv3DAutoEncoder
"""
def __init__(self,
sz_input,
n_filters,
sz_kernel,
sz_pool,
sz_upsample,
padding='same',
data_format=None,
loss='mean_squared_error',
name=None,
verbose=1):
super(Conv3DAutoEncoder, self).__init__()
self.sz_input = sz_input
self.n_filters = n_filters
self.sz_kernel = sz_kernel
self.sz_pool = sz_pool
self.sz_upsample = sz_upsample
self.padding = padding
if name is None:
self.name = self.__class__.__name__
else:
self.name = name
if data_format is None:
self.data_format = ks.backend.image_data_format()
else:
self.data_format = data_format
self.loss = loss
self.verbose = verbose
self.build()
def build(self):
"""
"""
x = ks.layers.Input(shape=self.sz_input, name='input')
encoded = self.encode(x)
decoded = self.decode(encoded)
self.model = ks.models.Model(x, decoded)
def encode(self, x):
"""
"""
for i, f in enumerate(self.n_filters[0:-1]):
x = ks.layers.Conv3D(
f, self.sz_kernel,
activation='relu',
padding=self.padding,
data_format=self.data_format,
name='conv3d-'+str(i+1),)(x)
x = ks.layers.MaxPooling3D(
self.sz_pool,
padding=self.padding,
data_format=self.data_format,
name='pool3d-'+str(i+1),)(x)
encoded = ks.layers.Conv3D(
self.n_filters[-1],
self.sz_kernel,
activation='relu',
padding=self.padding,
data_format=self.data_format,
name='conv3d-'+str(i+2),)(x)
return encoded
def decode(self, encoded):
"""
"""
idx_layer = len(self.n_filters)+1
x = ks.layers.Conv3D(
self.n_filters[-1], self.sz_kernel,
activation='relu',
padding=self.padding,
data_format=self.data_format,
name='conv3d-'+str(idx_layer),)(encoded)
for i, f in enumerate(self.n_filters[-2::-1]):
idx_layer += 1
x = ks.layers.Conv3D(
self.n_filters[i],
self.sz_kernel,
activation='relu',
padding=self.padding,
data_format=self.data_format,
name='conv3d-'+str(idx_layer),)(x)
x = ks.layers.UpSampling3D(
self.sz_upsample,
data_format=self.data_format,
name='upsample3d-'+str(idx_layer), )(x)
decoded = ks.layers.Conv3D(
1, self.sz_kernel,
activation='sigmoid',
data_format=self.data_format,
padding=self.padding,
name='conv3d-'+str(idx_layer+1))(x)
return decoded
def fit(self, train_set,
n_epochs,
sz_batch,
valid_set=None,
optimizer='adam',
shuffle=True,
generator=None,
check_point=None,
use_multiprocessing=False,
workers=1,
callbacks=None,
max_queue_size=1,
fname=None):
"""
"""
self.model.compile(optimizer=optimizer, loss=self.loss)
if not isinstance(train_set, tuple):
train_set = (train_set,) * 2
if not isinstance(valid_set, tuple) and valid_set is not None:
valid_set = (valid_set,) * 2
if fname is None:
dt = datetime.datetime.now()
fname = self.name + '-' + dt.strftime("%Y-%m-%d-%H-%M-%S")
if not hasattr(generator, '__call__'):
self.model.fit(
*train_set,
epochs=n_epochs,
batch_size=sz_batch,
shuffle=shuffle,
validation_data=valid_set,
callbacks=callbacks)
self.model.save(fname + '-end' + '.h5')
else:
self.model.fit_generator(
generator=generator(*train_set),
validation_data=valid_set,
steps_per_epoch=train_set[0].shape[0] // sz_batch,
validation_steps=valid_set[0].shape[0] // sz_batch,
epochs=n_epochs,
shuffle=False,
use_multiprocessing=use_multiprocessing,
workers=workers,
callbacks=callbacks,
max_queue_size=max_queue_size,)
self.model.save(fname + '-end' + '.h5')
seed = 42
n_x = 1000
sz_input = (1, 204, 20, 20)
data_format = 'channels_first'
np.random.seed(seed)
x = np.random.rand(n_x, *sz_input)
n_filters = [30, 20, 10]
sz_kernel = (5, 5, 5)
sz_pool = (2, 2, 2)
sz_upsample = (2, 2, 2)
# define model
conv3d_ae = Conv3DAutoEncoder(
sz_input,
n_filters,
sz_kernel,
sz_pool,
sz_upsample,
data_format=data_format)
# print out model summary
conv3d_ae.model.summary()
n_epochs = 2
sz_batch = 50
suffle = True
conv3d_ae.fit(x,
n_epochs, sz_batch,
valid_set=None,
shuffle=True,
generator=None,
check_point=None)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment