Last active
December 27, 2020 03:30
-
-
Save TeraBytesMemory/da068194615e6d7dcf83bf1bbd2a4d55 to your computer and use it in GitHub Desktop.
CNN parts in keras (in tf 1.x)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
def octave_conv2d(high_x, low_x, channels, alpha_in=0.5, alpha_out=0.5, stride=1, kernel_size=(3, 3), weight_decay=1e-4): | |
conv_hh = tf.keras.layers.Conv2D(int(channels * (1 - alpha_out)), | |
kernel_size, stride, padding='same', use_bias=False, | |
kernel_initializer=tf.keras.initializers.he_normal(), | |
kernel_regularizer=tf.keras.regularizers.l2(weight_decay), | |
bias_regularizer=tf.keras.regularizers.l2(weight_decay) | |
)(high_x) | |
conv_hl = tf.keras.layers.AveragePooling2D()(high_x) | |
conv_hl = tf.keras.layers.Conv2D(int(channels * alpha_out), | |
kernel_size, stride, padding='same', use_bias=False, | |
kernel_initializer=tf.keras.initializers.he_normal(), | |
kernel_regularizer=tf.keras.regularizers.l2(weight_decay), | |
bias_regularizer=tf.keras.regularizers.l2(weight_decay) | |
)(conv_hl) | |
conv_lh = tf.keras.layers.UpSampling2D()(low_x) | |
conv_lh = tf.keras.layers.Conv2D(int(channels * (1 - alpha_out)), | |
kernel_size, stride, padding='same', use_bias=False, | |
kernel_initializer=tf.keras.initializers.he_normal(), | |
kernel_regularizer=tf.keras.regularizers.l2(weight_decay), | |
bias_regularizer=tf.keras.regularizers.l2(weight_decay) | |
)(conv_lh) | |
conv_ll = tf.keras.layers.Conv2D(int(channels * alpha_out), | |
kernel_size, stride, padding='same', use_bias=False, | |
kernel_initializer=tf.keras.initializers.he_normal(), | |
kernel_regularizer=tf.keras.regularizers.l2(weight_decay), | |
bias_regularizer=tf.keras.regularizers.l2(weight_decay) | |
)(low_x) | |
yh, yl = conv_hh + conv_lh, conv_hl + conv_ll | |
return (yh, yl) | |
def octave_conv2d_with_bn(high_x, low_x, channels, stride=1, kernel_size=(3, 3), weight_decay=1e-4, use_se_block=True): | |
high_x, low_x = octave_conv2d( | |
high_x, low_x, channles, stride=stride, kernel_size=kernel_size, weight_decay=weight_decay | |
) | |
high_x = tf.keras.layers.BatchNormalization()(high_x) | |
high_x = tf.nn.relu(high_x) | |
low_x = tf.keras.layers.BatchNormalization()(low_x) | |
low_x = tf.nn.relu(low_x) | |
if use_se_block: | |
high_x = se_block(high_x, channels, weight_decay=weight_decay) | |
low_x = se_block(low_x, channels, weight_decay=weight_decay) | |
return (high_x, low_x) | |
def residual_octave_conv_block(high_x, low_x, before_ch, after_ch, stride=1, weight_decay=1e-4, use_se_block=True, dropout=0.): | |
high_residual = tf.keras.layers.BatchNormalization()(high_x) | |
high_residual = tf.nn.relu(high_residual) | |
low_residual = tf.keras.layers.BatchNormalization()(low_x) | |
low_residual = tf.nn.relu(low_residual) | |
high_residual, low_residual = octave_conv2d( | |
high_residual, low_residual, before_ch, stride=1, kernel_size=(3, 3), weight_decay=weight_decay | |
) | |
high_residual = tf.keras.layers.BatchNormalization()(high_residual) | |
high_residual = tf.nn.relu(high_residual) | |
low_residual = tf.keras.layers.BatchNormalization()(low_residual) | |
low_residual = tf.nn.relu(low_residual) | |
if dropout: | |
high_residual = tf.keras.layers.SpatialDropout2D(dropout)(high_residual) | |
low_residual = tf.keras.layers.SpatialDropout2D(dropout)(low_residual) | |
high_residual, low_residual = octave_conv2d( | |
high_residual, low_residual, after_ch, stride=stride, kernel_size=(3, 3), weight_decay=weight_decay | |
) | |
# if use_se_block: | |
# residual = se_block(residual, after_ch, weight_decay=weight_decay) | |
if stride > 1 or before_ch != after_ch: | |
high_x, low_x = octave_conv2d( | |
high_x, low_x, after_ch, stride=stride, kernel_size=(1, 1), weight_decay=weight_decay | |
) | |
high_x = tf.keras.layers.Add()([high_x, high_residual]) | |
low_x = tf.keras.layers.Add()([low_x, low_residual]) | |
return (high_x, low_x) | |
def conv2d_with_bn(x, channels, stride=1, kernel_size=(3, 3), weight_decay=1e-4, use_se_block=True): | |
x = tf.keras.layers.Conv2D(channels, kernel_size, (stride, stride), padding='same', use_bias=False, | |
kernel_initializer=tf.keras.initializers.he_normal(), | |
kernel_regularizer=tf.keras.regularizers.l2(weight_decay), | |
bias_regularizer=tf.keras.regularizers.l2(weight_decay) | |
)(x) | |
x = tf.keras.layers.BatchNormalization()(x) | |
x = tf.nn.relu(x) | |
if use_se_block: | |
x = se_block(x, channels, weight_decay=weight_decay) | |
return x | |
def se_block(in_block, ch, ratio=16, weight_decay=1e-4): | |
z = tf.keras.layers.GlobalAveragePooling2D()(in_block) | |
x = tf.keras.layers.Dense( | |
ch//ratio, activation='relu', | |
kernel_initializer=tf.keras.initializers.he_normal(), | |
kernel_regularizer=tf.keras.regularizers.l2(weight_decay), | |
bias_regularizer=tf.keras.regularizers.l2(weight_decay) | |
)(z) | |
x = tf.keras.layers.Dense( | |
ch, activation='sigmoid', | |
kernel_initializer=tf.keras.initializers.he_normal(), | |
kernel_regularizer=tf.keras.regularizers.l2(weight_decay), | |
bias_regularizer=tf.keras.regularizers.l2(weight_decay) | |
)(x) | |
return tf.keras.layers.Multiply()([in_block, x]) | |
def residual_conv_block(x, before_ch, after_ch, stride=1, weight_decay=1e-4, use_se_block=True, dropout=0.): | |
residual = tf.keras.layers.BatchNormalization()(x) | |
residual = tf.nn.relu(residual) | |
residual = tf.keras.layers.Conv2D(before_ch, 3, (1, 1), padding='same', use_bias=False, | |
kernel_initializer=tf.keras.initializers.he_normal(), | |
kernel_regularizer=tf.keras.regularizers.l2(weight_decay), | |
bias_regularizer=tf.keras.regularizers.l2(weight_decay) | |
)(residual) | |
residual = tf.keras.layers.BatchNormalization()(x) | |
residual = tf.nn.relu(residual) | |
if dropout: | |
residual = tf.keras.layers.SpatialDropout2D(dropout)(residual) | |
residual = tf.keras.layers.Conv2D(after_ch, 3, (stride, stride), padding='same', use_bias=False, | |
kernel_initializer=tf.keras.initializers.he_normal(), | |
kernel_regularizer=tf.keras.regularizers.l2(weight_decay), | |
bias_regularizer=tf.keras.regularizers.l2(weight_decay) | |
)(residual) | |
if use_se_block: | |
residual = se_block(residual, after_ch, weight_decay=weight_decay) | |
if stride > 1 or before_ch != after_ch: | |
x = tf.keras.layers.Conv2D(after_ch, 1, (stride, stride), padding='same', use_bias=False, | |
kernel_initializer=tf.keras.initializers.he_normal(), | |
kernel_regularizer=tf.keras.regularizers.l2(weight_decay), | |
bias_regularizer=tf.keras.regularizers.l2(weight_decay) | |
)(x) | |
x = tf.keras.layers.Add()([x, residual]) | |
return x | |
def bottleneck_residual_conv_block(x, before_ch, after_ch, stride=1, weight_decay=1e-4, use_se_block=True): | |
residual = tf.keras.layers.BatchNormalization()(x) | |
residual = tf.nn.relu(residual) | |
residual = tf.keras.layers.Conv2D(before_ch, 1, (1, 1), padding='same', use_bias=False, | |
kernel_initializer=tf.keras.initializers.he_normal(), | |
kernel_regularizer=tf.keras.regularizers.l2(weight_decay), | |
bias_regularizer=tf.keras.regularizers.l2(weight_decay) | |
)(residual) | |
residual = tf.keras.layers.BatchNormalization()(x) | |
residual = tf.nn.relu(residual) | |
residual = tf.keras.layers.Conv2D(before_ch, 3, (1, 1), padding='same', use_bias=False, | |
kernel_initializer=tf.keras.initializers.he_normal(), | |
kernel_regularizer=tf.keras.regularizers.l2(weight_decay), | |
bias_regularizer=tf.keras.regularizers.l2(weight_decay) | |
)(residual) | |
residual = tf.keras.layers.BatchNormalization()(x) | |
residual = tf.nn.relu(residual) | |
residual = tf.keras.layers.Conv2D(after_ch, 1, (stride, stride), padding='same', use_bias=False, | |
kernel_initializer=tf.keras.initializers.he_normal(), | |
kernel_regularizer=tf.keras.regularizers.l2(weight_decay), | |
bias_regularizer=tf.keras.regularizers.l2(weight_decay) | |
)(residual) | |
if use_se_block: | |
residual = se_block(residual, after_ch, weight_decay=weight_decay) | |
if stride > 1 or before_ch != after_ch: | |
x = tf.keras.layers.Conv2D(after_ch, 1, (stride, stride), padding='same', use_bias=False, | |
kernel_initializer=tf.keras.initializers.he_normal(), | |
kernel_regularizer=tf.keras.regularizers.l2(weight_decay), | |
bias_regularizer=tf.keras.regularizers.l2(weight_decay) | |
)(x) | |
x = tf.keras.layers.Add()([x, residual]) | |
return x | |
# https://openreview.net/forum?id=xTJEN-ggl1b | |
def bottleneck_residual_lambda_block(x, ch, stride=1, r=23, weight_decay=1e-4): | |
residual = tf.keras.layers.BatchNormalization()(x) | |
residual = tf.nn.relu(residual) | |
residual = tf.keras.layers.Conv2D(ch, 1, (1, 1), use_bias=False, | |
kernel_initializer=tf.keras.initializers.he_normal(), | |
kernel_regularizer=tf.keras.regularizers.l2(weight_decay), | |
bias_regularizer=tf.keras.regularizers.l2(weight_decay) | |
)(residual) | |
residual = tf.keras.layers.BatchNormalization()(residual) | |
residual = tf.nn.relu(residual) | |
lmd = LambdaLayer(dim_out=ch, r=r, dim_k=16, heads=4, dim_u=1) | |
_apply_regularizer(lmd, weight_decay) | |
residual = lmd(residual) | |
residual = tf.keras.layers.BatchNormalization()(residual) | |
residual = tf.nn.relu(residual) | |
if stride > 1: | |
residual = tf.keras.layers.AveragePooling2D((3, 3), (stride, stride), padding='same')(residual) | |
residual = tf.keras.layers.Conv2D(ch*4, 1, (1, 1), use_bias=False, | |
kernel_initializer=tf.keras.initializers.he_normal(), | |
kernel_regularizer=tf.keras.regularizers.l2(weight_decay), | |
bias_regularizer=tf.keras.regularizers.l2(weight_decay) | |
)(residual) | |
if x.shape[-1].value != ch or stride > 1: | |
x = tf.keras.layers.Conv2D(ch*4, 1, (stride, stride), padding='same', use_bias=False, | |
kernel_initializer=tf.keras.initializers.he_normal(), | |
kernel_regularizer=tf.keras.regularizers.l2(weight_decay), | |
bias_regularizer=tf.keras.regularizers.l2(weight_decay) | |
)(x) | |
return x + residual | |
def aa_downsample(x): | |
# https://arxiv.org/abs/2003.13630 | |
sampler = tf.keras.layers.DepthwiseConv2D(3, (2, 2), use_bias=False) | |
y = tf.pad(x, tf.constant([[0,0], [1, 1], [1, 1], [0, 0]], dtype=tf.int32), "REFLECT") | |
y = sampler(y) | |
# define kernel and freeze kernel # | |
kernel = np.array([[1,2,1], [2,4,2], [1,2,1]]) | |
kernel = kernel / kernel.sum() | |
kernel = np.stack([kernel] * x.shape[-1].value, axis=2)[:,:,:,np.newaxis] | |
sampler.set_weights([kernel]) | |
sampler.trainable = False | |
return y | |
# ref: https://arxiv.org/abs/2003.13630, https://github.com/Alibaba-MIIL/TResNet | |
# ref: https://arxiv.org/abs/1810.12890, https://github.com/CyberZHG/keras-drop-block | |
def tresnet_resnext_block(x, ch, stride=1, n_splits=32, weight_decay=1e-4, use_se=True, drop_block_raito=0., prefix=''): | |
## TODO: use In-Place Activated BatchNorm ## | |
# https://openaccess.thecvf.com/content_cvpr_2018/papers/Bulo_In-Place_Activated_BatchNorm_CVPR_2018_paper.pdf | |
if stride > 1: | |
x = aa_downsample(x) | |
path = x | |
path = tf.keras.layers.Conv2D(ch, 1, (1, 1), padding='same', use_bias=False, | |
kernel_initializer=tf.keras.initializers.he_normal(), | |
kernel_regularizer=tf.keras.regularizers.l2(weight_decay), | |
bias_regularizer=tf.keras.regularizers.l2(weight_decay), | |
name='{prefix}_pointwise_conv'.format(prefix=prefix) | |
)(path) | |
path = tf.keras.layers.BatchNormalization(name='{prefix}_batchnorm1'.format(prefix=prefix))(path) | |
path = tf.nn.leaky_relu(path, 1e-3) | |
pathes = [] | |
for i in range(n_splits): | |
group = tf.keras.layers.Conv2D(ch // n_splits, 3, (1, 1), padding='same', use_bias=False, | |
kernel_initializer=tf.keras.initializers.he_normal(), | |
kernel_regularizer=tf.keras.regularizers.l2(weight_decay), | |
bias_regularizer=tf.keras.regularizers.l2(weight_decay), | |
name='{prefix}_{i}th_middle_conv'.format(prefix=prefix, i=i) | |
)(path) | |
pathes.append(group) | |
path = tf.keras.layers.Concatenate(name='{prefix}_merge_pathes'.format(prefix=prefix))(pathes) | |
path = tf.keras.layers.BatchNormalization(name='{prefix}_batchnorm2'.format(prefix=prefix))(path) | |
path = tf.nn.leaky_relu(path, 1e-3) | |
if use_se: | |
path = se_block(path, ch, 8, weight_decay) | |
path = tf.keras.layers.Conv2D(ch*4, 1, (1, 1), padding='same', use_bias=False, | |
kernel_initializer=tf.keras.initializers.he_normal(), | |
kernel_regularizer=tf.keras.regularizers.l2(weight_decay), | |
bias_regularizer=tf.keras.regularizers.l2(weight_decay), | |
name='{prefix}_expand_conv'.format(prefix=prefix) | |
)(path) | |
path = tf.keras.layers.BatchNormalization(name='{prefix}_batchnorm3'.format(prefix=prefix))(path) | |
if drop_block_raito > 0.: | |
path = DropBlock2D(5, 1. - drop_block_raito)(path) | |
y = path | |
return tf.nn.relu(x + y) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment