Skip to content

Instantly share code, notes, and snippets.

@saitejamalyala
Last active June 27, 2021 12:49
Show Gist options
  • Save saitejamalyala/033856765991d1f4703a32aca088573c to your computer and use it in GitHub Desktop.
Save saitejamalyala/033856765991d1f4703a32aca088573c to your computer and use it in GitHub Desktop.
Keras custom layer to multiply input by a scalar
import tensorflow as tf
class CustomLayer(tf.keras.layers.Layer):
def __init__(self, k, name=None, **kwargs):
super(CustomLayer, self).__init__(name=name)
self.k = k
super(CustomLayer, self).__init__(**kwargs)
def get_config(self):
config = super(CustomLayer, self).get_config()
config.update({"k": self.k})
return config
def call(self, input):
return tf.multiply(input, 2)
model = tf.keras.models.Sequential([
tf.keras.Input(name='input_layer', shape=(10,)),
CustomLayer(10, name='custom_layer'),
tf.keras.layers.Dense(1, activation='sigmoid', name='output_layer')
])
tf.keras.models.save_model(model, 'model.h5')
new_model = tf.keras.models.load_model('model.h5', custom_objects={'CustomLayer': CustomLayer})
print(new_model.summary())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment