Skip to content

Instantly share code, notes, and snippets.

@Abhishek-Shaw-Kolkata
Created March 13, 2021 12:37
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Abhishek-Shaw-Kolkata/e5f762e5d10383b7df2dc7c2c3faf523 to your computer and use it in GitHub Desktop.
Save Abhishek-Shaw-Kolkata/e5f762e5d10383b7df2dc7c2c3faf523 to your computer and use it in GitHub Desktop.
def conv_block(input, num_filters):
x = Conv2D(num_filters, 3, padding="same")(input)
x = BatchNormalization()(x)
x = Activation("relu")(x)
x = Conv2D(num_filters, 3, padding="same")(x)
x = BatchNormalization()(x)
x = Activation("relu")(x)
return x
def encoder_block(input, num_filters):
x = conv_block(input, num_filters)
p = MaxPool2D((2, 2))(x)
return x, p
def decoder_block(input, skip_features, num_filters):
x = Conv2DTranspose(num_filters, (2, 2), strides=2, padding="same")(input)
x = Concatenate()([x, skip_features])
x = conv_block(x, num_filters)
return x
def build_unet(input_shape):
inputs = Input(input_shape)
s1, p1 = encoder_block(inputs, 64)
s2, p2 = encoder_block(p1, 128)
s3, p3 = encoder_block(p2, 256)
s4, p4 = encoder_block(p3, 512)
b1 = conv_block(p4, 1024)
d1 = decoder_block(b1, s4, 512)
d2 = decoder_block(d1, s3, 256)
d3 = decoder_block(d2, s2, 128)
d4 = decoder_block(d3, s1, 64)
outputs = Conv2D(1, 1, padding="same", activation="sigmoid")(d4)
model = Model(inputs, outputs, name="U-Net")
return model
input_shape = (256, 256, 3)
model_seg = build_unet(input_shape)
model_seg.compile(optimizer=tf.keras.optimizers.Adam(lr=0.001), loss=combined_loss, metrics=[dice_coef])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment