Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save flymin/4ce1a67d3b2e52297900511b6fbc5e3a to your computer and use it in GitHub Desktop.
Save flymin/4ce1a67d3b2e52297900511b6fbc5e3a to your computer and use it in GitHub Desktop.
[ckpt rename] Small python script to rename variables in a TensorFlow checkpoint; add output option and only display changing variables #Tensorflow
import sys, getopt
import os
import pdb
import tensorflow as tf
usage_str = 'python tensorflow_rename_variables.py --checkpoint_dir=path/to/dir/ ' \
'--replace_from=substr --replace_to=substr --add_prefix=abc --dry_run ' \
'--output_dir=dir/to/output'
def rename(checkpoint_dir, replace_from, replace_to, add_prefix, dry_run, output_dir):
checkpoint = tf.train.get_checkpoint_state(checkpoint_dir)
with tf.Session() as sess:
for var_name, _ in tf.contrib.framework.list_variables(checkpoint_dir):
# Load the variable
var = tf.contrib.framework.load_variable(checkpoint_dir, var_name)
# Set the new name
new_name = var_name
if None not in [replace_from, replace_to]:
new_name = new_name.replace(replace_from, replace_to)
if add_prefix:
new_name = add_prefix + new_name
if dry_run:
if var_name != new_name:
print('%s would be renamed to %s.' % (var_name, new_name))
else:
print('%s would not change.' % (var_name))
else:
if var_name != new_name:
print('Renaming %s to %s.' % (var_name, new_name))
# Rename the variable
var = tf.Variable(var, name=new_name)
if not dry_run:
# Save the variables
saver = tf.train.Saver()
sess.run(tf.global_variables_initializer())
out_file = os.path.join(output_dir, os.path.split(checkpoint.model_checkpoint_path)[-1])
saver.save(sess, out_file)
print("Model save to %s", out_file)
def main(argv):
checkpoint_dir = None
replace_from = None
replace_to = None
add_prefix = None
dry_run = False
try:
opts, args = getopt.getopt(argv, 'h', ['help=', 'checkpoint_dir=', 'replace_from=',
'replace_to=', 'add_prefix=', 'dry_run', 'output_dir='])
except getopt.GetoptError:
print(usage_str)
sys.exit(2)
for opt, arg in opts:
if opt in ('-h', '--help'):
print(usage_str)
sys.exit()
elif opt == '--checkpoint_dir':
checkpoint_dir = arg
elif opt == '--replace_from':
replace_from = arg
elif opt == '--replace_to':
replace_to = arg
elif opt == '--add_prefix':
add_prefix = arg
elif opt == '--dry_run':
dry_run = True
elif opt == '--output_dir':
output_dir = arg
if not checkpoint_dir:
print('Please specify a checkpoint_dir. Usage:')
print(usage_str)
sys.exit(2)
rename(checkpoint_dir, replace_from, replace_to, add_prefix, dry_run, output_dir)
if __name__ == '__main__':
main(sys.argv[1:])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment