Created
August 2, 2020 23:38
-
-
Save thierryherrmann/c6bc132b149c264342fa79bd912c172f to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def save_module(module, model_dir): | |
# When saving a tf.keras.Model with either model.save() or | |
# tf.keras.models.save_model() or tf.saved_model.save(), | |
# the saved model contains a `serving_default` signature used to get the | |
# output of the model from an input sample. But here we don't save a keras | |
# Model but a tf.Module. This requires to specify the signatures manually | |
# Note that we also export the training function here | |
tf.saved_model.save(module, model_dir, | |
signatures={ | |
'my_serve' : | |
module.__call__.get_concrete_function(tf.TensorSpec([None, 8], tf.float32)), | |
'my_train' : | |
module.my_train.get_concrete_function(tf.TensorSpec([None, 8], tf.float32), | |
tf.TensorSpec([None], tf.float32))}) | |
def inspect_checkpoint(checkpoint, print_values=False, variables=None): | |
if not variables: | |
variables = [var_name for (var_name, shape) in tf.train.list_variables(checkpoint)] | |
checkpoint_reader = tf.train.load_checkpoint(checkpoint) | |
for var_name in variables: | |
try: | |
tensor = checkpoint_reader.get_tensor(var_name) | |
except Exception as e: | |
print('ignored : %s (exception %s)' % (var_name, str(type(e)))) | |
continue | |
if isinstance(tensor, np.ndarray): | |
if print_values: | |
print('tensor : ', var_name, tensor.shape, tensor) | |
else: | |
print('tensor : ', var_name, tensor.shape) | |
else: | |
if print_values: | |
print('non-tensor: ', var_name, type(tensor), tensor) | |
else: | |
print('non-tensor: ', var_name, type(tensor)) | |
model_dir = 'saved_model' | |
os.makedirs(model_dir, exist_ok=True) | |
save_module(module, model_dir) | |
inspect_checkpoint(model_dir + '/variables/variables') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment