Skip to content

Instantly share code, notes, and snippets.

@danyashorokh
Created October 4, 2021 13:40
Show Gist options
  • Save danyashorokh/8fad8cc94e99aca68b2aa8ecc3671ebb to your computer and use it in GitHub Desktop.
Save danyashorokh/8fad8cc94e99aca68b2aa8ecc3671ebb to your computer and use it in GitHub Desktop.
[KERAS] U-net
import tensorflow as tf
def conv2d_block(input_tensor, n_filters, kernel_size=3):
x = input_tensor
for i in range(2):
x = tf.keras.layers.Conv2D(filters=n_filters,
kernel_size=(kernel_size, kernel_size))(x)
x = tf.keras.layers.Activation('relu')(x)
return x
def encoder_block(inputs, n_filters, pool_size, dropout):
f = conv2d_block(inputs, n_filters=n_filters)
p = tf.keras.layers.MaxPooling2D(pool_size)(f)
p = tf.keras.layers.Dropout(dropout)(p)
return f, p
def encoder(inputs):
f1, p1 = encoder_block(inputs, n_filters=64, pool_size=(2, 2), dropout=0.3)
f2, p2 = encoder_block(p1, n_filters=128, pool_size=(2, 2), dropout=0.3)
f3, p3 = encoder_block(p2, n_filters=256, pool_size=(2, 2), dropout=0.3)
f4, p4 = encoder_block(p3, n_filters=512, pool_size=(2, 2), dropout=0.3)
return p4, (f1, f2, f3, f4)
def bottleneck(inputs):
bottle_neck = conv2d_block(inputs, n_filters=1024)
return bottle_neck
def decoder_block(inputs, conv_output, n_filters, kernel_size, strides, dropout):
u = tf.keras.layers.Conv2DTranspose(n_filters, kernel_size, strides=strides,
padding='same')(inputs)
c = tf.keras.layers.concatenate([u, conv_output])
c = tf.keras.layers.Dropout(dropout)(c)
c = conv2d_block(c, n_filters, kernel_size=3)
return c
def decoder(inputs, convs, n_outputs, activation='softmax'):
f1, f2, f3, f4 = convs
c6 = decoder_block(inputs, f4, n_filters=512, kernel_size=(3, 3), strides=(2, 2), dropout=0.3)
c7 = decoder_block(c6, f3, n_filters=256, kernel_size=(3, 3), strides=(2, 2), dropout=0.3)
c8 = decoder_block(c7, f2, n_filters=128, kernel_size=(3, 3), strides=(2, 2), dropout=0.3)
c9 = decoder_block(c8, f1, n_filters=64, kernel_size=(3, 3), strides=(2, 2), dropout=0.3)
outputs = tf.keras.layers.Conv2D(n_outputs, (1, 1), activation=activation)(c9)
return outputs
def unet(input_shape=(128, 128, 3), n_outputs=10, output_activation='softmax'):
inputs = tf.keras.layers.Input(shape=input_shape)
encoder_output, convs = encoder(inputs)
bottle_neck = bottleneck(encoder_output)
outputs = decoder(bottle_neck, convs, n_outputs, activation=output_activation)
model = tf.keras.Model(inputs=inputs, outputs=outputs)
return model
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment