Skip to content

Instantly share code, notes, and snippets.

@innat
Last active October 24, 2023 11:50
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save innat/7b957f33b00d5d189f900d8e19d09976 to your computer and use it in GitHub Desktop.
Save innat/7b957f33b00d5d189f900d8e19d09976 to your computer and use it in GitHub Desktop.
UNet

UNet With ImageNet as Backbone

from tensorflow import keras
from tensorflow.keras import layers as nn
from tensorflow.keras import backend as K
from tensorflow.keras import applications

def short_summary(model):
    trainable_count = np.sum(
        [K.count_params(w) for w in model.trainable_weights]
    ) 
    non_trainable_count = np.sum(
        [K.count_params(w) for w in model.non_trainable_weights]
    )
    print('Total params: {:,}'.format(trainable_count + non_trainable_count))
    print('Trainable params: {:,}'.format(trainable_count))
    print('Non-trainable params: {:,}'.format(non_trainable_count))
    
    
BACKBONE = {
    'efficientnetb0': applications.EfficientNetB0,
    'resnet50': applications.ResNet50,
    'densenet121': applications.DenseNet121,
    'convnextsmall': applications.ConvNeXtSmall
}

BACKBONE_ARGS = {
    'efficientnetb0': [
        'block6a_expand_activation', 'block4a_expand_activation',
        'block3a_expand_activation', 'block2a_expand_activation'
    ],
    
    'resnet50': [
        'conv4_block6_2_relu', 'conv3_block4_2_relu', 
        'conv2_block3_2_relu', 'conv1_relu'
    ],
    'densenet121': [
        311, 139, 51, 4
    ],
    'convnextsmall': [
        268, 51, 26
    ]
}


def Conv3x3BNReLU(filters):
    def apply(input):
        x = nn.Conv2D(
            filters, kernel_size=(3,3), activation='relu', padding='same',
        )(input)
        x = nn.BatchNormalization()(x)
        x = nn.ReLU()(x)
        return x
    
    return apply

def UpsampleBlock(filters):
    def apply(x, skip=None):
        x = nn.UpSampling2D((2,2))(x)
        x = nn.Concatenate(axis=3)([skip, x]) if skip is not None else x
        x = Conv3x3BNReLU(filters)(x)
        x = Conv3x3BNReLU(filters)(x)
        return x
    
    return apply
        
    
def UNet(backbone, input_size, num_classes, activation, decoder_filters=[256, 128, 64, 32, 16]
):
    inputs = keras.Input(input_size)
    
    base_model = BACKBONE[backbone](weights=None, include_top=False, input_tensor=inputs)
    selected_layers = BACKBONE_ARGS[backbone]
    skip_layers = [base_model.get_layer(name).output for name in selected_layers]

    # Start Upsampling
    x = base_model.output
    for i in range(len(decoder_filters)):
        if i < len(skip_layers):
            skip = skip_layers[i]
        else:
            skip = None
        x = UpsampleBlock(decoder_filters[i])(x, skip)
        
    # Final layer
    x = nn.Conv2D(filters=num_classes, kernel_size=(3, 3), padding='same')(x)
    final = nn.Activation(activation, dtype='float32')(x)
    model = keras.Model(inputs=inputs, outputs=final, name='UNet')

    return model

# create model
model = UNet(backbone="resnet50", input_size=(256,256,3), num_classes=3, activation='sigmoid')
short_summary(model)
@innat
Copy link
Author

innat commented Aug 12, 2023

UNet Basic

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers


def EncoderBlock(filter):
    def apply(input):
        x = layers.Conv2D(filter, 3, padding = 'valid')(input)
        x = layers.Activation('relu')(x)
        x = layers.Conv2D(filter, 3, padding = 'valid')(x)
        x = layers.Activation('relu')(x)
        x = layers.MaxPool2D((2, 2), 2)(x)
        return x
    
    return apply



def DecoderBlock(filter):
    def apply(input, skip_feature):
        x = layers.Conv2DTranspose(filter, 2, 2, padding = 'valid')(input)
        skip_feature = tf.image.resize(skip_feature, size=(x.shape[1], x.shape[2]))
        x = layers.Concatenate()([x, skip_feature])
        x = layers.Conv2D(filter, 3, padding='valid')(x)
        x = layers.Activation('relu')(x)
        x = layers.Conv2D(filter, 3, padding='valid')(x)
        x = layers.Activation('relu')(x)
        return x
    
    return apply



def UNet(input_shape = (256, 256, 3), num_classes = 1):
    inputs = keras.Input(input_shape)
      
    # Contracting Path
    s1 = EncoderBlock(64)(inputs)
    s2 = EncoderBlock(128)(s1)
    s3 = EncoderBlock(256)(s2)
    s4 = EncoderBlock(512)(s3)
      
    # Bottleneck
    b1 = layers.Conv2D(1024, 3, padding = 'valid')(s4)
    b1 = layers.Activation('relu')(b1)
    b1 = layers.Conv2D(1024, 3, padding = 'valid')(b1)
    b1 = layers.Activation('relu')(b1)
      
    # Expansive Path
    s5 = DecoderBlock(512)(b1, skip_feature=s4)
    s6 = DecoderBlock(256)(s5, skip_feature=s3)
    s7 = DecoderBlock(128)(s6, skip_feature=s2)
    s8 = DecoderBlock(64)(s7, skip_feature=s1)
    
    outputs = layers.Conv2D(
        num_classes,
        1, 
        padding='valid', 
        activation='sigmoid',
        dtype='float32'
    )(s8)
      
    model = keras.Model(
        inputs=inputs, 
        outputs=outputs, 
        name='UNet'
    )
    return model

model = UNet(input_shape=(224, 224, 3), num_classes=10)

@innat
Copy link
Author

innat commented Aug 12, 2023

UNet General

def Conv3x3BNReLU(filter):
    def apply(input):
        x = layers.Conv2D(filter, 3, padding='same')(input)
        x = layers.Activation('relu')(x)
        x = layers.BatchNormalization()(x)
        x = layers.Activation('relu')(x)
        return x
    
    return apply

def EncoderBlock(filter):
    def apply(input):
        x = Conv3x3BNReLU(filter)(input)
        x = Conv3x3BNReLU(filter)(x)
        x = layers.MaxPooling2D(pool_size=(2, 2))(x) 
        return x
    
    return apply

def BottleNeck(filter):
    def apply(input):
        x = layers.Conv2D(1024, 3, padding='same')(input)
        x = layers.Activation('relu')(x)
        return x
    return apply

def DecoderBlock(filter):
    def apply(input, skip=None):
        x = layers.UpSampling2D(size=(2, 2))(input)
        x = layers.Concatenate()([x, skip]) if skip is not None else x
        x = Conv3x3BNReLU(filter)(x)
        x = Conv3x3BNReLU(filter)(x)
        return x
    
    return apply

def UNet(
    input_shape=(256, 256, 3), 
    num_classes=1,
    class_activation='sigmoid',
    n_upsample_blocks=5,
    encoder_filters=(32,64,128,256),
    decoder_filters=(256,128,64,32,16),
):
    inputs = keras.Input(input_shape)
    
    # Contracting Path
    x = inputs
    skips = []
    for filter in encoder_filters:
        x = EncoderBlock(filter)(x)
        skips.append(x)
    skips = skips[::-1]  # reverse for skip connections
    
    # Bottleneck
    x = BottleNeck(512)(x)
    x = layers.MaxPooling2D(pool_size=(2, 2))(x)
    
    # Expansive Path
    for i in range(n_upsample_blocks):
        if i < len(skips):
            skip = skips[i]
        else:
            skip = None
        x = DecoderBlock(decoder_filters[i])(x, skip)
    
    outputs = layers.Conv2D(
        num_classes, 1, padding='same', activation=class_activation, dtype='float32'
    )(x)
    
    model = keras.Model(inputs=inputs, outputs=outputs, name='UNet')
    return model

keras.backend.clear_session()
model = UNet(
    input_shape=(512, 512, 3), 
    num_classes=10,
)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment