Skip to content

Instantly share code, notes, and snippets.

@Geoyi
Forked from galtay/unet-deconv
Created February 17, 2017 02:31
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Geoyi/0f9c6bac2f49479fdf4fa0148ab3ed61 to your computer and use it in GitHub Desktop.
Save Geoyi/0f9c6bac2f49479fdf4fa0148ab3ed61 to your computer and use it in GitHub Desktop.
unet implementation
def unet_model(batch_size, npix_in, n_channels, n_filters, n_classes, activation='relu'):
input_layer = Input(batch_shape=(batch_size, npix_in, npix_in, n_channels), name='input')
dblk1_conv1 = Convolution2D(n_filters, 3, 3, activation=activation, name='dblk1_conv1')(input_layer)
dblk1_conv2 = Convolution2D(n_filters, 3, 3, activation=activation, name='dblk1_conv2')(dblk1_conv1)
dblk1_pool = MaxPooling2D(pool_size=(2,2), name='dblk1_pool')(dblk1_conv2)
dblk2_conv1 = Convolution2D(n_filters * 2, 3, 3, activation=activation, name='dblk2_conv1')(dblk1_pool)
dblk2_conv2 = Convolution2D(n_filters * 2, 3, 3, activation=activation, name='dblk2_conv2')(dblk2_conv1)
dblk2_pool = MaxPooling2D(pool_size=(2,2), name='dblk2_pool')(dblk2_conv2)
dblk3_conv1 = Convolution2D(n_filters * 4, 3, 3, activation=activation, name='dblk3_conv1')(dblk2_pool)
dblk3_conv2 = Convolution2D(n_filters * 4, 3, 3, activation=activation, name='dblk3_conv2')(dblk3_conv1)
dblk3_pool = MaxPooling2D(pool_size=(2,2), name='dblk3_pool')(dblk3_conv2)
dblk4_conv1 = Convolution2D(n_filters * 8, 3, 3, activation=activation, name='dblk4_conv1')(dblk3_pool)
dblk4_conv2 = Convolution2D(n_filters * 8, 3, 3, activation=activation, name='dblk4_conv2')(dblk4_conv1)
dblk4_drop = Dropout(0.5, name='dblk4_drop')(dblk4_conv2)
dblk4_pool = MaxPooling2D(pool_size=(2,2), name='dblk4_pool')(dblk4_drop)
bottom_conv1 = Convolution2D(n_filters * 16, 3, 3, activation=activation, name='bottom_conv1')(dblk4_pool)
bottom_conv2 = Convolution2D(n_filters * 16, 3, 3, activation=activation, name='bottom_conv2')(bottom_conv1)
bottom_drop = Dropout(0.5, name='bottom_drop')(bottom_conv2)
outpix = bottom_drop.get_shape()[1].value * 2
diff = dblk4_drop.get_shape()[1].value - outpix
cpix = diff//2
ublk4_deconv = Deconvolution2D(
n_filters * 8, 2, 2, output_shape=(batch_size, outpix, outpix, n_filters * 8),
subsample=(2,2), activation=activation, name='ublk4_deconv')(bottom_drop)
ublk4_crop = Cropping2D(cropping=((cpix,cpix),(cpix,cpix)), name='ublk4_crop')(dblk4_drop)
ublk4_concat = merge([ublk4_crop, ublk4_deconv], mode='concat', concat_axis=3, name='ublk4_concat')
ublk4_conv1 = Convolution2D(n_filters * 8, 3, 3, activation=activation, name='ublk4_conv1')(ublk4_concat)
ublk4_conv2 = Convolution2D(n_filters * 8, 3, 3, activation=activation, name='ublk4_conv2')(ublk4_conv1)
outpix = ublk4_conv2.get_shape()[1].value * 2
diff = dblk3_conv2.get_shape()[1].value - outpix
cpix = diff//2
ublk3_deconv = Deconvolution2D(
n_filters * 4, 2, 2, output_shape=(batch_size, outpix, outpix, n_filters * 4),
subsample=(2,2), activation=activation, name='ublk3_deconv')(ublk4_conv2)
ublk3_crop = Cropping2D(cropping=((cpix,cpix),(cpix,cpix)), name='ublk3_crop')(dblk3_conv2)
ublk3_concat = merge([ublk3_crop, ublk3_deconv], mode='concat', concat_axis=3, name='ublk3_concat')
ublk3_conv1 = Convolution2D(n_filters * 4, 3, 3, activation=activation, name='ublk3_conv1')(ublk3_concat)
ublk3_conv2 = Convolution2D(n_filters * 4, 3, 3, activation=activation, name='ublk3_conv2')(ublk3_conv1)
outpix = ublk3_conv2.get_shape()[1].value * 2
diff = dblk2_conv2.get_shape()[1].value - outpix
cpix = diff//2
ublk2_deconv = Deconvolution2D(
n_filters * 2, 2, 2, output_shape=(batch_size, outpix, outpix, n_filters * 2),
subsample=(2,2), activation=activation, name='ublk2_deconv')(ublk3_conv2)
ublk2_crop = Cropping2D(cropping=((cpix,cpix),(cpix,cpix)), name='ublk2_crop')(dblk2_conv2)
ublk2_concat = merge([ublk2_crop, ublk2_deconv], mode='concat', concat_axis=3, name='ublk2_concat')
ublk2_conv1 = Convolution2D(n_filters * 2, 3, 3, activation=activation, name='ublk2_conv1')(ublk2_concat)
ublk2_conv2 = Convolution2D(n_filters * 2, 3, 3, activation=activation, name='ublk2_conv2')(ublk2_conv1)
outpix = ublk2_conv2.get_shape()[1].value * 2
diff = dblk1_conv2.get_shape()[1].value - outpix
cpix = diff//2
ublk1_deconv = Deconvolution2D(
n_filters, 2, 2, output_shape=(batch_size, outpix, outpix, n_filters),
subsample=(2,2), activation=activation, name='ublk1_deconv')(ublk2_conv2)
ublk1_crop = Cropping2D(cropping=((cpix,cpix),(cpix,cpix)), name='ublk1_crop')(dblk1_conv2)
ublk1_concat = merge([ublk1_crop, ublk1_deconv], mode='concat', concat_axis=3, name='ublk1_concat')
ublk1_conv1 = Convolution2D(n_filters, 3, 3, activation=activation, name='ublk1_conv1')(ublk1_concat)
ublk1_conv2 = Convolution2D(n_filters, 3, 3, activation=activation, name='ublk1_conv2')(ublk1_conv1)
output_layer1 = Convolution2D(n_classes, 1, 1, name='logits')(ublk1_conv2)
outpix = output_layer1.get_shape()[1].value
shape = (outpix * outpix, n_classes)
output_layer2 = Reshape(shape)(output_layer1)
output_layer3 = Activation('sigmoid', name='sigmoid')(output_layer2)
output = output_layer3
model = Model(input=input_layer, output=output)
return model
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment