Last active
April 3, 2023 07:49
-
-
Save innat/2d2fd2e93b47dd5349095182eb984ef6 to your computer and use it in GitHub Desktop.
Official Implementation of UNet++ (EfficientNets) in TensorFlow 2
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import tensorflow as tf | |
from tensorflow import keras | |
from tensorflow.keras import Model, Input | |
from tensorflow.keras.layers import Conv2DTranspose | |
from tensorflow.keras.layers import UpSampling2D | |
from tensorflow.keras.layers import Conv2D | |
from tensorflow.keras.layers import BatchNormalization | |
from tensorflow.keras.layers import Activation | |
from tensorflow.keras.layers import Concatenate | |
def handle_block_names(stage, cols): | |
conv_name = 'decoder_stage{}-{}_conv'.format(stage, cols) | |
bn_name = 'decoder_stage{}-{}_bn'.format(stage, cols) | |
relu_name = 'decoder_stage{}-{}_relu'.format(stage, cols) | |
up_name = 'decoder_stage{}-{}_upsample'.format(stage, cols) | |
merge_name = 'merge_{}-{}'.format(stage, cols) | |
return conv_name, bn_name, relu_name, up_name, merge_name | |
def ConvRelu(filters, kernel_size, use_batchnorm=False, conv_name='conv', bn_name='bn', relu_name='relu'): | |
def layer(x): | |
x = Conv2D(filters, kernel_size, padding="same", name=conv_name, use_bias=not(use_batchnorm))(x) | |
if use_batchnorm: | |
x = BatchNormalization(name=bn_name)(x) | |
x = Activation('relu', name=relu_name)(x) | |
return x | |
return layer | |
def Upsample2D_block(filters, stage, cols, kernel_size=(3,3), upsample_rate=(2,2), | |
use_batchnorm=False, skip=None): | |
def layer(input_tensor): | |
conv_name, bn_name, relu_name, up_name, merge_name = handle_block_names(stage, cols) | |
x = UpSampling2D(size=upsample_rate, name=up_name)(input_tensor) | |
if (type(skip) != list and skip is not None) or (type(skip) == list and None not in skip): | |
if type(skip) is list: | |
x = Concatenate(name=merge_name)([x] + skip) | |
else: | |
x = Concatenate(name=merge_name)([x, skip]) | |
x = ConvRelu(filters, kernel_size, use_batchnorm=use_batchnorm, | |
conv_name=conv_name + '1', bn_name=bn_name + '1', relu_name=relu_name + '1')(x) | |
x = ConvRelu(filters, kernel_size, use_batchnorm=use_batchnorm, | |
conv_name=conv_name + '2', bn_name=bn_name + '2', relu_name=relu_name + '2')(x) | |
return x | |
return layer | |
def Transpose2D_block(filters, stage, cols, kernel_size=(3,3), upsample_rate=(2,2), | |
transpose_kernel_size=(4,4), use_batchnorm=False, skip=None): | |
def layer(input_tensor): | |
conv_name, bn_name, relu_name, up_name, merge_name = handle_block_names(stage, cols) | |
x = Conv2DTranspose(filters, transpose_kernel_size, strides=upsample_rate, | |
padding='same', name=up_name, use_bias=not(use_batchnorm))(input_tensor) | |
if use_batchnorm: | |
x = BatchNormalization(name=bn_name+'1')(x) | |
x = Activation('relu', name=relu_name+'1')(x) | |
if (type(skip) != list and skip is not None) or (type(skip) == list and None not in skip): | |
if type(skip) is list: | |
merge_list = [] | |
merge_list.append(x) | |
for l in skip: | |
merge_list.append(l) | |
x = Concatenate(name=merge_name)(merge_list) | |
else: | |
x = Concatenate(name=merge_name)([x, skip]) | |
x = ConvRelu(filters, kernel_size, use_batchnorm=use_batchnorm, | |
conv_name=conv_name + '2', bn_name=bn_name + '2', relu_name=relu_name + '2')(x) | |
return x | |
return layer |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def get_layer_number(model, layer_name): | |
""" | |
Help find layer in Keras model by name | |
Args: | |
model: Keras `Model` | |
layer_name: str, name of layer | |
Returns: | |
index of layer | |
Raises: | |
ValueError: if model does not contains layer with such name | |
""" | |
for i, l in enumerate(model.layers): | |
if l.name == layer_name: | |
return i | |
raise ValueError('No layer with name {} in model {}.'.format(layer_name, model.name)) | |
def to_tuple(x): | |
if isinstance(x, tuple): | |
if len(x) == 2: | |
return x | |
elif np.isscalar(x): | |
return (x, x) | |
raise ValueError('Value should be tuple of length 2 or int value, got "{}"'.format(x)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def build_xnet(backbone, classes, skip_connection_layers, | |
decoder_filters=(256,128,64,32,16), | |
upsample_rates=(2,2,2,2,2), | |
n_upsample_blocks=5, | |
block_type='upsampling', | |
activation='sigmoid', | |
use_batchnorm=True): | |
input = backbone.input | |
if block_type == 'transpose': | |
up_block = Transpose2D_block | |
else: | |
up_block = Upsample2D_block | |
if len(skip_connection_layers) > n_upsample_blocks: | |
downsampling_layers = skip_connection_layers[int(len(skip_connection_layers)/2):] | |
skip_connection_layers = skip_connection_layers[:int(len(skip_connection_layers)/2)] | |
else: | |
downsampling_layers = skip_connection_layers | |
# convert layer names to indices | |
skip_connection_idx = ([get_layer_number(backbone, l) if isinstance(l, str) else l | |
for l in skip_connection_layers]) | |
skip_layers_list = [backbone.layers[skip_connection_idx[i]].output for i in range(len(skip_connection_idx))] | |
downsampling_idx = ([get_layer_number(backbone, l) if isinstance(l, str) else l | |
for l in downsampling_layers]) | |
downsampling_list = [backbone.layers[downsampling_idx[i]].output for i in range(len(downsampling_idx))] | |
downterm = [None] * (n_upsample_blocks+1) | |
for i in range(len(downsampling_idx)): | |
downterm[n_upsample_blocks-i-1] = downsampling_list[i] | |
downterm[-1] = backbone.output | |
interm = [None] * (n_upsample_blocks+1) * (n_upsample_blocks+1) | |
for i in range(len(skip_connection_idx)): | |
interm[-i*(n_upsample_blocks+1)+(n_upsample_blocks+1)*(n_upsample_blocks-1)] = skip_layers_list[i] | |
interm[(n_upsample_blocks+1)*n_upsample_blocks] = backbone.output | |
for j in range(n_upsample_blocks): | |
for i in range(n_upsample_blocks-j): | |
upsample_rate = to_tuple(upsample_rates[i]) | |
if i == 0 and j < n_upsample_blocks-1 and len(skip_connection_layers) < n_upsample_blocks: | |
interm[(n_upsample_blocks+1)*i+j+1] = None | |
elif j == 0: | |
if downterm[i+1] is not None: | |
interm[(n_upsample_blocks+1)*i+j+1] = up_block(decoder_filters[n_upsample_blocks-i-2], | |
i+1, j+1, upsample_rate=upsample_rate, | |
skip=interm[(n_upsample_blocks+1)*i+j], | |
use_batchnorm=use_batchnorm)(downterm[i+1]) | |
else: | |
interm[(n_upsample_blocks+1)*i+j+1] = None | |
else: | |
interm[(n_upsample_blocks+1)*i+j+1] = up_block(decoder_filters[n_upsample_blocks-i-2], | |
i+1, j+1, upsample_rate=upsample_rate, | |
skip=interm[(n_upsample_blocks+1)*i : (n_upsample_blocks+1)*i+j+1], | |
use_batchnorm=use_batchnorm)(interm[(n_upsample_blocks+1)*(i+1)+j]) | |
x = Conv2D(classes, (3,3), padding='same', name='final_conv')(interm[n_upsample_blocks]) | |
x = Activation(activation, name=activation)(x) | |
model = Model(input, x) | |
return model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from tensorflow.keras.applications import (EfficientNetB0, EfficientNetB1, | |
EfficientNetB2, EfficientNetB3, | |
EfficientNetB4, EfficientNetB5, | |
EfficientNetB6, EfficientNetB7) | |
backbones = { | |
"efficientnetb0": EfficientNetB0, | |
"efficientnetb1": EfficientNetB1, | |
"efficientnetb2": EfficientNetB2, | |
"efficientnetb3": EfficientNetB3, | |
"efficientnetb4": EfficientNetB4, | |
"efficientnetb5": EfficientNetB5, | |
"efficientnetb6": EfficientNetB6, | |
"efficientnetb7": EfficientNetB7, | |
} | |
def get_backbone(name, *args, **kwargs): | |
return backbones[name](*args, **kwargs) | |
def freeze_model(model): | |
for layer in model.layers: | |
layer.trainable = False | |
return | |
DEFAULT_SKIP_CONNECTIONS = { | |
'efficientnetb0': ('block6a_expand_activation', 'block4a_expand_activation', | |
'block3a_expand_activation', 'block2a_expand_activation'), | |
'efficientnetb1': ('block6a_expand_activation', 'block4a_expand_activation', | |
'block3a_expand_activation', 'block2a_expand_activation'), | |
'efficientnetb2': ('block6a_expand_activation', 'block4a_expand_activation', | |
'block3a_expand_activation', 'block2a_expand_activation'), | |
'efficientnetb3': ('block6a_expand_activation', 'block4a_expand_activation', | |
'block3a_expand_activation', 'block2a_expand_activation'), | |
'efficientnetb4': ('block6a_expand_activation', 'block4a_expand_activation', | |
'block3a_expand_activation', 'block2a_expand_activation'), | |
'efficientnetb5': ('block6a_expand_activation', 'block4a_expand_activation', | |
'block3a_expand_activation', 'block2a_expand_activation'), | |
'efficientnetb6': ('block6a_expand_activation', 'block4a_expand_activation', | |
'block3a_expand_activation', 'block2a_expand_activation'), | |
'efficientnetb7': ('block6a_expand_activation', 'block4a_expand_activation', | |
'block3a_expand_activation', 'block2a_expand_activation'), | |
} | |
def Xnet(backbone_name='efficientnetb0', | |
input_shape=(None, None, 3), | |
input_tensor=None, | |
encoder_weights='imagenet', | |
freeze_encoder=False, | |
skip_connections='default', | |
decoder_block_type='upsampling', | |
decoder_filters=(256,128,64,32,16), | |
decoder_use_batchnorm=True, | |
n_upsample_blocks=5, | |
upsample_rates=(2,2,2,2,2), | |
classes=1, | |
activation='sigmoid'): | |
backbone = get_backbone(backbone_name, | |
input_shape=input_shape, | |
input_tensor=input_tensor, | |
weights=encoder_weights, | |
include_top=False) | |
if skip_connections == 'default': | |
skip_connections = DEFAULT_SKIP_CONNECTIONS[backbone_name] | |
model = build_xnet(backbone, | |
classes, | |
skip_connections, | |
decoder_filters=decoder_filters, | |
block_type=decoder_block_type, | |
activation=activation, | |
n_upsample_blocks=n_upsample_blocks, | |
upsample_rates=upsample_rates, | |
use_batchnorm=decoder_use_batchnorm) | |
# lock encoder weights for fine-tuning | |
if freeze_encoder: | |
freeze_model(backbone) | |
return model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
model = Xnet(backbone_name='efficientnetb1', | |
encoder_weights='imagenet', | |
classes=10, | |
activation=None) | |
model.summary() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Official Repo.
Discussion.