Skip to content

Instantly share code, notes, and snippets.

@batzner
Last active May 25, 2023 06:15
Show Gist options
  • Save batzner/7c24802dd9c5e15870b4b56e22135c96 to your computer and use it in GitHub Desktop.
Save batzner/7c24802dd9c5e15870b4b56e22135c96 to your computer and use it in GitHub Desktop.
Small python script to rename variables in a TensorFlow checkpoint
import sys, getopt
import tensorflow as tf
usage_str = 'python tensorflow_rename_variables.py --checkpoint_dir=path/to/dir/ ' \
'--replace_from=substr --replace_to=substr --add_prefix=abc --dry_run'
def rename(checkpoint_dir, replace_from, replace_to, add_prefix, dry_run):
checkpoint = tf.train.get_checkpoint_state(checkpoint_dir)
with tf.Session() as sess:
for var_name, _ in tf.contrib.framework.list_variables(checkpoint_dir):
# Load the variable
var = tf.contrib.framework.load_variable(checkpoint_dir, var_name)
# Set the new name
new_name = var_name
if None not in [replace_from, replace_to]:
new_name = new_name.replace(replace_from, replace_to)
if add_prefix:
new_name = add_prefix + new_name
if dry_run:
print('%s would be renamed to %s.' % (var_name, new_name))
else:
print('Renaming %s to %s.' % (var_name, new_name))
# Rename the variable
var = tf.Variable(var, name=new_name)
if not dry_run:
# Save the variables
saver = tf.train.Saver()
sess.run(tf.global_variables_initializer())
saver.save(sess, checkpoint.model_checkpoint_path)
def main(argv):
checkpoint_dir = None
replace_from = None
replace_to = None
add_prefix = None
dry_run = False
try:
opts, args = getopt.getopt(argv, 'h', ['help=', 'checkpoint_dir=', 'replace_from=',
'replace_to=', 'add_prefix=', 'dry_run'])
except getopt.GetoptError:
print(usage_str)
sys.exit(2)
for opt, arg in opts:
if opt in ('-h', '--help'):
print(usage_str)
sys.exit()
elif opt == '--checkpoint_dir':
checkpoint_dir = arg
elif opt == '--replace_from':
replace_from = arg
elif opt == '--replace_to':
replace_to = arg
elif opt == '--add_prefix':
add_prefix = arg
elif opt == '--dry_run':
dry_run = True
if not checkpoint_dir:
print('Please specify a checkpoint_dir. Usage:')
print(usage_str)
sys.exit(2)
rename(checkpoint_dir, replace_from, replace_to, add_prefix, dry_run)
if __name__ == '__main__':
main(sys.argv[1:])
@Dev-HJYoo
Copy link

Thank you

@zheyuanWang
Copy link

I wrote a loop with the rename function:

    for i in range(1,7):  # expaned_conv_i
        node_name = V_HEAD_EX + "_" + str(i + n_to_add) + "/"
        to_node_name = V_UPPER_HEAD_EX + "_" + str(i) + "/"
        rename(checkpointdir, node_name, to_node_name, dry_run=dry_run)

but the variable have them duplicated with a suffix "_1"

Renaming     MobilenetV3/expanded_conv_9/depthwise/BatchNorm/moving_mean to
MobilenetV3/upper_layers/expanded_conv_1/depthwise/BatchNorm/moving_mean
Renaming     MobilenetV3/expanded_conv_9/depthwise/BatchNorm/moving_mean_1 to
MobilenetV3/upper_layers/expanded_conv_1/depthwise/BatchNorm/moving_mean_1
Renaming     MobilenetV3/expanded_conv_9/depthwise/BatchNorm/moving_variance to
MobilenetV3/upper_layers/expanded_conv_1/depthwise/BatchNorm/moving_variance
Renaming     MobilenetV3/expanded_conv_9/depthwise/BatchNorm/moving_variance_1 to
MobilenetV3/upper_layers/expanded_conv_1/depthwise/BatchNorm/moving_variance_1

Anyone knows why it's happening?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment