Forked from batzner/tensorflow_rename_variables.py
Last active
June 16, 2020 01:22
-
-
Save flymin/4ce1a67d3b2e52297900511b6fbc5e3a to your computer and use it in GitHub Desktop.
[ckpt rename] Small python script to rename variables in a TensorFlow checkpoint; add output option and only display changing variables #Tensorflow
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys, getopt | |
import os | |
import pdb | |
import tensorflow as tf | |
usage_str = 'python tensorflow_rename_variables.py --checkpoint_dir=path/to/dir/ ' \ | |
'--replace_from=substr --replace_to=substr --add_prefix=abc --dry_run ' \ | |
'--output_dir=dir/to/output' | |
def rename(checkpoint_dir, replace_from, replace_to, add_prefix, dry_run, output_dir): | |
checkpoint = tf.train.get_checkpoint_state(checkpoint_dir) | |
with tf.Session() as sess: | |
for var_name, _ in tf.contrib.framework.list_variables(checkpoint_dir): | |
# Load the variable | |
var = tf.contrib.framework.load_variable(checkpoint_dir, var_name) | |
# Set the new name | |
new_name = var_name | |
if None not in [replace_from, replace_to]: | |
new_name = new_name.replace(replace_from, replace_to) | |
if add_prefix: | |
new_name = add_prefix + new_name | |
if dry_run: | |
if var_name != new_name: | |
print('%s would be renamed to %s.' % (var_name, new_name)) | |
else: | |
print('%s would not change.' % (var_name)) | |
else: | |
if var_name != new_name: | |
print('Renaming %s to %s.' % (var_name, new_name)) | |
# Rename the variable | |
var = tf.Variable(var, name=new_name) | |
if not dry_run: | |
# Save the variables | |
saver = tf.train.Saver() | |
sess.run(tf.global_variables_initializer()) | |
out_file = os.path.join(output_dir, os.path.split(checkpoint.model_checkpoint_path)[-1]) | |
saver.save(sess, out_file) | |
print("Model save to %s", out_file) | |
def main(argv): | |
checkpoint_dir = None | |
replace_from = None | |
replace_to = None | |
add_prefix = None | |
dry_run = False | |
try: | |
opts, args = getopt.getopt(argv, 'h', ['help=', 'checkpoint_dir=', 'replace_from=', | |
'replace_to=', 'add_prefix=', 'dry_run', 'output_dir=']) | |
except getopt.GetoptError: | |
print(usage_str) | |
sys.exit(2) | |
for opt, arg in opts: | |
if opt in ('-h', '--help'): | |
print(usage_str) | |
sys.exit() | |
elif opt == '--checkpoint_dir': | |
checkpoint_dir = arg | |
elif opt == '--replace_from': | |
replace_from = arg | |
elif opt == '--replace_to': | |
replace_to = arg | |
elif opt == '--add_prefix': | |
add_prefix = arg | |
elif opt == '--dry_run': | |
dry_run = True | |
elif opt == '--output_dir': | |
output_dir = arg | |
if not checkpoint_dir: | |
print('Please specify a checkpoint_dir. Usage:') | |
print(usage_str) | |
sys.exit(2) | |
rename(checkpoint_dir, replace_from, replace_to, add_prefix, dry_run, output_dir) | |
if __name__ == '__main__': | |
main(sys.argv[1:]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment