Skip to content

Instantly share code, notes, and snippets.

@SaifAlDilaimi
Created May 4, 2018 15:09
Show Gist options
  • Save SaifAlDilaimi/cdacc15129fb71f905990282c20b0b35 to your computer and use it in GitHub Desktop.
Save SaifAlDilaimi/cdacc15129fb71f905990282c20b0b35 to your computer and use it in GitHub Desktop.
Keras Hourglass network
# -*- coding: utf-8 -*-
import GlobalParams as PARAMS
from keras import backend as K
from keras.layers import Input
from keras.layers import Lambda
from keras.layers import Dense
from keras.layers import Dropout
from keras.layers import Activation
from keras.layers import Conv2D
from keras.layers import SeparableConv2D
from keras.layers import BatchNormalization
from keras.layers import MaxPooling2D
from keras.layers import UpSampling2D
from keras.layers import Multiply
from keras.layers import Concatenate
from keras.layers import Add
from keras.models import Model
from keras.optimizers import RMSprop
from keras import losses
from keras import metrics
def conv(x, filters, kernel_size = 1, strides=(1, 1), padding='same', name="conv"):
""" Spatial Convolution (CONV2D)
Args:
inputs : Input Tensor (Data Type : NHWC)
filters : Number of filters (channels)
kernel_size : Size of kernel
strides : Stride
pad : Padding Type (VALID/SAME)
name : Name of the block
Returns:
conv : Output Tensor (Convolved Input)
"""
x = Conv2D(filters, kernel_size, strides=strides, padding=padding,
use_bias=False)(x)
return x
def max_pool2d(x, pool_size = (2,2), strides=(2,2), padding="valid"):
x = MaxPooling2D(pool_size=pool_size, strides=strides, padding=padding)(x)
return x
def conv_bn(x, filters, kernel_size = 1, strides=(1, 1), padding='same', name="conv_bn"):
x = conv(x, filters, kernel_size, strides, padding, name)
x = BatchNormalization(axis=-1, scale=False)(x)
return x
def conv_bn_relu(x, filters, kernel_size = 1, strides = 1, padding="same", name="conv_bn_rel"):
""" Spatial Convolution (CONV2D) + BatchNormalization + ReLU Activation
Args:
inputs : Input Tensor (Data Type : NHWC)
filters : Number of filters (channels)
kernel_size : Size of kernel
strides : Stride
pad : Padding Type (VALID/SAME)
name : Name of the block
Returns:
norm : Output Tensor
"""
x = conv(x, filters, kernel_size, strides, padding, name)
x = BatchNormalization(axis=-1, scale=False)(x)
x = Activation('relu')(x)
return x
def conv_block(x, numOut, name="conv_block"):
""" Convolutional Block
Args:
inputs : Input Tensor
numOut : Desired output number of channel
name : Name of the block
Returns:
conv_3 : Output Tensor
"""
x = BatchNormalization(axis=-1, scale=False)(x)
x = Activation('relu')(x)
x = conv(x, int(numOut))
return x
def skip_layer(x, numOut, name = 'skip_layer'):
""" Skip Layer
Args:
inputs : Input Tensor
numOut : Desired output number of channel
name : Name of the bloc
Returns:
Tensor of shape (None, inputs.height, inputs.width, numOut)
"""
print(x.shape, x.shape[3], numOut, x.shape[3] == numOut)
if x.shape[3] == numOut: # check if right
return x
x = conv(x, numOut)
return x
def residual_block(x, numOut, name = "residual_block"):
""" Residual Unit
Args:
inputs : Input Tensor
numOut : Number of Output Features (channels)
name : Name of the block
"""
convb = conv_block(x, numOut)
skip_l = skip_layer(x, numOut)
x = Add()([convb, skip_l])
x = Activation('relu')(x)
return x
def hourglass(x, n, numOut, name = 'hourglass'):
""" Hourglass Module
Args:
inputs : Input Tensor
n : Number of downsampling step
numOut : Number of Output Features (channels)
name : Name of the block
"""
# upper branch
up_1 = residual_block(x, numOut, name="up_1")
# lower branch
low_ = max_pool2d(x)
low_1 = residual_block(low_, numOut, name="low_1")
if n > 0:
low_2 = hourglass(low_1, n-1, numOut, name="low_2")
else:
low_2 = residual_block(low_1, numOut, name="low_2")
low_3 = residual_block(low_2, numOut, name="low_3")
print("low3: ", low_3.shape)
low3_size = K.int_shape(low_3)[1:3]
up_size = (2,2)
print("upsampling size: ", up_size)
#up_size = tuple([x*x for x in up_size])
#print(up_size)
up_2 = UpSampling2D(up_size)(low_3)
print(up_1)
print(up_2)
x = Add()([up_2, up_1])
x = Activation('relu')(x)
x = Dropout(0.2)(x)
return x
class HGKerasModel():
def build_model(self):
if K.image_data_format() == 'channels_first':
input_shape = (3, PARAMS.ML_INPUT_IMAGE_HEIGHT, PARAMS.ML_INPUT_IMAGE_WIDTH)
else:
input_shape = (PARAMS.ML_INPUT_IMAGE_WIDTH, PARAMS.ML_INPUT_IMAGE_HEIGHT, 3)
m_input = Input(shape=input_shape)
# Storage Table
hg = [None] * PARAMS.ML_DEEPPOSE_STAGES
ll = [None] * PARAMS.ML_DEEPPOSE_STAGES
ll_ = [None] * PARAMS.ML_DEEPPOSE_STAGES
drop = [None] * PARAMS.ML_DEEPPOSE_STAGES
out = [None] * PARAMS.ML_DEEPPOSE_STAGES
out_ = [None] * PARAMS.ML_DEEPPOSE_STAGES
sum_ = [None] * PARAMS.ML_DEEPPOSE_STAGES
# preprossing
conv1 = conv_bn_relu(m_input, filters=64, kernel_size=6, strides=2)
r1 = residual_block(conv1, 128)
pool1 = max_pool2d(r1)
r2 = residual_block(pool1, numOut=int(PARAMS.ML_INPUT_FEATURES/2))
r3 = residual_block(r2, numOut=PARAMS.ML_INPUT_FEATURES)
# stage 0
hg[0] = hourglass(r3, PARAMS.ML_HOURGLASS_DOWNSAMPLING, PARAMS.ML_INPUT_FEATURES)
ll[0] = conv_bn_relu(hg[0], PARAMS.ML_INPUT_FEATURES)
out[0] = conv_bn_relu(ll[0], PARAMS.ML_LABEL_CLASSES)
out_[0] = conv(out[0], PARAMS.ML_INPUT_FEATURES)
sum_[0] = Add()([out_[0], ll[0], r3])
# build stages 1 till k-1
for i in range(1, PARAMS.ML_DEEPPOSE_STAGES - 1):
hg[i] = hourglass(sum_[i-1], PARAMS.ML_HOURGLASS_DOWNSAMPLING, PARAMS.ML_INPUT_FEATURES)
ll[i] = conv_bn_relu(hg[i], PARAMS.ML_INPUT_FEATURES)
out[i] = conv_bn_relu(ll[i], PARAMS.ML_LABEL_CLASSES)
out_[i] = conv(out[i], PARAMS.ML_INPUT_FEATURES)
sum_[i] = Add()([out_[i], ll[i], sum_[i-1]])
# build stage k-1
stages = PARAMS.ML_DEEPPOSE_STAGES
hg[stages - 1] = hourglass(sum_[stages - 2], PARAMS.ML_HOURGLASS_DOWNSAMPLING, PARAMS.ML_INPUT_FEATURES)
ll[stages - 1] = conv_bn_relu(hg[stages - 1], PARAMS.ML_INPUT_FEATURES)
out[stages - 1] = conv_bn_relu(ll[stages - 1], PARAMS.ML_LABEL_CLASSES)
conc = Concatenate()(out)
sigmoid = Activation('sigmoid')(conc)
model = Model(inputs=m_input, outputs=sigmoid)
rmsprop = RMSprop(lr=PARAMS.ML_HOURGLASS_LEARN_RATE, decay=PARAMS.ML_HOURGLASS_LEARN_RATE_DECAY)
model.compile(rmsprop, loss=losses.binary_crossentropy, metrics=['accuracy'])
model.summary()
print("Input shape: ", m_input.shape)
print("Output length: ", len(sigmoid_out))
return model
def main():
model = HGKerasModel().build_model()
for out in model.output:
print(out)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment