Skip to content

Instantly share code, notes, and snippets.

@ajoydas
Created March 13, 2021 20:56
Show Gist options
  • Save ajoydas/cf39b552448927c3a74283128b626189 to your computer and use it in GitHub Desktop.
Save ajoydas/cf39b552448927c3a74283128b626189 to your computer and use it in GitHub Desktop.
import numpy as np
import tensorflow as tf
def get_unet_mod(patch_size = (560,560),learning_rate = 1e-3,\
learning_decay = 1e-6, drop_out = 0.1,nchannels = 1,kshape = (3,3)):
''' Get U-Net model with gaussian noise and dropout'''
dropout = drop_out
input_img = tf.keras.layers.Input((patch_size[0], patch_size[1],nchannels))
conv1 = tf.keras.layers.Conv2D(64, kshape, activation='relu', padding='same')(input_img)
conv1 = tf.keras.layers.Conv2D(64, kshape, activation='relu', padding='same')(conv1)
pool1 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(conv1)
conv2 = tf.keras.layers.Conv2D(128, kshape, activation='relu', padding='same')(pool1)
conv2 = tf.keras.layers.Conv2D(128, kshape, activation='relu', padding='same')(conv2)
pool2 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(conv2)
conv3 = tf.keras.layers.Conv2D(256, kshape, activation='relu', padding='same')(pool2)
conv3 = tf.keras.layers.Conv2D(256, kshape, activation='relu', padding='same')(conv3)
pool3 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(conv3)
conv4 = tf.keras.layers.Conv2D(512, kshape, activation='relu', padding='same')(pool3)
conv4 = tf.keras.layers.Conv2D(512, kshape, activation='relu', padding='same')(conv4)
pool4 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(conv4)
pool4 = tf.keras.layers.Dropout(dropout)(pool4)
conv5 = tf.keras.layers.Conv2D(1024, kshape, activation='relu', padding='same')(pool4)
conv5 = tf.keras.layers.Conv2D(1024, kshape, activation='relu', padding='same')(conv5)
up6 = tf.keras.layers.concatenate([tf.keras.layers.UpSampling2D(size=(2, 2))(conv5), conv4],axis=-1)
up6 = tf.keras.layers.Dropout(dropout)(up6)
conv6 = tf.keras.layers.Conv2D(512, kshape, activation='relu', padding='same')(up6)
conv6 = tf.keras.layers.Conv2D(512, kshape, activation='relu', padding='same')(conv6)
up7 = tf.keras.layers.concatenate([tf.keras.layers.UpSampling2D(size=(2, 2))(conv6), conv3],axis=-1)
up7 = tf.keras.layers.Dropout(dropout)(up7)
conv7 = tf.keras.layers.Conv2D(256, kshape, activation='relu', padding='same')(up7)
conv7 = tf.keras.layers.Conv2D(256, kshape, activation='relu', padding='same')(conv7)
up8 = tf.keras.layers.concatenate([tf.keras.layers.UpSampling2D(size=(2, 2))(conv7), conv2],axis=-1)
up8 = tf.keras.layers.Dropout(dropout)(up8)
conv8 = tf.keras.layers.Conv2D(128, kshape, activation='relu', padding='same')(up8)
conv8 = tf.keras.layers.Conv2D(128, kshape, activation='relu', padding='same')(conv8)
up9 = tf.keras.layers.concatenate([tf.keras.layers.UpSampling2D(size=(2, 2))(conv8), conv1], axis=-1)
up9 = tf.keras.layers.Dropout(dropout)(up9)
conv9 = tf.keras.layers.Conv2D(64, kshape, activation='relu', padding='same')(up9)
conv9 = tf.keras.layers.Conv2D(64, kshape, activation='relu', padding='same')(conv9)
conv10 = tf.keras.layers.Conv2D(3, (1, 1), activation='linear')(conv9)
out = tf.keras.layers.Add()([conv10, input_img])
model = tf.keras.models.Model(inputs=input_img, outputs=out)
opt = tf.keras.optimizers.Adam(lr= learning_rate, decay = learning_decay)
model.compile(optimizer= opt,loss='mse')
return model
model = get_unet_mod()
print(model.summary())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment