Skip to content

Instantly share code, notes, and snippets.

@keunwoochoi
Created October 11, 2017 00:49
Show Gist options
  • Save keunwoochoi/6d9e7d200582384a3bdc2ca69b35d4f9 to your computer and use it in GitHub Desktop.
Save keunwoochoi/6d9e7d200582384a3bdc2ca69b35d4f9 to your computer and use it in GitHub Desktop.
Keras-unet
import keras
from keras import backend as K
from keras.layers.convolutional import Conv2D, Conv2DTranspose
from keras.layers import Input, Dense, Activation
from keras.layers import concatenate # functional interface
from keras.models import Model
from keras.layers.advanced_activations import LeakyReLU
N_INPUT = 512
def get_unet():
n_ch_exps = [4, 5, 6, 6, 7, 7]
kernels = (5, 5)
if K.image_data_format() == 'channels_first':
ch_axis = 1
input_shape = (1, N_INPUT, N_INPUT)
elif K.image_data_format() == 'channels_last':
ch_axis = 3
input_shape = (N_INPUT, N_INPUT, 1)
inp = Input(shape=input_shape)
encodeds = []
# encoder
enc = inp
for l_idx, n_ch in enumerate(n_ch_exps):
enc = Conv2D(2 ** n_ch, kernels,
strides=(2, 2), padding='same',
kernel_initializer='he_normal')(enc)
enc = LeakyReLU(name='encoded_{}'.format(l_idx),
alpha=0.2)(enc)
encodeds.append(enc)
# decoder
dec = enc
decoder_n_chs = n_ch_exps[::-1][1:]
for l_idx, n_ch in enumerate(decoder_n_chs):
l_idx_rev = len(n_ch_exps) - l_idx - 2 #
dec = Conv2DTranspose(2 ** n_ch, kernels,
strides=(2, 2), padding='same',
kernel_initializer='he_normal',
activation='relu',
name='decoded_{}'.format(l_idx))(dec)
dec = concatenate([dec, encodeds[l_idx_rev]],
axis=ch_axis)
outp = Conv2DTranspose(1, kernels,
strides=(2, 2), padding='same',
kernel_initializer='glorot_normal',
activation='sigmoid',
name='decoded_{}'.format(l_idx + 1))(dec)
unet = Model(inputs=inp, outputs=outp)
return unet
if __name__ == "__main__":
model = get_unet()
model.summary()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment