Forked from batzner/tensorflow_rename_variables.py
Last active
April 5, 2018 02:33
-
-
Save knwng/614e132b91ec31b4f7ca16df747e52d0 to your computer and use it in GitHub Desktop.
Small python script to rename variables in a TensorFlow checkpoint
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys, getopt | |
import tensorflow as tf | |
usage_str = 'python tensorflow_rename_variables.py --checkpoint_dir=path/to/dir/ ' \ | |
'--replace_from=substr --replace_to=substr --add_prefix=abc --dry_run' | |
def rename(checkpoint_dir, replace_from, replace_to, add_prefix, dry_run): | |
checkpoint = tf.train.get_checkpoint_state(checkpoint_dir) | |
with tf.Session() as sess: | |
for var_name, _ in tf.contrib.framework.list_variables(checkpoint_dir): | |
# Load the variable | |
var = tf.contrib.framework.load_variable(checkpoint_dir, var_name) | |
# Set the new name | |
new_name = var_name | |
if None not in [replace_from, replace_to]: | |
new_name = new_name.replace(replace_from, replace_to) | |
if add_prefix: | |
new_name = add_prefix + new_name | |
if new_name != var_name: | |
print('%s would be renamed to %s.' % (var_name, new_name)) | |
if not dry_run: | |
# Rename the variable | |
var = tf.Variable(var, name=new_name) | |
if not dry_run: | |
# Save the variables | |
saver = tf.train.Saver() | |
sess.run(tf.global_variables_initializer()) | |
saver.save(sess, checkpoint.model_checkpoint_path) | |
def main(argv): | |
checkpoint_dir = None | |
replace_from = None | |
replace_to = None | |
add_prefix = None | |
dry_run = False | |
try: | |
opts, args = getopt.getopt(argv, 'h', ['help=', 'checkpoint_dir=', 'replace_from=', | |
'replace_to=', 'add_prefix=', 'dry_run']) | |
except getopt.GetoptError: | |
print(usage_str) | |
sys.exit(2) | |
for opt, arg in opts: | |
if opt in ('-h', '--help'): | |
print(usage_str) | |
sys.exit() | |
elif opt == '--checkpoint_dir': | |
checkpoint_dir = arg | |
elif opt == '--replace_from': | |
replace_from = arg | |
elif opt == '--replace_to': | |
replace_to = arg | |
elif opt == '--add_prefix': | |
add_prefix = arg | |
elif opt == '--dry_run': | |
dry_run = True | |
if not checkpoint_dir: | |
print('Please specify a checkpoint_dir. Usage:') | |
print(usage_str) | |
sys.exit(2) | |
rename(checkpoint_dir, replace_from, replace_to, add_prefix, dry_run) | |
if __name__ == '__main__': | |
main(sys.argv[1:]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment