Skip to content

Instantly share code, notes, and snippets.

@saitejamalyala
Last active August 5, 2021 23:58
Show Gist options
  • Save saitejamalyala/e533c2cfd3a90cf9121d892a8a39d041 to your computer and use it in GitHub Desktop.
Save saitejamalyala/e533c2cfd3a90cf9121d892a8a39d041 to your computer and use it in GitHub Desktop.
Custom keras layer to mask input tensor
from tensorflow.keras import layers
from tensorflow.keras import models
import tensorflow as tf
class CustomMaskLayer(layers.Layer):
"""Layer that masks tensor at specific locations as mentioned in binary tensor
Args:
layers (layers.Layer): keras.layers baseclass
"""
def __init__(self, list_mask,name=None,**kwargs):
self.list_mask = list_mask
super(CustomMaskLayer,self).__init__(name=name,**kwargs)
def call(self, inputs):
temp = inputs
mask = tf.constant(self.list_mask,dtype=tf.float32)
# masking with first and last co-ordinate
first_last_skip_conn = tf.math.multiply(mask, temp)
output = first_last_skip_conn
return output
def get_config(self):
config = super(CustomMaskLayer,self).get_config()
config.update({
"list_mask": self.list_mask,
})
return config
# provide mask position on list
# 0- to mask the resulting array with0
list_mask=[[1., 1.],
[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.],
[1., 1.]]
def nn(full_skip:bool=True):
# Grid Map input
ip_gridmap = layers.Input(shape=(1536,1536,1))
#CNN - branch1
#1x1 conv
#x_A = layers.Conv2D(3,kernel_size=1,strides=1)(ip_gridmap)
x_A = layers.Conv2D(16,kernel_size=7,strides=2)(ip_gridmap)
x_A = layers.ReLU()(x_A)
x_A = layers.BatchNormalization()(x_A)
x_A = layers.AvgPool2D(pool_size=(4,4))(x_A)
x_A = layers.Conv2D(32,kernel_size=5,strides=2)(x_A)
x_A = layers.ReLU()(x_A)
x_A = layers.BatchNormalization()(x_A)
x_A = layers.AvgPool2D(pool_size=(4,4))(x_A)
x_A = layers.Conv2D(64,kernel_size=3,strides=2)(x_A)
x_A = layers.BatchNormalization()(x_A)
x_A = layers.ReLU()(x_A)
x_A = layers.AvgPool2D(pool_size=(2,2))(x_A)
x_A = layers.Flatten()(x_A)
# Other inputs
ip_grid_org_res = layers.Input(shape=(3,),name="Grid_origin_res")
ip_left_bnd = layers.Input(shape=(25,2),name="Left_boundary")
ip_right_bnd = layers.Input(shape=(25,2),name="Right_boundary")
ip_car_odo = layers.Input(shape=(3,),name="Car_loc")
ip_init_path = layers.Input(shape=(25,2),name="Initial_path")
#ip_filedetais = layers.Input
# branch 5
conc_grid_orgres_car_odo = layers.concatenate([ip_grid_org_res,ip_car_odo])
#reshaping paths
reshape_init_path = layers.Reshape((50,))(ip_init_path)
reshape_left_bnd = layers.Reshape((50,))(ip_left_bnd)
reshape_right_bnd = layers.Reshape((50,))(ip_right_bnd)
#concatenate feature
concat_feat = layers.concatenate([x_A, reshape_init_path, reshape_left_bnd, reshape_right_bnd, conc_grid_orgres_car_odo])
# Dense Network
output = layers.Dense(128, activation='linear')(concat_feat)
output = layers.BatchNormalization()(output)
output = layers.ReLU()(output)
output = layers.Dense(96, activation='linear')(output)
output = layers.BatchNormalization()(output)
output = layers.ReLU()(output)
output = layers.Dense(64, activation='linear')(output)
output = layers.BatchNormalization()(output)
output = layers.ReLU()(output)
#output = layers.LeakyReLU()(output)
output = layers.Dense(50, activation='linear')(output)
if full_skip:
output = layers.add([output,reshape_init_path])
output = layers.Dense(50, activation='linear')(output)
else:
"""
first_last_skip_conn = tf.constant(list_mask,dtype=tf.float32)
# masking with first and last co-ordinate
first_last_skip_conn = tf.math.multiply(first_last_skip_conn,ip_init_path)
"""
first_last_skip_conn= CustomMaskLayer(list_mask=list_mask)(ip_init_path)
reshape_first_last_skip = layers.Reshape((50,))(first_last_skip_conn)
output = layers.add([output, reshape_first_last_skip])
output = layers.Dense(50, activation='linear')(output)
output = layers.Reshape((25,2))(output)
nn_fun = models.Model(inputs = [ip_gridmap,ip_grid_org_res,ip_left_bnd, ip_right_bnd, ip_car_odo, ip_init_path], outputs= output)
nn_fun.summary(line_length=120)
return nn_fun
nn(full_skip=False)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment