Skip to content

Instantly share code, notes, and snippets.

@thierryherrmann
Created August 2, 2020 23:38
Show Gist options
  • Save thierryherrmann/c6bc132b149c264342fa79bd912c172f to your computer and use it in GitHub Desktop.
Save thierryherrmann/c6bc132b149c264342fa79bd912c172f to your computer and use it in GitHub Desktop.
def save_module(module, model_dir):
# When saving a tf.keras.Model with either model.save() or
# tf.keras.models.save_model() or tf.saved_model.save(),
# the saved model contains a `serving_default` signature used to get the
# output of the model from an input sample. But here we don't save a keras
# Model but a tf.Module. This requires to specify the signatures manually
# Note that we also export the training function here
tf.saved_model.save(module, model_dir,
signatures={
'my_serve' :
module.__call__.get_concrete_function(tf.TensorSpec([None, 8], tf.float32)),
'my_train' :
module.my_train.get_concrete_function(tf.TensorSpec([None, 8], tf.float32),
tf.TensorSpec([None], tf.float32))})
def inspect_checkpoint(checkpoint, print_values=False, variables=None):
if not variables:
variables = [var_name for (var_name, shape) in tf.train.list_variables(checkpoint)]
checkpoint_reader = tf.train.load_checkpoint(checkpoint)
for var_name in variables:
try:
tensor = checkpoint_reader.get_tensor(var_name)
except Exception as e:
print('ignored : %s (exception %s)' % (var_name, str(type(e))))
continue
if isinstance(tensor, np.ndarray):
if print_values:
print('tensor : ', var_name, tensor.shape, tensor)
else:
print('tensor : ', var_name, tensor.shape)
else:
if print_values:
print('non-tensor: ', var_name, type(tensor), tensor)
else:
print('non-tensor: ', var_name, type(tensor))
model_dir = 'saved_model'
os.makedirs(model_dir, exist_ok=True)
save_module(module, model_dir)
inspect_checkpoint(model_dir + '/variables/variables')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment