Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save tanzhenyu/22fadcfda66704199a5c5d4edf10c17e to your computer and use it in GitHub Desktop.
Save tanzhenyu/22fadcfda66704199a5c5d4edf10c17e to your computer and use it in GitHub Desktop.
test model cloning
# a gist of model cloning (sequential). (functional model cloning should be the same)
def to_list(x):
if isinstance(x, list):
return x
else:
return [x]
def is_keras_tensor(x):
return hasattr(x, '_keras_history')
def clone_sequential_model(model, input_tensors=None):
def clone(layer):
return layer.__class__.from_config(layer.get_config())
layers = [clone(layer) for layer in model.layers]
if input_tensors is None:
return tf.keras.Sequential(layers=layers, name=model.name)
else:
if len(to_list(input_tensors)) != 1:
raise ValueError('To clone a `Sequential` model, we expect '
' at most one tensor '
'as part of `input_tensors`.')
x = to_list(input_tensors)[0]
if is_keras_tensor(x):
origin_layer = x._keras_history[0]
if isinstance(origin_layer, InputLayer):
return tf.keras.Sequential(layers=[origin_layer] + layers, name=model.name)
else:
raise ValueError('Cannot clone a `Sequential` model on top '
'of a tensor that comes from a Keras layer '
'other than an `InputLayer`. '
'Use the functional API instead.')
input_tensor = tf.keras.Input(tensor=x, name='input_wrapper_for_' + str(x.name))
input_layer = input_tensor._keras_history[0]
return tf.keras.Sequential(layers=[input_layer] + layers, name=model.name)
import tensorflow as tf
tf.reset_default_graph()
model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(1, input_dim = (1)))
model.add(tf.keras.layers.BatchNormalization())
model.add(tf.keras.layers.Activation('softmax'))
x = tf.placeholder(tf.float32, shape = (None, 1))
clone = clone_sequential_model(model, input_tensors=x)
model.compile(optimizer = 'sgd', loss = 'categorical_crossentropy', metrics = ['categorical_accuracy'])
clone.compile(optimizer = 'sgd', loss = 'categorical_crossentropy', metrics = ['categorical_accuracy'])
# dense_input with updates of batch_norm moving_mean and moving_var
model._feed_inputs
model.get_updates_for(model._feed_inputs)
# empty input with empty updates
clone._feed_inputs
clone.get_updates_for(clone._feed_inputs)
# train function with update ops including batch norm assign moving mean/var ops
model._make_train_function()
print(model.train_function.updates_op)
# train function with update ops excluding batch norm assign moving mena/var ops
clone._make_train_function()
print(clone.train_function.updates_op)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment