Last active
May 25, 2023 06:15
-
-
Save batzner/7c24802dd9c5e15870b4b56e22135c96 to your computer and use it in GitHub Desktop.
Small python script to rename variables in a TensorFlow checkpoint
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys, getopt | |
import tensorflow as tf | |
usage_str = 'python tensorflow_rename_variables.py --checkpoint_dir=path/to/dir/ ' \ | |
'--replace_from=substr --replace_to=substr --add_prefix=abc --dry_run' | |
def rename(checkpoint_dir, replace_from, replace_to, add_prefix, dry_run): | |
checkpoint = tf.train.get_checkpoint_state(checkpoint_dir) | |
with tf.Session() as sess: | |
for var_name, _ in tf.contrib.framework.list_variables(checkpoint_dir): | |
# Load the variable | |
var = tf.contrib.framework.load_variable(checkpoint_dir, var_name) | |
# Set the new name | |
new_name = var_name | |
if None not in [replace_from, replace_to]: | |
new_name = new_name.replace(replace_from, replace_to) | |
if add_prefix: | |
new_name = add_prefix + new_name | |
if dry_run: | |
print('%s would be renamed to %s.' % (var_name, new_name)) | |
else: | |
print('Renaming %s to %s.' % (var_name, new_name)) | |
# Rename the variable | |
var = tf.Variable(var, name=new_name) | |
if not dry_run: | |
# Save the variables | |
saver = tf.train.Saver() | |
sess.run(tf.global_variables_initializer()) | |
saver.save(sess, checkpoint.model_checkpoint_path) | |
def main(argv): | |
checkpoint_dir = None | |
replace_from = None | |
replace_to = None | |
add_prefix = None | |
dry_run = False | |
try: | |
opts, args = getopt.getopt(argv, 'h', ['help=', 'checkpoint_dir=', 'replace_from=', | |
'replace_to=', 'add_prefix=', 'dry_run']) | |
except getopt.GetoptError: | |
print(usage_str) | |
sys.exit(2) | |
for opt, arg in opts: | |
if opt in ('-h', '--help'): | |
print(usage_str) | |
sys.exit() | |
elif opt == '--checkpoint_dir': | |
checkpoint_dir = arg | |
elif opt == '--replace_from': | |
replace_from = arg | |
elif opt == '--replace_to': | |
replace_to = arg | |
elif opt == '--add_prefix': | |
add_prefix = arg | |
elif opt == '--dry_run': | |
dry_run = True | |
if not checkpoint_dir: | |
print('Please specify a checkpoint_dir. Usage:') | |
print(usage_str) | |
sys.exit(2) | |
rename(checkpoint_dir, replace_from, replace_to, add_prefix, dry_run) | |
if __name__ == '__main__': | |
main(sys.argv[1:]) |
I wrote a loop with the rename function:
for i in range(1,7): # expaned_conv_i
node_name = V_HEAD_EX + "_" + str(i + n_to_add) + "/"
to_node_name = V_UPPER_HEAD_EX + "_" + str(i) + "/"
rename(checkpointdir, node_name, to_node_name, dry_run=dry_run)
but the variable have them duplicated with a suffix "_1"
Renaming MobilenetV3/expanded_conv_9/depthwise/BatchNorm/moving_mean to
MobilenetV3/upper_layers/expanded_conv_1/depthwise/BatchNorm/moving_mean
Renaming MobilenetV3/expanded_conv_9/depthwise/BatchNorm/moving_mean_1 to
MobilenetV3/upper_layers/expanded_conv_1/depthwise/BatchNorm/moving_mean_1
Renaming MobilenetV3/expanded_conv_9/depthwise/BatchNorm/moving_variance to
MobilenetV3/upper_layers/expanded_conv_1/depthwise/BatchNorm/moving_variance
Renaming MobilenetV3/expanded_conv_9/depthwise/BatchNorm/moving_variance_1 to
MobilenetV3/upper_layers/expanded_conv_1/depthwise/BatchNorm/moving_variance_1
Anyone knows why it's happening?
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Thank you