Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Small python script to rename variables in a TensorFlow checkpoint
import sys, getopt
import tensorflow as tf
usage_str = 'python tensorflow_rename_variables.py --checkpoint_dir=path/to/dir/ ' \
'--replace_from=substr --replace_to=substr --add_prefix=abc --dry_run'
def rename(checkpoint_dir, replace_from, replace_to, add_prefix, dry_run):
checkpoint = tf.train.get_checkpoint_state(checkpoint_dir)
with tf.Session() as sess:
for var_name, _ in tf.contrib.framework.list_variables(checkpoint_dir):
# Load the variable
var = tf.contrib.framework.load_variable(checkpoint_dir, var_name)
# Set the new name
new_name = var_name
if None not in [replace_from, replace_to]:
new_name = new_name.replace(replace_from, replace_to)
if add_prefix:
new_name = add_prefix + new_name
if dry_run:
print('%s would be renamed to %s.' % (var_name, new_name))
else:
print('Renaming %s to %s.' % (var_name, new_name))
# Rename the variable
var = tf.Variable(var, name=new_name)
if not dry_run:
# Save the variables
saver = tf.train.Saver()
sess.run(tf.global_variables_initializer())
saver.save(sess, checkpoint.model_checkpoint_path)
def main(argv):
checkpoint_dir = None
replace_from = None
replace_to = None
add_prefix = None
dry_run = False
try:
opts, args = getopt.getopt(argv, 'h', ['help=', 'checkpoint_dir=', 'replace_from=',
'replace_to=', 'add_prefix=', 'dry_run'])
except getopt.GetoptError:
print(usage_str)
sys.exit(2)
for opt, arg in opts:
if opt in ('-h', '--help'):
print(usage_str)
sys.exit()
elif opt == '--checkpoint_dir':
checkpoint_dir = arg
elif opt == '--replace_from':
replace_from = arg
elif opt == '--replace_to':
replace_to = arg
elif opt == '--add_prefix':
add_prefix = arg
elif opt == '--dry_run':
dry_run = True
if not checkpoint_dir:
print('Please specify a checkpoint_dir. Usage:')
print(usage_str)
sys.exit(2)
rename(checkpoint_dir, replace_from, replace_to, add_prefix, dry_run)
if __name__ == '__main__':
main(sys.argv[1:])
@bver

This comment has been minimized.

Copy link

@bver bver commented Feb 7, 2017

Thank you very much! I was able to reuse my models after upgrading to the recent TF version.

However I had to comment-out lines 23 and 24 since other variables with names not matching --replace_from were not saved back to the checkpoint file. It looks like in previous TF versions checkpoints were modified in-place but this logic changed.

P.

@npit

This comment has been minimized.

Copy link

@npit npit commented Jul 15, 2017

Thank you for this.

@spacetrain

This comment has been minimized.

Copy link

@spacetrain spacetrain commented Aug 3, 2017

Thanks for this!
Just one thing is that this script will not save a variable if its name does not contain "replace_from"
So I modified lines 23, 24 to this,

        if new_name == var_name:
            print('%s remains unchanged' % var_name)
            var = tf.Variable(var, name=new_name)
            continue

Cheers

P.S. Ah,, I just noticed that already @bver commented on this. =)

@batzner

This comment has been minimized.

Copy link
Owner Author

@batzner batzner commented Aug 24, 2017

Thanks @bver and @spacetrain, I changed it!

@liqing-ustc

This comment has been minimized.

Copy link

@liqing-ustc liqing-ustc commented Sep 15, 2017

Thanks a lot. It's exactly what I have been finding for a long time!

@clarken92

This comment has been minimized.

Copy link

@clarken92 clarken92 commented Nov 27, 2017

Thank you for sharing the code. It's very helpful

@fvisin

This comment has been minimized.

Copy link

@fvisin fvisin commented Dec 20, 2017

Thanks a lot, great script!!
I improved it a little adding a few other options to look for a specific key (variable name) and to compare the variables in two checkpoints.
In case it can be of some help: https://gist.github.com/fvisin/578089ae098424590d3f25567b6ee255

@jimmyljxy

This comment has been minimized.

Copy link

@jimmyljxy jimmyljxy commented Aug 29, 2018

thank you for your great helpful code!!!

@Yangyangii

This comment has been minimized.

Copy link

@Yangyangii Yangyangii commented Oct 17, 2018

Thank you, this code is so coooooooooooool :)

@hellozjj

This comment has been minimized.

Copy link

@hellozjj hellozjj commented Dec 11, 2018

I got an Error: ValueError: GraphDef cannot be larger than 2GB.

@claudehang

This comment has been minimized.

Copy link

@claudehang claudehang commented Feb 10, 2019

@batzner
Thanks for your cool code, but I got this error,
saver.save(sess, checkpoint.model_checkpoint_path) AttributeError: 'NoneType' object has no attribute 'model_checkpoint_path'
however, the conversion has actually done.
Why is it?

@cocopambag

This comment has been minimized.

Copy link

@cocopambag cocopambag commented Aug 27, 2019

Thank you

@zheyuanWang

This comment has been minimized.

Copy link

@zheyuanWang zheyuanWang commented Jul 22, 2020

I wrote a loop with the rename function:

    for i in range(1,7):  # expaned_conv_i
        node_name = V_HEAD_EX + "_" + str(i + n_to_add) + "/"
        to_node_name = V_UPPER_HEAD_EX + "_" + str(i) + "/"
        rename(checkpointdir, node_name, to_node_name, dry_run=dry_run)

but the variable have them duplicated with a suffix "_1"

Renaming     MobilenetV3/expanded_conv_9/depthwise/BatchNorm/moving_mean to
MobilenetV3/upper_layers/expanded_conv_1/depthwise/BatchNorm/moving_mean
Renaming     MobilenetV3/expanded_conv_9/depthwise/BatchNorm/moving_mean_1 to
MobilenetV3/upper_layers/expanded_conv_1/depthwise/BatchNorm/moving_mean_1
Renaming     MobilenetV3/expanded_conv_9/depthwise/BatchNorm/moving_variance to
MobilenetV3/upper_layers/expanded_conv_1/depthwise/BatchNorm/moving_variance
Renaming     MobilenetV3/expanded_conv_9/depthwise/BatchNorm/moving_variance_1 to
MobilenetV3/upper_layers/expanded_conv_1/depthwise/BatchNorm/moving_variance_1

Anyone knows why it's happening?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.