Skip to content

Instantly share code, notes, and snippets.

@saitejamalyala
Created June 28, 2021 17:43
Show Gist options
  • Save saitejamalyala/4c584155bee0b25475d60ca90d9a1778 to your computer and use it in GitHub Desktop.
Save saitejamalyala/4c584155bee0b25475d60ca90d9a1778 to your computer and use it in GitHub Desktop.
You can provide manually the mapping custom_objects in the load_model method as mentioned in the answer https://stackoverflow.com/a/62326857/8056572 but it can be tedious when you have a lot of custom layers (or any custom callables defined. e.g. metrics, losses, optimizers, ...). Tensorflow provides a utils function to do it automatically: tf.k…
import tensorflow as tf
@tf.keras.utils.register_keras_serializable()
class CustomLayer(tf.keras.layers.Layer):
def __init__(self, k, **kwargs):
self.k = k
super(CustomLayer, self).__init__(**kwargs)
def get_config(self):
config = super().get_config()
config["k"] = self.k
return config
def call(self, input):
return tf.multiply(input, 2)
def main():
model = tf.keras.models.Sequential(
[
tf.keras.Input(name='input_layer', shape=(10,)),
CustomLayer(10, name='custom_layer'),
tf.keras.layers.Dense(1, activation='sigmoid', name='output_layer')
]
)
print("SUMMARY OF THE MODEL CREATED")
print("-" * 60)
print(model.summary())
model.save('model.h5')
del model
print()
print()
model = tf.keras.models.load_model('model.h5')
print("SUMMARY OF THE MODEL LOADED")
print("-" * 60)
print(model.summary())
if __name__ == "__main__":
main()
#reference : https://stackoverflow.com/questions/62280161/saving-keras-models-with-custom-layers
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment