Skip to content

Instantly share code, notes, and snippets.

@kriskorrel-cw
Created April 9, 2020 09:37
Show Gist options
  • Save kriskorrel-cw/c90e349fddec1c2a20a031660431f80a to your computer and use it in GitHub Desktop.
Save kriskorrel-cw/c90e349fddec1c2a20a031660431f80a to your computer and use it in GitHub Desktop.
Saving a model not possible with custom layer, using different signature
from tensorflow import keras
from tensorflow.keras.layers import BatchNormalization, Dense
class BatchNormalization1(BatchNormalization):
def call(self, inputs, **kwargs):
return super(BatchNormalization1, self).call(inputs, **kwargs)
class BatchNormalization2(BatchNormalization):
def call(self, *args, **kwargs):
return super(BatchNormalization2, self).call(*args, **kwargs)
class ThreeLayerMLP(keras.Model):
def __init__(self, name=None):
super(ThreeLayerMLP, self).__init__(name=name)
self.dense_1 = Dense(64, activation='relu', name='dense_1')
self.dense_2 = Dense(64, activation='relu', name='dense_2')
self.batch_norm = BatchNormalization1()
self.pred_layer = Dense(10, name='predictions')
def call(self, inputs):
x = self.dense_1(inputs)
x = self.dense_2(x)
x = self.batch_norm(x)
return self.pred_layer(x)
def get_model():
return ThreeLayerMLP(name='3_layer_mlp')
model = get_model()
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = x_train.reshape(60000, 784).astype('float32') / 255
x_test = x_test.reshape(10000, 784).astype('float32') / 255
model.compile(loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=keras.optimizers.RMSprop())
history = model.fit(x_train, y_train,
batch_size=64,
epochs=1)
# Reset metrics before saving so that loaded model has same state,
# since metric states are not preserved by Model.save_weights
model.reset_metrics()
model.save('path_to_my_model', save_format='tf')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment