Skip to content

Instantly share code, notes, and snippets.

Last active May 25, 2023 06:15
Show Gist options
  • Save batzner/7c24802dd9c5e15870b4b56e22135c96 to your computer and use it in GitHub Desktop.
Save batzner/7c24802dd9c5e15870b4b56e22135c96 to your computer and use it in GitHub Desktop.
Small python script to rename variables in a TensorFlow checkpoint
import sys, getopt
import tensorflow as tf
usage_str = 'python --checkpoint_dir=path/to/dir/ ' \
'--replace_from=substr --replace_to=substr --add_prefix=abc --dry_run'
def rename(checkpoint_dir, replace_from, replace_to, add_prefix, dry_run):
checkpoint = tf.train.get_checkpoint_state(checkpoint_dir)
with tf.Session() as sess:
for var_name, _ in tf.contrib.framework.list_variables(checkpoint_dir):
# Load the variable
var = tf.contrib.framework.load_variable(checkpoint_dir, var_name)
# Set the new name
new_name = var_name
if None not in [replace_from, replace_to]:
new_name = new_name.replace(replace_from, replace_to)
if add_prefix:
new_name = add_prefix + new_name
if dry_run:
print('%s would be renamed to %s.' % (var_name, new_name))
print('Renaming %s to %s.' % (var_name, new_name))
# Rename the variable
var = tf.Variable(var, name=new_name)
if not dry_run:
# Save the variables
saver = tf.train.Saver(), checkpoint.model_checkpoint_path)
def main(argv):
checkpoint_dir = None
replace_from = None
replace_to = None
add_prefix = None
dry_run = False
opts, args = getopt.getopt(argv, 'h', ['help=', 'checkpoint_dir=', 'replace_from=',
'replace_to=', 'add_prefix=', 'dry_run'])
except getopt.GetoptError:
for opt, arg in opts:
if opt in ('-h', '--help'):
elif opt == '--checkpoint_dir':
checkpoint_dir = arg
elif opt == '--replace_from':
replace_from = arg
elif opt == '--replace_to':
replace_to = arg
elif opt == '--add_prefix':
add_prefix = arg
elif opt == '--dry_run':
dry_run = True
if not checkpoint_dir:
print('Please specify a checkpoint_dir. Usage:')
rename(checkpoint_dir, replace_from, replace_to, add_prefix, dry_run)
if __name__ == '__main__':
Copy link

npit commented Jul 15, 2017

Thank you for this.

Copy link

spacetrain commented Aug 3, 2017

Thanks for this!
Just one thing is that this script will not save a variable if its name does not contain "replace_from"
So I modified lines 23, 24 to this,

        if new_name == var_name:
            print('%s remains unchanged' % var_name)
            var = tf.Variable(var, name=new_name)


P.S. Ah,, I just noticed that already @bver commented on this. =)

Copy link

batzner commented Aug 24, 2017

Thanks @bver and @spacetrain, I changed it!

Copy link

Thanks a lot. It's exactly what I have been finding for a long time!

Copy link

clarken92 commented Nov 27, 2017

Thank you for sharing the code. It's very helpful

Copy link

fvisin commented Dec 20, 2017

Thanks a lot, great script!!
I improved it a little adding a few other options to look for a specific key (variable name) and to compare the variables in two checkpoints.
In case it can be of some help:

Copy link

thank you for your great helpful code!!!

Copy link

Thank you, this code is so coooooooooooool :)

Copy link

I got an Error: ValueError: GraphDef cannot be larger than 2GB.

Copy link

Thanks for your cool code, but I got this error,, checkpoint.model_checkpoint_path) AttributeError: 'NoneType' object has no attribute 'model_checkpoint_path'
however, the conversion has actually done.
Why is it?

Copy link

Thank you

Copy link

I wrote a loop with the rename function:

    for i in range(1,7):  # expaned_conv_i
        node_name = V_HEAD_EX + "_" + str(i + n_to_add) + "/"
        to_node_name = V_UPPER_HEAD_EX + "_" + str(i) + "/"
        rename(checkpointdir, node_name, to_node_name, dry_run=dry_run)

but the variable have them duplicated with a suffix "_1"

Renaming     MobilenetV3/expanded_conv_9/depthwise/BatchNorm/moving_mean to
Renaming     MobilenetV3/expanded_conv_9/depthwise/BatchNorm/moving_mean_1 to
Renaming     MobilenetV3/expanded_conv_9/depthwise/BatchNorm/moving_variance to
Renaming     MobilenetV3/expanded_conv_9/depthwise/BatchNorm/moving_variance_1 to

Anyone knows why it's happening?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment