Skip to content

Instantly share code, notes, and snippets.

@ardamavi
Created June 27, 2018 12:13
Show Gist options
  • Save ardamavi/2dd5886e241b3319b8b4ce5bcb6bd134 to your computer and use it in GitHub Desktop.
Save ardamavi/2dd5886e241b3319b8b4ce5bcb6bd134 to your computer and use it in GitHub Desktop.
U-Net
# Arda Mavi
import os
from keras.models import Model
from keras.optimizers import Adam
from keras.models import model_from_json
from keras.layers import Input, Conv2D, UpSampling2D, Activation, MaxPooling2D, Flatten, Dense, concatenate, Dropout
def save_model(model, path='Data/Model/', model_name = 'model', weights_name = 'weights.h5'):
if not os.path.exists(path):
os.makedirs(path)
model_json = model.to_json()
with open(path+model+'.json', 'w') as model_file:
model_file.write(model_json)
# serialize weights to HDF5
model.save_weights(path+weights_name+'.h5')
print('Model and weights saved to ' + path+model+'.json and' + path+weights_name+'.h5')
return
def get_model(model_path, weights_path):
if not os.path.exists(model_path):
print('Model file not exists!')
return None
elif not os.path.exists(weights_path):
print('Weights file not exists!')
return None
# Getting model:
with open(model_path, 'r') as model_file:
model = model_file.read()
model = model_from_json(model)
# Getting weights
model.load_weights(weights_path)
return model
def get_unet(data_shape):
inputs = Input(shape=(data_shape))
conv_block_1 = Conv2D(32, (3, 3), strides=(1, 1), padding='same')(inputs)
conv_block_1 = Activation('relu')(conv_block_1)
conv_block_1 = Conv2D(32, (3, 3), strides=(1, 1), padding='same')(conv_block_1)
conv_block_1 = Activation('relu')(conv_block_1)
pool_block_1 = MaxPooling2D(pool_size=(2, 2), strides=(2, 2))(conv_block_1)
conv_block_2 = Conv2D(64, (3, 3), strides=(1, 1), padding='same')(pool_block_1)
conv_block_2 = Activation('relu')(conv_block_2)
conv_block_2 = Conv2D(64, (3, 3), strides=(1, 1), padding='same')(conv_block_2)
conv_block_2 = Activation('relu')(conv_block_2)
pool_block_2 = MaxPooling2D(pool_size=(2, 2), strides=(2, 2))(conv_block_2)
conv_block_3 = Conv2D(128, (3, 3), strides=(1, 1), padding='same')(pool_block_2)
conv_block_3 = Activation('relu')(conv_block_3)
conv_block_3 = Conv2D(128, (3, 3), strides=(1, 1), padding='same')(conv_block_3)
conv_block_3 = Activation('relu')(conv_block_3)
pool_block_3 = MaxPooling2D(pool_size=(2, 2), strides=(2, 2))(conv_block_3)
conv_block_4 = Conv2D(256, (3, 3), strides=(1, 1), padding='same')(pool_block_3)
conv_block_4 = Activation('relu')(conv_block_4)
conv_block_4 = Conv2D(256, (3, 3), strides=(1, 1), padding='same')(conv_block_4)
conv_block_4 = Activation('relu')(conv_block_4)
pool_block_4 = MaxPooling2D(pool_size=(2, 2), strides=(2, 2))(conv_block_4)
conv_block_5 = Conv2D(256, (3, 3), strides=(1, 1), padding='same')(pool_block_4)
conv_block_5 = Activation('relu')(conv_block_5)
conv_block_5 = Conv2D(256, (3, 3), strides=(1, 1), padding='same')(conv_block_5)
conv_block_5 = Activation('relu')(conv_block_5)
up_block_1 = UpSampling2D((2, 2))(conv_block_5)
up_block_1 = Conv2D(512, (3, 3), strides=(1, 1), padding='same')(up_block_1)
merge_1 = concatenate([conv_block_4, up_block_1])
conv_block_6 = Conv2D(256, (3, 3), strides=(1, 1), padding='same')(merge_1)
conv_block_6 = Activation('relu')(conv_block_6)
conv_block_6 = Conv2D(256, (3, 3), strides=(1, 1), padding='same')(conv_block_6)
conv_block_6 = Activation('relu')(conv_block_6)
up_block_2 = UpSampling2D((2, 2))(conv_block_6)
up_block_2 = Conv2D(256, (3, 3), strides=(1, 1), padding='same')(up_block_2)
merge_2 = concatenate([conv_block_3, up_block_2])
conv_block_7 = Conv2D(128, (3, 3), strides=(1, 1), padding='same')(merge_2)
conv_block_7 = Activation('relu')(conv_block_7)
conv_block_7 = Conv2D(128, (3, 3), strides=(1, 1), padding='same')(conv_block_7)
conv_block_7 = Activation('relu')(conv_block_7)
up_block_3 = UpSampling2D((2, 2))(conv_block_7)
up_block_3 = Conv2D(128, (3, 3), strides=(1, 1), padding='same')(up_block_3)
merge_3 = concatenate([conv_block_2, up_block_3])
conv_block_8 = Conv2D(64, (3, 3), strides=(1, 1), padding='same')(merge_3)
conv_block_8 = Activation('relu')(conv_block_8)
conv_block_8 = Conv2D(64, (3, 3), strides=(1, 1), padding='same')(conv_block_8)
conv_block_8 = Activation('relu')(conv_block_8)
up_block_4 = UpSampling2D((2, 2))(conv_block_8)
up_block_4 = Conv2D(32, (3, 3), strides=(1, 1), padding='same')(up_block_4)
merge_4 = concatenate([conv_block_1, up_block_4])
conv_block_9 = Conv2D(64, (3, 3), strides=(1, 1), padding='same')(merge_4)
conv_block_9 = Activation('relu')(conv_block_9)
conv_block_9 = Conv2D(64, (3, 3), strides=(1, 1), padding='same')(conv_block_9)
conv_block_9 = Activation('relu')(conv_block_9)
conv_block_10 = Conv2D(data_shape[-1], (1, 1), strides=(1, 1), padding='same')(conv_block_9)
outputs = Activation('sigmoid')(conv_block_10)
model = Model(inputs=inputs, outputs=outputs)
model.compile(optimizer = Adam(lr = 1e-4), loss = 'binary_crossentropy', metrics = ['accuracy'])
return model
if __name__ == '__main__':
model = get_unet((1024,1024,1))
print(model.summary())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment