Skip to content

Instantly share code, notes, and snippets.

@CatalyzeX
Last active October 1, 2020 16:14
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save CatalyzeX/d6af1fb287e84b8f73e30af775ff4b43 to your computer and use it in GitHub Desktop.
Save CatalyzeX/d6af1fb287e84b8f73e30af775ff4b43 to your computer and use it in GitHub Desktop.
Network source code for Learning Deep Similarity Metric for 3D MR-TRUS Registration https://www.catalyzex.com/paper/arxiv:1806.04548
import keras
from keras import backend as K, optimizers
from keras.callbacks import EarlyStopping, ModelCheckpoint
#from keras.preprocessing.image import ImageDataGenerator
from keras.models import Model
from keras.layers import Input, concatenate
from keras.layers import Dense, Dropout, Flatten
from keras.layers import Conv3D, BatchNormalization
from keras.optimizers import RMSprop
from keras import regularizers
import matplotlib.pyplot as plt
#import numpy as np
from os import path
import sys
import time
sys.path.append("../utils")
import volume_data_generator as vdg
# %% input image dimensions
time_start= time.time()
img_rows, img_cols = 96, 96
depth = 32
#
img_channels = 2
# mini batch size
mbs = 32
# %%
data_folder = '/home/data/uronav_data'
home_folder = path.expanduser('~')
tmp_folder = path.join(home_folder, 'tmp')
vdg_train = vdg.VolumeDataGenerator(data_folder, (71,750),
max_registration_error=20)
print('{} cases for training'.format(vdg_train.get_num_cases()))
trainGen = vdg_train.generate_batch(batch_size=mbs,
shape=(img_cols,img_rows,depth))
vdg_val = vdg.VolumeDataGenerator(data_folder, (1,70),
max_registration_error=20)
print('{} cases for validation'.format(vdg_val.get_num_cases()))
valGen = vdg_val.generate_batch(batch_size=mbs,
shape=(img_cols,img_rows,depth))
fn_model = path.join(tmp_folder, 'trained_3d_regression_Adagrad.h5')
# %% Create CNN model
input3D = Input(shape=(depth, img_rows, img_cols, img_channels), name='input')
x = Conv3D(32, 3, strides=(1,1,1), activation='relu', padding='same')(input3D)
x = Conv3D(32, 3, strides=(1,2,2), activation='relu', padding='same')(x)
x = BatchNormalization()(x)
x = Conv3D(64, 3, strides=(1,1,1),
activation='relu',
padding='same',
kernel_regularizer=regularizers.l2(0.001),
activity_regularizer=regularizers.l1(0.001))(x)
#
# Have to use this pooling step, otherwise not enough memory in Titan X
x_0 = Conv3D(64, 3, strides=(2,2,2), activation='relu', padding='same')(x)
x = Conv3D(128, 3, strides=(1,1,1), activation='relu', padding='same')(x_0)
x = Conv3D(128, 3, strides=(1,1,1), activation='relu', padding='same')(x)
x_1 = Conv3D(64, 3, strides=(1,1,1),
activation='relu',
padding='same',
kernel_regularizer=regularizers.l1(0.001))(x)
x_01 = concatenate([x_0, x_1])
x_01 = BatchNormalization()(x_01)
x2 = Conv3D(128, 3, strides=(1,1,1), activation='relu', padding='same')(x_01)
x = Conv3D(16, 5, strides=(1,1,1), activation='relu', padding='valid')(x2)
c = Flatten()(x)
c = Dropout(0.25)(c)
c = Dense(128, activation='relu')(c)
c = Dropout(0.25)(c)
#
c = Dense(32, activation='relu')(c)
c = Dropout(0.1)(c)
#
output = Dense(1, name='output')(c)
#
model = Model(inputs=[input3D], outputs = [output])
rmsprop = RMSprop(lr=5e-6)
sgd = optimizers.SGD(lr=5e-6, decay=1e-6, momentum=0, nesterov=False)
Adam = optimizers.Adam(lr=5e-6, beta_1=0.9, beta_2=0.999, epsilon=K.epsilon(), decay=0.0)
Adadelta = optimizers.Adadelta(lr = 5e-6, epsilon = K.epsilon(), decay = 0.0)
Adamax = optimizers.Adamax(lr=5e-6, beta_1=0.9, beta_2=0.999, epsilon=K.epsilon(), decay=0.0)
Adagrad = optimizers.Adagrad(lr = 5e-6,epsilon = K.epsilon(), decay = 0.0)
###Alter this:
model.compile(loss='mean_squared_error',
optimizer=Adagrad,
metrics=['mae'])
####
#
# print network structure
#
layers = model.layers
print('=' * 70)
for idx, layer in enumerate(layers):
if (type(layer) is keras.layers.core.Activation or
type(layer) is keras.layers.core.Dropout):
continue
print('Layer {}: type={}'.format(idx, type(layer)))
print(' '*9 + 'output->{}'.format(layer.output_shape))
if (type(layer) is keras.layers.convolutional.Conv3D or
type(layer) is keras.layers.core.Dense):
weights_shape = layer.get_weights()[0].shape
print(' '*9 + 'weights->{}'.format(weights_shape))
print('=' * 70)
# %% Start training
num_epoch = 100
earlyStopping = EarlyStopping(monitor='val_mean_absolute_error',
patience=5,
verbose=0,
mode='auto')
modelCheckPoint = ModelCheckpoint(fn_model,
monitor='val_mean_absolute_error',
save_best_only=True,
verbose=0)
history = model.fit_generator(trainGen,
steps_per_epoch = 50,
epochs=num_epoch,
validation_data=valGen,
validation_steps = 50,
callbacks=[earlyStopping, modelCheckPoint])
#model.save_weights('first_try.h5')
train_loss = history.history['loss']
train_acc = history.history['mean_absolute_error']
val_loss = history.history['val_loss']
val_acc = history.history['val_mean_absolute_error']
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment