Skip to content

Instantly share code, notes, and snippets.

Last active October 24, 2023 11:50
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save innat/7b957f33b00d5d189f900d8e19d09976 to your computer and use it in GitHub Desktop.
Save innat/7b957f33b00d5d189f900d8e19d09976 to your computer and use it in GitHub Desktop.

UNet With ImageNet as Backbone

from tensorflow import keras
from tensorflow.keras import layers as nn
from tensorflow.keras import backend as K
from tensorflow.keras import applications

def short_summary(model):
    trainable_count = np.sum(
        [K.count_params(w) for w in model.trainable_weights]
    non_trainable_count = np.sum(
        [K.count_params(w) for w in model.non_trainable_weights]
    print('Total params: {:,}'.format(trainable_count + non_trainable_count))
    print('Trainable params: {:,}'.format(trainable_count))
    print('Non-trainable params: {:,}'.format(non_trainable_count))
    'efficientnetb0': applications.EfficientNetB0,
    'resnet50': applications.ResNet50,
    'densenet121': applications.DenseNet121,
    'convnextsmall': applications.ConvNeXtSmall

    'efficientnetb0': [
        'block6a_expand_activation', 'block4a_expand_activation',
        'block3a_expand_activation', 'block2a_expand_activation'
    'resnet50': [
        'conv4_block6_2_relu', 'conv3_block4_2_relu', 
        'conv2_block3_2_relu', 'conv1_relu'
    'densenet121': [
        311, 139, 51, 4
    'convnextsmall': [
        268, 51, 26

def Conv3x3BNReLU(filters):
    def apply(input):
        x = nn.Conv2D(
            filters, kernel_size=(3,3), activation='relu', padding='same',
        x = nn.BatchNormalization()(x)
        x = nn.ReLU()(x)
        return x
    return apply

def UpsampleBlock(filters):
    def apply(x, skip=None):
        x = nn.UpSampling2D((2,2))(x)
        x = nn.Concatenate(axis=3)([skip, x]) if skip is not None else x
        x = Conv3x3BNReLU(filters)(x)
        x = Conv3x3BNReLU(filters)(x)
        return x
    return apply
def UNet(backbone, input_size, num_classes, activation, decoder_filters=[256, 128, 64, 32, 16]
    inputs = keras.Input(input_size)
    base_model = BACKBONE[backbone](weights=None, include_top=False, input_tensor=inputs)
    selected_layers = BACKBONE_ARGS[backbone]
    skip_layers = [base_model.get_layer(name).output for name in selected_layers]

    # Start Upsampling
    x = base_model.output
    for i in range(len(decoder_filters)):
        if i < len(skip_layers):
            skip = skip_layers[i]
            skip = None
        x = UpsampleBlock(decoder_filters[i])(x, skip)
    # Final layer
    x = nn.Conv2D(filters=num_classes, kernel_size=(3, 3), padding='same')(x)
    final = nn.Activation(activation, dtype='float32')(x)
    model = keras.Model(inputs=inputs, outputs=final, name='UNet')

    return model

# create model
model = UNet(backbone="resnet50", input_size=(256,256,3), num_classes=3, activation='sigmoid')
Copy link

innat commented Aug 12, 2023

UNet Basic

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

def EncoderBlock(filter):
    def apply(input):
        x = layers.Conv2D(filter, 3, padding = 'valid')(input)
        x = layers.Activation('relu')(x)
        x = layers.Conv2D(filter, 3, padding = 'valid')(x)
        x = layers.Activation('relu')(x)
        x = layers.MaxPool2D((2, 2), 2)(x)
        return x
    return apply

def DecoderBlock(filter):
    def apply(input, skip_feature):
        x = layers.Conv2DTranspose(filter, 2, 2, padding = 'valid')(input)
        skip_feature = tf.image.resize(skip_feature, size=(x.shape[1], x.shape[2]))
        x = layers.Concatenate()([x, skip_feature])
        x = layers.Conv2D(filter, 3, padding='valid')(x)
        x = layers.Activation('relu')(x)
        x = layers.Conv2D(filter, 3, padding='valid')(x)
        x = layers.Activation('relu')(x)
        return x
    return apply

def UNet(input_shape = (256, 256, 3), num_classes = 1):
    inputs = keras.Input(input_shape)
    # Contracting Path
    s1 = EncoderBlock(64)(inputs)
    s2 = EncoderBlock(128)(s1)
    s3 = EncoderBlock(256)(s2)
    s4 = EncoderBlock(512)(s3)
    # Bottleneck
    b1 = layers.Conv2D(1024, 3, padding = 'valid')(s4)
    b1 = layers.Activation('relu')(b1)
    b1 = layers.Conv2D(1024, 3, padding = 'valid')(b1)
    b1 = layers.Activation('relu')(b1)
    # Expansive Path
    s5 = DecoderBlock(512)(b1, skip_feature=s4)
    s6 = DecoderBlock(256)(s5, skip_feature=s3)
    s7 = DecoderBlock(128)(s6, skip_feature=s2)
    s8 = DecoderBlock(64)(s7, skip_feature=s1)
    outputs = layers.Conv2D(
    model = keras.Model(
    return model

model = UNet(input_shape=(224, 224, 3), num_classes=10)

Copy link

innat commented Aug 12, 2023

UNet General

def Conv3x3BNReLU(filter):
    def apply(input):
        x = layers.Conv2D(filter, 3, padding='same')(input)
        x = layers.Activation('relu')(x)
        x = layers.BatchNormalization()(x)
        x = layers.Activation('relu')(x)
        return x
    return apply

def EncoderBlock(filter):
    def apply(input):
        x = Conv3x3BNReLU(filter)(input)
        x = Conv3x3BNReLU(filter)(x)
        x = layers.MaxPooling2D(pool_size=(2, 2))(x) 
        return x
    return apply

def BottleNeck(filter):
    def apply(input):
        x = layers.Conv2D(1024, 3, padding='same')(input)
        x = layers.Activation('relu')(x)
        return x
    return apply

def DecoderBlock(filter):
    def apply(input, skip=None):
        x = layers.UpSampling2D(size=(2, 2))(input)
        x = layers.Concatenate()([x, skip]) if skip is not None else x
        x = Conv3x3BNReLU(filter)(x)
        x = Conv3x3BNReLU(filter)(x)
        return x
    return apply

def UNet(
    input_shape=(256, 256, 3), 
    inputs = keras.Input(input_shape)
    # Contracting Path
    x = inputs
    skips = []
    for filter in encoder_filters:
        x = EncoderBlock(filter)(x)
    skips = skips[::-1]  # reverse for skip connections
    # Bottleneck
    x = BottleNeck(512)(x)
    x = layers.MaxPooling2D(pool_size=(2, 2))(x)
    # Expansive Path
    for i in range(n_upsample_blocks):
        if i < len(skips):
            skip = skips[i]
            skip = None
        x = DecoderBlock(decoder_filters[i])(x, skip)
    outputs = layers.Conv2D(
        num_classes, 1, padding='same', activation=class_activation, dtype='float32'
    model = keras.Model(inputs=inputs, outputs=outputs, name='UNet')
    return model

model = UNet(
    input_shape=(512, 512, 3), 

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment