Skip to content

Instantly share code, notes, and snippets.

@snakers4
Created September 12, 2017 09:12
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save snakers4/fb0dc2eb260635608bad05f001ccc1e0 to your computer and use it in GitHub Desktop.
Save snakers4/fb0dc2eb260635608bad05f001ccc1e0 to your computer and use it in GitHub Desktop.
from keras.models import Model
from keras.layers import Input, concatenate, Conv2D, MaxPooling2D, Activation, UpSampling2D, BatchNormalization,multiply
from keras.optimizers import RMSprop
from model.losses import bce_dice_loss, dice_loss, weighted_bce_dice_loss, weighted_dice_loss, dice_coeff
import params
orig_width = 1918
orig_height = 1280
def gate_unit(f_i, f_i1):
f_i = Conv2D(8, (3, 3), padding='same')(f_i)
f_i = BatchNormalization()(f_i)
f_i = Activation('relu')(f_i)
f_i1 = Conv2D(8, (3, 3), padding='same')(f_i1)
f_i1 = BatchNormalization()(f_i1)
f_i1 = Activation('relu')(f_i1)
f_i1 = UpSampling2D((2, 2))(f_i1)
output = multiply([f_i,f_i1])
return output
def gated_refinement_unit(r_f, m_f, num_classes):
m_f = Conv2D(num_classes, (3, 3), padding='same')(m_f)
m_f = BatchNormalization()(m_f)
m_f = Activation('relu')(m_f)
output = concatenate([r_f, m_f], axis=3)
output = Conv2D(num_classes, (3, 3), padding='same')(output)
output = UpSampling2D((2, 2))(output)
return output
def g_frnet(input_shape=(orig_width+2, orig_height, 3),
num_classes=1):
inputs = Input(shape=input_shape)
# 1024
f1 = Conv2D(8, (3, 3), padding='same')(inputs)
f1 = BatchNormalization()(f1)
f1 = Activation('relu')(f1)
f1 = Conv2D(8, (3, 3), padding='same')(f1)
f1 = BatchNormalization()(f1)
f1 = Activation('relu')(f1)
f1_pool = MaxPooling2D((2, 2), strides=(2, 2))(f1)
# 512
f2 = Conv2D(16, (3, 3), padding='same')(f1_pool)
f2 = BatchNormalization()(f2)
f2 = Activation('relu')(f2)
f2 = Conv2D(16, (3, 3), padding='same')(f2)
f2 = BatchNormalization()(f2)
f2 = Activation('relu')(f2)
f2_pool = MaxPooling2D((2, 2), strides=(2, 2))(f2)
# 256
f3 = Conv2D(32, (3, 3), padding='same')(f2_pool)
f3 = BatchNormalization()(f3)
f3 = Activation('relu')(f3)
f3 = Conv2D(32, (3, 3), padding='same')(f3)
f3 = BatchNormalization()(f3)
f3 = Activation('relu')(f3)
f3_pool = MaxPooling2D((2, 2), strides=(2, 2))(f3)
# 128
f4 = Conv2D(64, (3, 3), padding='same')(f3_pool)
f4 = BatchNormalization()(f4)
f4 = Activation('relu')(f4)
f4 = Conv2D(64, (3, 3), padding='same')(f4)
f4 = BatchNormalization()(f4)
f4 = Activation('relu')(f4)
f4_pool = MaxPooling2D((2, 2), strides=(2, 2))(f4)
# 64
f5 = Conv2D(128, (3, 3), padding='same')(f4_pool)
f5 = BatchNormalization()(f5)
f5 = Activation('relu')(f5)
f5 = Conv2D(128, (3, 3), padding='same')(f5)
f5 = BatchNormalization()(f5)
f5 = Activation('relu')(f5)
f5_pool = MaxPooling2D((2, 2), strides=(2, 2))(f5)
# 32
f6 = Conv2D(256, (3, 3), padding='same')(f5_pool)
f6 = BatchNormalization()(f6)
f6 = Activation('relu')(f6)
f6 = Conv2D(256, (3, 3), padding='same')(f6)
f6 = BatchNormalization()(f6)
f6 = Activation('relu')(f6)
f6_pool = MaxPooling2D((2, 2), strides=(2, 2))(f6)
# 16
f7 = Conv2D(512, (3, 3), padding='same')(f6_pool)
f7 = BatchNormalization()(f7)
f7 = Activation('relu')(f7)
f7 = Conv2D(512, (3, 3), padding='same')(f7)
f7 = BatchNormalization()(f7)
f7 = Activation('relu')(f7)
f7_pool = MaxPooling2D((2, 2), strides=(2, 2))(f7)
ru0 = Conv2D(num_classes, (1, 1), activation='sigmoid')(f7)
# 8
g1 = gate_unit(f_i = f6_pool, f_i1 = f7_pool)
ru1 = gated_refinement_unit (r_f = ru0, m_f = g1, num_classes=num_classes)
g2 = gate_unit(f_i = f5_pool, f_i1 = f6_pool)
ru2 = gated_refinement_unit (r_f = ru1, m_f = g2, num_classes=num_classes)
g3 = gate_unit(f_i = f4_pool, f_i1 = f5_pool)
ru3 = gated_refinement_unit (r_f = ru2, m_f = g3, num_classes=num_classes)
g4 = gate_unit(f_i = f3_pool, f_i1 = f4_pool)
ru4 = gated_refinement_unit (r_f = ru3, m_f = g4, num_classes=num_classes)
g5 = gate_unit(f_i = f2_pool, f_i1 = f3_pool)
ru5 = gated_refinement_unit (r_f = ru4, m_f = g5, num_classes=num_classes)
g6 = gate_unit(f_i = f1_pool, f_i1 = f2_pool)
ru6 = gated_refinement_unit (r_f = ru5, m_f = g6, num_classes=num_classes)
model = Model(inputs=inputs, outputs=[ru6,ru5,ru4,ru3,ru2,ru1,ru0])
model.compile(optimizer=RMSprop(lr=0.0001), loss=bce_dice_loss, metrics=[dice_coeff], loss_weights = [0.15,0.15,0.15,0.15,0.15,0.15,0.1])
return model
def gate_unit_selu(f_i, f_i1):
# f_i = Conv2D(8, (3, 3), padding='same')(f_i)
# f_i = BatchNormalization()(f_i)
# f_i = Activation('selu')(f_i)
f_i = Conv2D(64, (3, 3), padding='same')(f_i)
f_i = BatchNormalization()(f_i)
f_i = Activation('selu')(f_i)
f_i = Conv2D(64, (3, 3), padding='same')(f_i)
f_i = BatchNormalization()(f_i)
f_i = Activation('selu')(f_i)
# f_i1 = Conv2D(8, (3, 3), padding='same')(f_i1)
# f_i1 = BatchNormalization()(f_i1)
# f_i1 = Activation('selu')(f_i1)
# f_i1 = UpSampling2D((2, 2))(f_i1)
f_i1 = Conv2D(64, (3, 3), padding='same')(f_i1)
f_i1 = BatchNormalization()(f_i1)
f_i1 = Activation('selu')(f_i1)
f_i1 = Conv2D(64, (3, 3), padding='same')(f_i1)
f_i1 = BatchNormalization()(f_i1)
f_i1 = Activation('selu')(f_i1)
f_i1 = UpSampling2D((2, 2))(f_i1)
output = multiply([f_i,f_i1])
return output
def gated_refinement_unit_selu(r_f, m_f, num_classes):
m_f = Conv2D(num_classes, (3, 3), padding='same')(m_f)
m_f = BatchNormalization()(m_f)
m_f = Activation('selu')(m_f)
output = concatenate([r_f, m_f], axis=3)
output = Conv2D(num_classes, (3, 3), padding='same')(output)
output = UpSampling2D((2, 2))(output)
return output
def g_frnet_selu(input_shape=(orig_height,orig_width+2,3),
num_classes=1):
inputs = Input(shape=input_shape)
# 1024
f1 = Conv2D(8, (3, 3), padding='same', name='f1_conv_1')(inputs)
f1 = BatchNormalization()(f1)
f1 = Activation('selu')(f1)
f1 = Conv2D(8, (3, 3), padding='same', name='f1_conv_2')(f1)
f1 = BatchNormalization()(f1)
f1 = Activation('selu')(f1)
f1_pool = MaxPooling2D((2, 2), strides=(2, 2))(f1)
# 512
f2 = Conv2D(16, (3, 3), padding='same', name='f2_conv_1')(f1_pool)
f2 = BatchNormalization()(f2)
f2 = Activation('selu')(f2)
f2 = Conv2D(16, (3, 3), padding='same', name='f2_conv_2')(f2)
f2 = BatchNormalization()(f2)
f2 = Activation('selu')(f2)
f2_pool = MaxPooling2D((2, 2), strides=(2, 2))(f2)
# 256
f3 = Conv2D(32, (3, 3), padding='same', name='f3_conv_1')(f2_pool)
f3 = BatchNormalization()(f3)
f3 = Activation('selu')(f3)
f3 = Conv2D(32, (3, 3), padding='same', name='f3_conv_2')(f3)
f3 = BatchNormalization()(f3)
f3 = Activation('selu')(f3)
f3_pool = MaxPooling2D((2, 2), strides=(2, 2))(f3)
# 128
f4 = Conv2D(64, (3, 3), padding='same', name='f4_conv_1')(f3_pool)
f4 = BatchNormalization()(f4)
f4 = Activation('selu')(f4)
f4 = Conv2D(64, (3, 3), padding='same', name='f4_conv_2')(f4)
f4 = BatchNormalization()(f4)
f4 = Activation('selu')(f4)
f4_pool = MaxPooling2D((2, 2), strides=(2, 2))(f4)
# 64
f5 = Conv2D(128, (3, 3), padding='same', name='f5_conv_1')(f4_pool)
f5 = BatchNormalization()(f5)
f5 = Activation('selu')(f5)
f5 = Conv2D(128, (3, 3), padding='same', name='f5_conv_2')(f5)
f5 = BatchNormalization()(f5)
f5 = Activation('selu')(f5)
f5_pool = MaxPooling2D((2, 2), strides=(2, 2))(f5)
# 32
f6 = Conv2D(256, (3, 3), padding='same', name='f6_conv_1')(f5_pool)
f6 = BatchNormalization()(f6)
f6 = Activation('selu')(f6)
f6 = Conv2D(256, (3, 3), padding='same', name='f6_conv_2')(f6)
f6 = BatchNormalization()(f6)
f6 = Activation('selu')(f6)
f6_pool = MaxPooling2D((2, 2), strides=(2, 2))(f6)
# 16
f7 = Conv2D(512, (3, 3), padding='same', name='f7_conv_1')(f6_pool)
f7 = BatchNormalization()(f7)
f7 = Activation('selu')(f7)
f7 = Conv2D(512, (3, 3), padding='same', name='f7_conv_2')(f7)
f7 = BatchNormalization()(f7)
f7 = Activation('selu')(f7)
f7_pool = MaxPooling2D((2, 2), strides=(2, 2))(f7)
ru0 = Conv2D(num_classes, (1, 1), activation='sigmoid', name='ru_0_conv')(f7)
# 8
g1 = gate_unit_selu(f_i = f6_pool, f_i1 = f7_pool)
ru1 = gated_refinement_unit_selu (r_f = ru0, m_f = g1, num_classes=num_classes)
g2 = gate_unit_selu(f_i = f5_pool, f_i1 = f6_pool)
ru2 = gated_refinement_unit_selu (r_f = ru1, m_f = g2, num_classes=num_classes)
g3 = gate_unit_selu(f_i = f4_pool, f_i1 = f5_pool)
ru3 = gated_refinement_unit_selu (r_f = ru2, m_f = g3, num_classes=num_classes)
g4 = gate_unit_selu(f_i = f3_pool, f_i1 = f4_pool)
ru4 = gated_refinement_unit_selu (r_f = ru3, m_f = g4, num_classes=num_classes)
g5 = gate_unit_selu(f_i = f2_pool, f_i1 = f3_pool)
ru5 = gated_refinement_unit_selu (r_f = ru4, m_f = g5, num_classes=num_classes)
g6 = gate_unit_selu(f_i = f1_pool, f_i1 = f2_pool)
ru6 = gated_refinement_unit_selu (r_f = ru5, m_f = g6, num_classes=num_classes)
model = Model(inputs=inputs, outputs=[ru6,ru5,ru4,ru3,ru2,ru1,ru0])
model.compile(optimizer=RMSprop(lr=0.0001), loss=bce_dice_loss, metrics=[dice_coeff], loss_weights = [0.15,0.15,0.15,0.15,0.15,0.15,0.1])
return model
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment