Skip to content

Instantly share code, notes, and snippets.

@TeraBytesMemory
Last active December 27, 2020 03:30
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save TeraBytesMemory/da068194615e6d7dcf83bf1bbd2a4d55 to your computer and use it in GitHub Desktop.
Save TeraBytesMemory/da068194615e6d7dcf83bf1bbd2a4d55 to your computer and use it in GitHub Desktop.
CNN parts in keras (in tf 1.x)
import tensorflow as tf
def octave_conv2d(high_x, low_x, channels, alpha_in=0.5, alpha_out=0.5, stride=1, kernel_size=(3, 3), weight_decay=1e-4):
conv_hh = tf.keras.layers.Conv2D(int(channels * (1 - alpha_out)),
kernel_size, stride, padding='same', use_bias=False,
kernel_initializer=tf.keras.initializers.he_normal(),
kernel_regularizer=tf.keras.regularizers.l2(weight_decay),
bias_regularizer=tf.keras.regularizers.l2(weight_decay)
)(high_x)
conv_hl = tf.keras.layers.AveragePooling2D()(high_x)
conv_hl = tf.keras.layers.Conv2D(int(channels * alpha_out),
kernel_size, stride, padding='same', use_bias=False,
kernel_initializer=tf.keras.initializers.he_normal(),
kernel_regularizer=tf.keras.regularizers.l2(weight_decay),
bias_regularizer=tf.keras.regularizers.l2(weight_decay)
)(conv_hl)
conv_lh = tf.keras.layers.UpSampling2D()(low_x)
conv_lh = tf.keras.layers.Conv2D(int(channels * (1 - alpha_out)),
kernel_size, stride, padding='same', use_bias=False,
kernel_initializer=tf.keras.initializers.he_normal(),
kernel_regularizer=tf.keras.regularizers.l2(weight_decay),
bias_regularizer=tf.keras.regularizers.l2(weight_decay)
)(conv_lh)
conv_ll = tf.keras.layers.Conv2D(int(channels * alpha_out),
kernel_size, stride, padding='same', use_bias=False,
kernel_initializer=tf.keras.initializers.he_normal(),
kernel_regularizer=tf.keras.regularizers.l2(weight_decay),
bias_regularizer=tf.keras.regularizers.l2(weight_decay)
)(low_x)
yh, yl = conv_hh + conv_lh, conv_hl + conv_ll
return (yh, yl)
def octave_conv2d_with_bn(high_x, low_x, channels, stride=1, kernel_size=(3, 3), weight_decay=1e-4, use_se_block=True):
high_x, low_x = octave_conv2d(
high_x, low_x, channles, stride=stride, kernel_size=kernel_size, weight_decay=weight_decay
)
high_x = tf.keras.layers.BatchNormalization()(high_x)
high_x = tf.nn.relu(high_x)
low_x = tf.keras.layers.BatchNormalization()(low_x)
low_x = tf.nn.relu(low_x)
if use_se_block:
high_x = se_block(high_x, channels, weight_decay=weight_decay)
low_x = se_block(low_x, channels, weight_decay=weight_decay)
return (high_x, low_x)
def residual_octave_conv_block(high_x, low_x, before_ch, after_ch, stride=1, weight_decay=1e-4, use_se_block=True, dropout=0.):
high_residual = tf.keras.layers.BatchNormalization()(high_x)
high_residual = tf.nn.relu(high_residual)
low_residual = tf.keras.layers.BatchNormalization()(low_x)
low_residual = tf.nn.relu(low_residual)
high_residual, low_residual = octave_conv2d(
high_residual, low_residual, before_ch, stride=1, kernel_size=(3, 3), weight_decay=weight_decay
)
high_residual = tf.keras.layers.BatchNormalization()(high_residual)
high_residual = tf.nn.relu(high_residual)
low_residual = tf.keras.layers.BatchNormalization()(low_residual)
low_residual = tf.nn.relu(low_residual)
if dropout:
high_residual = tf.keras.layers.SpatialDropout2D(dropout)(high_residual)
low_residual = tf.keras.layers.SpatialDropout2D(dropout)(low_residual)
high_residual, low_residual = octave_conv2d(
high_residual, low_residual, after_ch, stride=stride, kernel_size=(3, 3), weight_decay=weight_decay
)
# if use_se_block:
# residual = se_block(residual, after_ch, weight_decay=weight_decay)
if stride > 1 or before_ch != after_ch:
high_x, low_x = octave_conv2d(
high_x, low_x, after_ch, stride=stride, kernel_size=(1, 1), weight_decay=weight_decay
)
high_x = tf.keras.layers.Add()([high_x, high_residual])
low_x = tf.keras.layers.Add()([low_x, low_residual])
return (high_x, low_x)
def conv2d_with_bn(x, channels, stride=1, kernel_size=(3, 3), weight_decay=1e-4, use_se_block=True):
x = tf.keras.layers.Conv2D(channels, kernel_size, (stride, stride), padding='same', use_bias=False,
kernel_initializer=tf.keras.initializers.he_normal(),
kernel_regularizer=tf.keras.regularizers.l2(weight_decay),
bias_regularizer=tf.keras.regularizers.l2(weight_decay)
)(x)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.nn.relu(x)
if use_se_block:
x = se_block(x, channels, weight_decay=weight_decay)
return x
def se_block(in_block, ch, ratio=16, weight_decay=1e-4):
z = tf.keras.layers.GlobalAveragePooling2D()(in_block)
x = tf.keras.layers.Dense(
ch//ratio, activation='relu',
kernel_initializer=tf.keras.initializers.he_normal(),
kernel_regularizer=tf.keras.regularizers.l2(weight_decay),
bias_regularizer=tf.keras.regularizers.l2(weight_decay)
)(z)
x = tf.keras.layers.Dense(
ch, activation='sigmoid',
kernel_initializer=tf.keras.initializers.he_normal(),
kernel_regularizer=tf.keras.regularizers.l2(weight_decay),
bias_regularizer=tf.keras.regularizers.l2(weight_decay)
)(x)
return tf.keras.layers.Multiply()([in_block, x])
def residual_conv_block(x, before_ch, after_ch, stride=1, weight_decay=1e-4, use_se_block=True, dropout=0.):
residual = tf.keras.layers.BatchNormalization()(x)
residual = tf.nn.relu(residual)
residual = tf.keras.layers.Conv2D(before_ch, 3, (1, 1), padding='same', use_bias=False,
kernel_initializer=tf.keras.initializers.he_normal(),
kernel_regularizer=tf.keras.regularizers.l2(weight_decay),
bias_regularizer=tf.keras.regularizers.l2(weight_decay)
)(residual)
residual = tf.keras.layers.BatchNormalization()(x)
residual = tf.nn.relu(residual)
if dropout:
residual = tf.keras.layers.SpatialDropout2D(dropout)(residual)
residual = tf.keras.layers.Conv2D(after_ch, 3, (stride, stride), padding='same', use_bias=False,
kernel_initializer=tf.keras.initializers.he_normal(),
kernel_regularizer=tf.keras.regularizers.l2(weight_decay),
bias_regularizer=tf.keras.regularizers.l2(weight_decay)
)(residual)
if use_se_block:
residual = se_block(residual, after_ch, weight_decay=weight_decay)
if stride > 1 or before_ch != after_ch:
x = tf.keras.layers.Conv2D(after_ch, 1, (stride, stride), padding='same', use_bias=False,
kernel_initializer=tf.keras.initializers.he_normal(),
kernel_regularizer=tf.keras.regularizers.l2(weight_decay),
bias_regularizer=tf.keras.regularizers.l2(weight_decay)
)(x)
x = tf.keras.layers.Add()([x, residual])
return x
def bottleneck_residual_conv_block(x, before_ch, after_ch, stride=1, weight_decay=1e-4, use_se_block=True):
residual = tf.keras.layers.BatchNormalization()(x)
residual = tf.nn.relu(residual)
residual = tf.keras.layers.Conv2D(before_ch, 1, (1, 1), padding='same', use_bias=False,
kernel_initializer=tf.keras.initializers.he_normal(),
kernel_regularizer=tf.keras.regularizers.l2(weight_decay),
bias_regularizer=tf.keras.regularizers.l2(weight_decay)
)(residual)
residual = tf.keras.layers.BatchNormalization()(x)
residual = tf.nn.relu(residual)
residual = tf.keras.layers.Conv2D(before_ch, 3, (1, 1), padding='same', use_bias=False,
kernel_initializer=tf.keras.initializers.he_normal(),
kernel_regularizer=tf.keras.regularizers.l2(weight_decay),
bias_regularizer=tf.keras.regularizers.l2(weight_decay)
)(residual)
residual = tf.keras.layers.BatchNormalization()(x)
residual = tf.nn.relu(residual)
residual = tf.keras.layers.Conv2D(after_ch, 1, (stride, stride), padding='same', use_bias=False,
kernel_initializer=tf.keras.initializers.he_normal(),
kernel_regularizer=tf.keras.regularizers.l2(weight_decay),
bias_regularizer=tf.keras.regularizers.l2(weight_decay)
)(residual)
if use_se_block:
residual = se_block(residual, after_ch, weight_decay=weight_decay)
if stride > 1 or before_ch != after_ch:
x = tf.keras.layers.Conv2D(after_ch, 1, (stride, stride), padding='same', use_bias=False,
kernel_initializer=tf.keras.initializers.he_normal(),
kernel_regularizer=tf.keras.regularizers.l2(weight_decay),
bias_regularizer=tf.keras.regularizers.l2(weight_decay)
)(x)
x = tf.keras.layers.Add()([x, residual])
return x
# https://openreview.net/forum?id=xTJEN-ggl1b
def bottleneck_residual_lambda_block(x, ch, stride=1, r=23, weight_decay=1e-4):
residual = tf.keras.layers.BatchNormalization()(x)
residual = tf.nn.relu(residual)
residual = tf.keras.layers.Conv2D(ch, 1, (1, 1), use_bias=False,
kernel_initializer=tf.keras.initializers.he_normal(),
kernel_regularizer=tf.keras.regularizers.l2(weight_decay),
bias_regularizer=tf.keras.regularizers.l2(weight_decay)
)(residual)
residual = tf.keras.layers.BatchNormalization()(residual)
residual = tf.nn.relu(residual)
lmd = LambdaLayer(dim_out=ch, r=r, dim_k=16, heads=4, dim_u=1)
_apply_regularizer(lmd, weight_decay)
residual = lmd(residual)
residual = tf.keras.layers.BatchNormalization()(residual)
residual = tf.nn.relu(residual)
if stride > 1:
residual = tf.keras.layers.AveragePooling2D((3, 3), (stride, stride), padding='same')(residual)
residual = tf.keras.layers.Conv2D(ch*4, 1, (1, 1), use_bias=False,
kernel_initializer=tf.keras.initializers.he_normal(),
kernel_regularizer=tf.keras.regularizers.l2(weight_decay),
bias_regularizer=tf.keras.regularizers.l2(weight_decay)
)(residual)
if x.shape[-1].value != ch or stride > 1:
x = tf.keras.layers.Conv2D(ch*4, 1, (stride, stride), padding='same', use_bias=False,
kernel_initializer=tf.keras.initializers.he_normal(),
kernel_regularizer=tf.keras.regularizers.l2(weight_decay),
bias_regularizer=tf.keras.regularizers.l2(weight_decay)
)(x)
return x + residual
def aa_downsample(x):
# https://arxiv.org/abs/2003.13630
sampler = tf.keras.layers.DepthwiseConv2D(3, (2, 2), use_bias=False)
y = tf.pad(x, tf.constant([[0,0], [1, 1], [1, 1], [0, 0]], dtype=tf.int32), "REFLECT")
y = sampler(y)
# define kernel and freeze kernel #
kernel = np.array([[1,2,1], [2,4,2], [1,2,1]])
kernel = kernel / kernel.sum()
kernel = np.stack([kernel] * x.shape[-1].value, axis=2)[:,:,:,np.newaxis]
sampler.set_weights([kernel])
sampler.trainable = False
return y
# ref: https://arxiv.org/abs/2003.13630, https://github.com/Alibaba-MIIL/TResNet
# ref: https://arxiv.org/abs/1810.12890, https://github.com/CyberZHG/keras-drop-block
def tresnet_resnext_block(x, ch, stride=1, n_splits=32, weight_decay=1e-4, use_se=True, drop_block_raito=0., prefix=''):
## TODO: use In-Place Activated BatchNorm ##
# https://openaccess.thecvf.com/content_cvpr_2018/papers/Bulo_In-Place_Activated_BatchNorm_CVPR_2018_paper.pdf
if stride > 1:
x = aa_downsample(x)
path = x
path = tf.keras.layers.Conv2D(ch, 1, (1, 1), padding='same', use_bias=False,
kernel_initializer=tf.keras.initializers.he_normal(),
kernel_regularizer=tf.keras.regularizers.l2(weight_decay),
bias_regularizer=tf.keras.regularizers.l2(weight_decay),
name='{prefix}_pointwise_conv'.format(prefix=prefix)
)(path)
path = tf.keras.layers.BatchNormalization(name='{prefix}_batchnorm1'.format(prefix=prefix))(path)
path = tf.nn.leaky_relu(path, 1e-3)
pathes = []
for i in range(n_splits):
group = tf.keras.layers.Conv2D(ch // n_splits, 3, (1, 1), padding='same', use_bias=False,
kernel_initializer=tf.keras.initializers.he_normal(),
kernel_regularizer=tf.keras.regularizers.l2(weight_decay),
bias_regularizer=tf.keras.regularizers.l2(weight_decay),
name='{prefix}_{i}th_middle_conv'.format(prefix=prefix, i=i)
)(path)
pathes.append(group)
path = tf.keras.layers.Concatenate(name='{prefix}_merge_pathes'.format(prefix=prefix))(pathes)
path = tf.keras.layers.BatchNormalization(name='{prefix}_batchnorm2'.format(prefix=prefix))(path)
path = tf.nn.leaky_relu(path, 1e-3)
if use_se:
path = se_block(path, ch, 8, weight_decay)
path = tf.keras.layers.Conv2D(ch*4, 1, (1, 1), padding='same', use_bias=False,
kernel_initializer=tf.keras.initializers.he_normal(),
kernel_regularizer=tf.keras.regularizers.l2(weight_decay),
bias_regularizer=tf.keras.regularizers.l2(weight_decay),
name='{prefix}_expand_conv'.format(prefix=prefix)
)(path)
path = tf.keras.layers.BatchNormalization(name='{prefix}_batchnorm3'.format(prefix=prefix))(path)
if drop_block_raito > 0.:
path = DropBlock2D(5, 1. - drop_block_raito)(path)
y = path
return tf.nn.relu(x + y)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment