Skip to content

Instantly share code, notes, and snippets.

@innat
Last active April 3, 2023 07:49
Show Gist options
  • Save innat/2d2fd2e93b47dd5349095182eb984ef6 to your computer and use it in GitHub Desktop.
Save innat/2d2fd2e93b47dd5349095182eb984ef6 to your computer and use it in GitHub Desktop.
Official Implementation of UNet++ (EfficientNets) in TensorFlow 2
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import Model, Input
from tensorflow.keras.layers import Conv2DTranspose
from tensorflow.keras.layers import UpSampling2D
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import Activation
from tensorflow.keras.layers import Concatenate
def handle_block_names(stage, cols):
conv_name = 'decoder_stage{}-{}_conv'.format(stage, cols)
bn_name = 'decoder_stage{}-{}_bn'.format(stage, cols)
relu_name = 'decoder_stage{}-{}_relu'.format(stage, cols)
up_name = 'decoder_stage{}-{}_upsample'.format(stage, cols)
merge_name = 'merge_{}-{}'.format(stage, cols)
return conv_name, bn_name, relu_name, up_name, merge_name
def ConvRelu(filters, kernel_size, use_batchnorm=False, conv_name='conv', bn_name='bn', relu_name='relu'):
def layer(x):
x = Conv2D(filters, kernel_size, padding="same", name=conv_name, use_bias=not(use_batchnorm))(x)
if use_batchnorm:
x = BatchNormalization(name=bn_name)(x)
x = Activation('relu', name=relu_name)(x)
return x
return layer
def Upsample2D_block(filters, stage, cols, kernel_size=(3,3), upsample_rate=(2,2),
use_batchnorm=False, skip=None):
def layer(input_tensor):
conv_name, bn_name, relu_name, up_name, merge_name = handle_block_names(stage, cols)
x = UpSampling2D(size=upsample_rate, name=up_name)(input_tensor)
if (type(skip) != list and skip is not None) or (type(skip) == list and None not in skip):
if type(skip) is list:
x = Concatenate(name=merge_name)([x] + skip)
else:
x = Concatenate(name=merge_name)([x, skip])
x = ConvRelu(filters, kernel_size, use_batchnorm=use_batchnorm,
conv_name=conv_name + '1', bn_name=bn_name + '1', relu_name=relu_name + '1')(x)
x = ConvRelu(filters, kernel_size, use_batchnorm=use_batchnorm,
conv_name=conv_name + '2', bn_name=bn_name + '2', relu_name=relu_name + '2')(x)
return x
return layer
def Transpose2D_block(filters, stage, cols, kernel_size=(3,3), upsample_rate=(2,2),
transpose_kernel_size=(4,4), use_batchnorm=False, skip=None):
def layer(input_tensor):
conv_name, bn_name, relu_name, up_name, merge_name = handle_block_names(stage, cols)
x = Conv2DTranspose(filters, transpose_kernel_size, strides=upsample_rate,
padding='same', name=up_name, use_bias=not(use_batchnorm))(input_tensor)
if use_batchnorm:
x = BatchNormalization(name=bn_name+'1')(x)
x = Activation('relu', name=relu_name+'1')(x)
if (type(skip) != list and skip is not None) or (type(skip) == list and None not in skip):
if type(skip) is list:
merge_list = []
merge_list.append(x)
for l in skip:
merge_list.append(l)
x = Concatenate(name=merge_name)(merge_list)
else:
x = Concatenate(name=merge_name)([x, skip])
x = ConvRelu(filters, kernel_size, use_batchnorm=use_batchnorm,
conv_name=conv_name + '2', bn_name=bn_name + '2', relu_name=relu_name + '2')(x)
return x
return layer
def get_layer_number(model, layer_name):
"""
Help find layer in Keras model by name
Args:
model: Keras `Model`
layer_name: str, name of layer
Returns:
index of layer
Raises:
ValueError: if model does not contains layer with such name
"""
for i, l in enumerate(model.layers):
if l.name == layer_name:
return i
raise ValueError('No layer with name {} in model {}.'.format(layer_name, model.name))
def to_tuple(x):
if isinstance(x, tuple):
if len(x) == 2:
return x
elif np.isscalar(x):
return (x, x)
raise ValueError('Value should be tuple of length 2 or int value, got "{}"'.format(x))
def build_xnet(backbone, classes, skip_connection_layers,
decoder_filters=(256,128,64,32,16),
upsample_rates=(2,2,2,2,2),
n_upsample_blocks=5,
block_type='upsampling',
activation='sigmoid',
use_batchnorm=True):
input = backbone.input
if block_type == 'transpose':
up_block = Transpose2D_block
else:
up_block = Upsample2D_block
if len(skip_connection_layers) > n_upsample_blocks:
downsampling_layers = skip_connection_layers[int(len(skip_connection_layers)/2):]
skip_connection_layers = skip_connection_layers[:int(len(skip_connection_layers)/2)]
else:
downsampling_layers = skip_connection_layers
# convert layer names to indices
skip_connection_idx = ([get_layer_number(backbone, l) if isinstance(l, str) else l
for l in skip_connection_layers])
skip_layers_list = [backbone.layers[skip_connection_idx[i]].output for i in range(len(skip_connection_idx))]
downsampling_idx = ([get_layer_number(backbone, l) if isinstance(l, str) else l
for l in downsampling_layers])
downsampling_list = [backbone.layers[downsampling_idx[i]].output for i in range(len(downsampling_idx))]
downterm = [None] * (n_upsample_blocks+1)
for i in range(len(downsampling_idx)):
downterm[n_upsample_blocks-i-1] = downsampling_list[i]
downterm[-1] = backbone.output
interm = [None] * (n_upsample_blocks+1) * (n_upsample_blocks+1)
for i in range(len(skip_connection_idx)):
interm[-i*(n_upsample_blocks+1)+(n_upsample_blocks+1)*(n_upsample_blocks-1)] = skip_layers_list[i]
interm[(n_upsample_blocks+1)*n_upsample_blocks] = backbone.output
for j in range(n_upsample_blocks):
for i in range(n_upsample_blocks-j):
upsample_rate = to_tuple(upsample_rates[i])
if i == 0 and j < n_upsample_blocks-1 and len(skip_connection_layers) < n_upsample_blocks:
interm[(n_upsample_blocks+1)*i+j+1] = None
elif j == 0:
if downterm[i+1] is not None:
interm[(n_upsample_blocks+1)*i+j+1] = up_block(decoder_filters[n_upsample_blocks-i-2],
i+1, j+1, upsample_rate=upsample_rate,
skip=interm[(n_upsample_blocks+1)*i+j],
use_batchnorm=use_batchnorm)(downterm[i+1])
else:
interm[(n_upsample_blocks+1)*i+j+1] = None
else:
interm[(n_upsample_blocks+1)*i+j+1] = up_block(decoder_filters[n_upsample_blocks-i-2],
i+1, j+1, upsample_rate=upsample_rate,
skip=interm[(n_upsample_blocks+1)*i : (n_upsample_blocks+1)*i+j+1],
use_batchnorm=use_batchnorm)(interm[(n_upsample_blocks+1)*(i+1)+j])
x = Conv2D(classes, (3,3), padding='same', name='final_conv')(interm[n_upsample_blocks])
x = Activation(activation, name=activation)(x)
model = Model(input, x)
return model
from tensorflow.keras.applications import (EfficientNetB0, EfficientNetB1,
EfficientNetB2, EfficientNetB3,
EfficientNetB4, EfficientNetB5,
EfficientNetB6, EfficientNetB7)
backbones = {
"efficientnetb0": EfficientNetB0,
"efficientnetb1": EfficientNetB1,
"efficientnetb2": EfficientNetB2,
"efficientnetb3": EfficientNetB3,
"efficientnetb4": EfficientNetB4,
"efficientnetb5": EfficientNetB5,
"efficientnetb6": EfficientNetB6,
"efficientnetb7": EfficientNetB7,
}
def get_backbone(name, *args, **kwargs):
return backbones[name](*args, **kwargs)
def freeze_model(model):
for layer in model.layers:
layer.trainable = False
return
DEFAULT_SKIP_CONNECTIONS = {
'efficientnetb0': ('block6a_expand_activation', 'block4a_expand_activation',
'block3a_expand_activation', 'block2a_expand_activation'),
'efficientnetb1': ('block6a_expand_activation', 'block4a_expand_activation',
'block3a_expand_activation', 'block2a_expand_activation'),
'efficientnetb2': ('block6a_expand_activation', 'block4a_expand_activation',
'block3a_expand_activation', 'block2a_expand_activation'),
'efficientnetb3': ('block6a_expand_activation', 'block4a_expand_activation',
'block3a_expand_activation', 'block2a_expand_activation'),
'efficientnetb4': ('block6a_expand_activation', 'block4a_expand_activation',
'block3a_expand_activation', 'block2a_expand_activation'),
'efficientnetb5': ('block6a_expand_activation', 'block4a_expand_activation',
'block3a_expand_activation', 'block2a_expand_activation'),
'efficientnetb6': ('block6a_expand_activation', 'block4a_expand_activation',
'block3a_expand_activation', 'block2a_expand_activation'),
'efficientnetb7': ('block6a_expand_activation', 'block4a_expand_activation',
'block3a_expand_activation', 'block2a_expand_activation'),
}
def Xnet(backbone_name='efficientnetb0',
input_shape=(None, None, 3),
input_tensor=None,
encoder_weights='imagenet',
freeze_encoder=False,
skip_connections='default',
decoder_block_type='upsampling',
decoder_filters=(256,128,64,32,16),
decoder_use_batchnorm=True,
n_upsample_blocks=5,
upsample_rates=(2,2,2,2,2),
classes=1,
activation='sigmoid'):
backbone = get_backbone(backbone_name,
input_shape=input_shape,
input_tensor=input_tensor,
weights=encoder_weights,
include_top=False)
if skip_connections == 'default':
skip_connections = DEFAULT_SKIP_CONNECTIONS[backbone_name]
model = build_xnet(backbone,
classes,
skip_connections,
decoder_filters=decoder_filters,
block_type=decoder_block_type,
activation=activation,
n_upsample_blocks=n_upsample_blocks,
upsample_rates=upsample_rates,
use_batchnorm=decoder_use_batchnorm)
# lock encoder weights for fine-tuning
if freeze_encoder:
freeze_model(backbone)
return model
model = Xnet(backbone_name='efficientnetb1',
encoder_weights='imagenet',
classes=10,
activation=None)
model.summary()
@innat
Copy link
Author

innat commented Nov 20, 2021

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment