Last active
October 28, 2019 01:56
-
-
Save aidin-zadeh/58f7e9bd2b43778c5b600a0ac397c5df to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import datetime | |
import numpy as np | |
os.environ['KERAS_BACKEND'] = 'theano' | |
os.environ['THEANO_FLAGS'] = \ | |
'mode=FAST_RUN,device=cuda0,floatX=float32,optimizer=None' | |
import keras as ks | |
class Conv3DAutoEncoder(object): | |
"""Documentation for Conv3DAutoEncoder | |
""" | |
def __init__(self, | |
sz_input, | |
n_filters, | |
sz_kernel, | |
sz_pool, | |
sz_upsample, | |
padding='same', | |
data_format=None, | |
loss='mean_squared_error', | |
name=None, | |
verbose=1): | |
super(Conv3DAutoEncoder, self).__init__() | |
self.sz_input = sz_input | |
self.n_filters = n_filters | |
self.sz_kernel = sz_kernel | |
self.sz_pool = sz_pool | |
self.sz_upsample = sz_upsample | |
self.padding = padding | |
if name is None: | |
self.name = self.__class__.__name__ | |
else: | |
self.name = name | |
if data_format is None: | |
self.data_format = ks.backend.image_data_format() | |
else: | |
self.data_format = data_format | |
self.loss = loss | |
self.verbose = verbose | |
self.build() | |
def build(self): | |
""" | |
""" | |
x = ks.layers.Input(shape=self.sz_input, name='input') | |
encoded = self.encode(x) | |
decoded = self.decode(encoded) | |
self.model = ks.models.Model(x, decoded) | |
def encode(self, x): | |
""" | |
""" | |
for i, f in enumerate(self.n_filters[0:-1]): | |
x = ks.layers.Conv3D( | |
f, self.sz_kernel, | |
activation='relu', | |
padding=self.padding, | |
data_format=self.data_format, | |
name='conv3d-'+str(i+1),)(x) | |
x = ks.layers.MaxPooling3D( | |
self.sz_pool, | |
padding=self.padding, | |
data_format=self.data_format, | |
name='pool3d-'+str(i+1),)(x) | |
encoded = ks.layers.Conv3D( | |
self.n_filters[-1], | |
self.sz_kernel, | |
activation='relu', | |
padding=self.padding, | |
data_format=self.data_format, | |
name='conv3d-'+str(i+2),)(x) | |
return encoded | |
def decode(self, encoded): | |
""" | |
""" | |
idx_layer = len(self.n_filters)+1 | |
x = ks.layers.Conv3D( | |
self.n_filters[-1], self.sz_kernel, | |
activation='relu', | |
padding=self.padding, | |
data_format=self.data_format, | |
name='conv3d-'+str(idx_layer),)(encoded) | |
for i, f in enumerate(self.n_filters[-2::-1]): | |
idx_layer += 1 | |
x = ks.layers.Conv3D( | |
self.n_filters[i], | |
self.sz_kernel, | |
activation='relu', | |
padding=self.padding, | |
data_format=self.data_format, | |
name='conv3d-'+str(idx_layer),)(x) | |
x = ks.layers.UpSampling3D( | |
self.sz_upsample, | |
data_format=self.data_format, | |
name='upsample3d-'+str(idx_layer), )(x) | |
decoded = ks.layers.Conv3D( | |
1, self.sz_kernel, | |
activation='sigmoid', | |
data_format=self.data_format, | |
padding=self.padding, | |
name='conv3d-'+str(idx_layer+1))(x) | |
return decoded | |
def fit(self, train_set, | |
n_epochs, | |
sz_batch, | |
valid_set=None, | |
optimizer='adam', | |
shuffle=True, | |
generator=None, | |
check_point=None, | |
use_multiprocessing=False, | |
workers=1, | |
callbacks=None, | |
max_queue_size=1, | |
fname=None): | |
""" | |
""" | |
self.model.compile(optimizer=optimizer, loss=self.loss) | |
if not isinstance(train_set, tuple): | |
train_set = (train_set,) * 2 | |
if not isinstance(valid_set, tuple) and valid_set is not None: | |
valid_set = (valid_set,) * 2 | |
if fname is None: | |
dt = datetime.datetime.now() | |
fname = self.name + '-' + dt.strftime("%Y-%m-%d-%H-%M-%S") | |
if not hasattr(generator, '__call__'): | |
self.model.fit( | |
*train_set, | |
epochs=n_epochs, | |
batch_size=sz_batch, | |
shuffle=shuffle, | |
validation_data=valid_set, | |
callbacks=callbacks) | |
self.model.save(fname + '-end' + '.h5') | |
else: | |
self.model.fit_generator( | |
generator=generator(*train_set), | |
validation_data=valid_set, | |
steps_per_epoch=train_set[0].shape[0] // sz_batch, | |
validation_steps=valid_set[0].shape[0] // sz_batch, | |
epochs=n_epochs, | |
shuffle=False, | |
use_multiprocessing=use_multiprocessing, | |
workers=workers, | |
callbacks=callbacks, | |
max_queue_size=max_queue_size,) | |
self.model.save(fname + '-end' + '.h5') | |
seed = 42 | |
n_x = 1000 | |
sz_input = (1, 204, 20, 20) | |
data_format = 'channels_first' | |
np.random.seed(seed) | |
x = np.random.rand(n_x, *sz_input) | |
n_filters = [30, 20, 10] | |
sz_kernel = (5, 5, 5) | |
sz_pool = (2, 2, 2) | |
sz_upsample = (2, 2, 2) | |
# define model | |
conv3d_ae = Conv3DAutoEncoder( | |
sz_input, | |
n_filters, | |
sz_kernel, | |
sz_pool, | |
sz_upsample, | |
data_format=data_format) | |
# print out model summary | |
conv3d_ae.model.summary() | |
n_epochs = 2 | |
sz_batch = 50 | |
suffle = True | |
conv3d_ae.fit(x, | |
n_epochs, sz_batch, | |
valid_set=None, | |
shuffle=True, | |
generator=None, | |
check_point=None) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment