Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Small python script to rename variables in a TensorFlow checkpoint
import sys, getopt
import tensorflow as tf
usage_str = 'python tensorflow_rename_variables.py --checkpoint_dir=path/to/dir/ ' \
'--replace_from=substr --replace_to=substr --add_prefix=abc --dry_run'
def rename(checkpoint_dir, replace_from, replace_to, add_prefix, dry_run):
checkpoint = tf.train.get_checkpoint_state(checkpoint_dir)
with tf.Session() as sess:
for var_name, _ in tf.contrib.framework.list_variables(checkpoint_dir):
# Load the variable
var = tf.contrib.framework.load_variable(checkpoint_dir, var_name)
# Set the new name
new_name = var_name
if None not in [replace_from, replace_to]:
new_name = new_name.replace(replace_from, replace_to)
if add_prefix:
new_name = add_prefix + new_name
if dry_run:
print('%s would be renamed to %s.' % (var_name, new_name))
else:
print('Renaming %s to %s.' % (var_name, new_name))
# Rename the variable
var = tf.Variable(var, name=new_name)
if not dry_run:
# Save the variables
saver = tf.train.Saver()
sess.run(tf.global_variables_initializer())
saver.save(sess, checkpoint.model_checkpoint_path)
def main(argv):
checkpoint_dir = None
replace_from = None
replace_to = None
add_prefix = None
dry_run = False
try:
opts, args = getopt.getopt(argv, 'h', ['help=', 'checkpoint_dir=', 'replace_from=',
'replace_to=', 'add_prefix=', 'dry_run'])
except getopt.GetoptError:
print(usage_str)
sys.exit(2)
for opt, arg in opts:
if opt in ('-h', '--help'):
print(usage_str)
sys.exit()
elif opt == '--checkpoint_dir':
checkpoint_dir = arg
elif opt == '--replace_from':
replace_from = arg
elif opt == '--replace_to':
replace_to = arg
elif opt == '--add_prefix':
add_prefix = arg
elif opt == '--dry_run':
dry_run = True
if not checkpoint_dir:
print('Please specify a checkpoint_dir. Usage:')
print(usage_str)
sys.exit(2)
rename(checkpoint_dir, replace_from, replace_to, add_prefix, dry_run)
if __name__ == '__main__':
main(sys.argv[1:])
@bver

This comment has been minimized.

Copy link

@bver bver commented Feb 7, 2017

Thank you very much! I was able to reuse my models after upgrading to the recent TF version.

However I had to comment-out lines 23 and 24 since other variables with names not matching --replace_from were not saved back to the checkpoint file. It looks like in previous TF versions checkpoints were modified in-place but this logic changed.

P.

@npit

This comment has been minimized.

Copy link

@npit npit commented Jul 15, 2017

Thank you for this.

@spacetrain

This comment has been minimized.

Copy link

@spacetrain spacetrain commented Aug 3, 2017

Thanks for this!
Just one thing is that this script will not save a variable if its name does not contain "replace_from"
So I modified lines 23, 24 to this,

        if new_name == var_name:
            print('%s remains unchanged' % var_name)
            var = tf.Variable(var, name=new_name)
            continue

Cheers

P.S. Ah,, I just noticed that already @bver commented on this. =)

@batzner

This comment has been minimized.

Copy link
Owner Author

@batzner batzner commented Aug 24, 2017

Thanks @bver and @spacetrain, I changed it!

@liqing-ustc

This comment has been minimized.

Copy link

@liqing-ustc liqing-ustc commented Sep 15, 2017

Thanks a lot. It's exactly what I have been finding for a long time!

@clarken92

This comment has been minimized.

Copy link

@clarken92 clarken92 commented Nov 27, 2017

Thank you for sharing the code. It's very helpful

@fvisin

This comment has been minimized.

Copy link

@fvisin fvisin commented Dec 20, 2017

Thanks a lot, great script!!
I improved it a little adding a few other options to look for a specific key (variable name) and to compare the variables in two checkpoints.
In case it can be of some help: https://gist.github.com/fvisin/578089ae098424590d3f25567b6ee255

@jimmyljxy

This comment has been minimized.

Copy link

@jimmyljxy jimmyljxy commented Aug 29, 2018

thank you for your great helpful code!!!

@Yangyangii

This comment has been minimized.

Copy link

@Yangyangii Yangyangii commented Oct 17, 2018

Thank you, this code is so coooooooooooool :)

@hellozjj

This comment has been minimized.

Copy link

@hellozjj hellozjj commented Dec 11, 2018

I got an Error: ValueError: GraphDef cannot be larger than 2GB.

@claudehang

This comment has been minimized.

Copy link

@claudehang claudehang commented Feb 10, 2019

@batzner
Thanks for your cool code, but I got this error,
saver.save(sess, checkpoint.model_checkpoint_path) AttributeError: 'NoneType' object has no attribute 'model_checkpoint_path'
however, the conversion has actually done.
Why is it?

@cocopambag

This comment has been minimized.

Copy link

@cocopambag cocopambag commented Aug 27, 2019

Thank you

@zheyuanWang

This comment has been minimized.

Copy link

@zheyuanWang zheyuanWang commented Jul 22, 2020

I wrote a loop with the rename function:

    for i in range(1,7):  # expaned_conv_i
        node_name = V_HEAD_EX + "_" + str(i + n_to_add) + "/"
        to_node_name = V_UPPER_HEAD_EX + "_" + str(i) + "/"
        rename(checkpointdir, node_name, to_node_name, dry_run=dry_run)

but the variable have them duplicated with a suffix "_1"

Renaming     MobilenetV3/expanded_conv_9/depthwise/BatchNorm/moving_mean to
MobilenetV3/upper_layers/expanded_conv_1/depthwise/BatchNorm/moving_mean
Renaming     MobilenetV3/expanded_conv_9/depthwise/BatchNorm/moving_mean_1 to
MobilenetV3/upper_layers/expanded_conv_1/depthwise/BatchNorm/moving_mean_1
Renaming     MobilenetV3/expanded_conv_9/depthwise/BatchNorm/moving_variance to
MobilenetV3/upper_layers/expanded_conv_1/depthwise/BatchNorm/moving_variance
Renaming     MobilenetV3/expanded_conv_9/depthwise/BatchNorm/moving_variance_1 to
MobilenetV3/upper_layers/expanded_conv_1/depthwise/BatchNorm/moving_variance_1

Anyone knows why it's happening?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment