Skip to content

Instantly share code, notes, and snippets.

@p-geon
Created August 10, 2021 09:29
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save p-geon/88ec822e8a917c67bbbb5a976d13e60c to your computer and use it in GitHub Desktop.
Save p-geon/88ec822e8a917c67bbbb5a976d13e60c to your computer and use it in GitHub Desktop.
def unet(x, ch, depth, cut_path=False):
conv_kwargs = {'kernel_size': (3, 3), 'strides'=(1, 1), 'dilation_rate'=(1, 1),
'padding'='same', 'use_bias'=True, 'bias_initializer'='zeros', 'kernel_initializer'='he_normal'}
for _ in range(2):
x = tf.keras.layers.Convolution2D(ch, **conv_kwargs)(x)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.Activation('relu')(x)
if(depth != 0):
path = x
x = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(x)
x = unet(x, ch=ch*2, depth=depth-1, cut_path=cut_path)
x = tf.keras.layers.UpSampling2D((2, 2))(x)
if(cut_path==False):
x = tf.keras.layers.Concatenate(axis=-1)([x, path])
for _ in range(2):
x = tf.keras.layers.Convolution2D(ch, **conv_kwargs)(x)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.Activation('relu')(x)
return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment