Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Small python script to rename variables in a TensorFlow checkpoint
import sys, getopt
import tensorflow as tf
usage_str = 'python tensorflow_rename_variables.py --checkpoint_dir=path/to/dir/ ' \
'--replace_from=substr --replace_to=substr --add_prefix=abc --dry_run'
def rename(checkpoint_dir, replace_from, replace_to, add_prefix, dry_run):
checkpoint = tf.train.get_checkpoint_state(checkpoint_dir)
with tf.Session() as sess:
for var_name, _ in tf.contrib.framework.list_variables(checkpoint_dir):
# Load the variable
var = tf.contrib.framework.load_variable(checkpoint_dir, var_name)
# Set the new name
new_name = var_name
if None not in [replace_from, replace_to]:
new_name = new_name.replace(replace_from, replace_to)
if add_prefix:
new_name = add_prefix + new_name
if dry_run:
print('%s would be renamed to %s.' % (var_name, new_name))
else:
print('Renaming %s to %s.' % (var_name, new_name))
# Rename the variable
var = tf.Variable(var, name=new_name)
if not dry_run:
# Save the variables
saver = tf.train.Saver()
sess.run(tf.global_variables_initializer())
saver.save(sess, checkpoint.model_checkpoint_path)
def main(argv):
checkpoint_dir = None
replace_from = None
replace_to = None
add_prefix = None
dry_run = False
try:
opts, args = getopt.getopt(argv, 'h', ['help=', 'checkpoint_dir=', 'replace_from=',
'replace_to=', 'add_prefix=', 'dry_run'])
except getopt.GetoptError:
print(usage_str)
sys.exit(2)
for opt, arg in opts:
if opt in ('-h', '--help'):
print(usage_str)
sys.exit()
elif opt == '--checkpoint_dir':
checkpoint_dir = arg
elif opt == '--replace_from':
replace_from = arg
elif opt == '--replace_to':
replace_to = arg
elif opt == '--add_prefix':
add_prefix = arg
elif opt == '--dry_run':
dry_run = True
if not checkpoint_dir:
print('Please specify a checkpoint_dir. Usage:')
print(usage_str)
sys.exit(2)
rename(checkpoint_dir, replace_from, replace_to, add_prefix, dry_run)
if __name__ == '__main__':
main(sys.argv[1:])
@bver

This comment has been minimized.

Copy link

bver commented Feb 7, 2017

Thank you very much! I was able to reuse my models after upgrading to the recent TF version.

However I had to comment-out lines 23 and 24 since other variables with names not matching --replace_from were not saved back to the checkpoint file. It looks like in previous TF versions checkpoints were modified in-place but this logic changed.

P.

@npit

This comment has been minimized.

Copy link

npit commented Jul 15, 2017

Thank you for this.

@spacetrain

This comment has been minimized.

Copy link

spacetrain commented Aug 3, 2017

Thanks for this!
Just one thing is that this script will not save a variable if its name does not contain "replace_from"
So I modified lines 23, 24 to this,

        if new_name == var_name:
            print('%s remains unchanged' % var_name)
            var = tf.Variable(var, name=new_name)
            continue

Cheers

P.S. Ah,, I just noticed that already @bver commented on this. =)

@batzner

This comment has been minimized.

Copy link
Owner Author

batzner commented Aug 24, 2017

Thanks @bver and @spacetrain, I changed it!

@liqing-ustc

This comment has been minimized.

Copy link

liqing-ustc commented Sep 15, 2017

Thanks a lot. It's exactly what I have been finding for a long time!

@clarken92

This comment has been minimized.

Copy link

clarken92 commented Nov 27, 2017

Thank you for sharing the code. It's very helpful

@fvisin

This comment has been minimized.

Copy link

fvisin commented Dec 20, 2017

Thanks a lot, great script!!
I improved it a little adding a few other options to look for a specific key (variable name) and to compare the variables in two checkpoints.
In case it can be of some help: https://gist.github.com/fvisin/578089ae098424590d3f25567b6ee255

@jimmyljxy

This comment has been minimized.

Copy link

jimmyljxy commented Aug 29, 2018

thank you for your great helpful code!!!

@Yangyangii

This comment has been minimized.

Copy link

Yangyangii commented Oct 17, 2018

Thank you, this code is so coooooooooooool :)

@hellozjj

This comment has been minimized.

Copy link

hellozjj commented Dec 11, 2018

I got an Error: ValueError: GraphDef cannot be larger than 2GB.

@claudehang

This comment has been minimized.

Copy link

claudehang commented Feb 10, 2019

@batzner
Thanks for your cool code, but I got this error,
saver.save(sess, checkpoint.model_checkpoint_path) AttributeError: 'NoneType' object has no attribute 'model_checkpoint_path'
however, the conversion has actually done.
Why is it?

@cocopambag

This comment has been minimized.

Copy link

cocopambag commented Aug 27, 2019

Thank you

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.