Skip to content

Instantly share code, notes, and snippets.

@urialon
Created September 22, 2019 06:02
Show Gist options
  • Save urialon/8f1d34983e9f62c260d723cf94a841f0 to your computer and use it in GitHub Desktop.
Save urialon/8f1d34983e9f62c260d723cf94a841f0 to your computer and use it in GitHub Desktop.
from argparse import ArgumentParser
import os
import pickle
import tensorflow as tf
from tensorflow.contrib.framework.python.framework import checkpoint_utils
vars_to_rename = {
# 'model/old_name': 'model/new_name',
'model/decoder/attention_wrapper/lstm_cell/bias': 'model/decoder/attention_wrapper/multi_rnn_cell/cell_0/lstm_cell/bias',
'model/decoder/attention_wrapper/lstm_cell/kernel': 'model/decoder/attention_wrapper/multi_rnn_cell/cell_0/lstm_cell/kernel',
}
if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument("--old", dest="old_model", required=True)
parser.add_argument("--new", dest="new_model", required=True)
parser.add_argument('--dry_run', action='store_true')
args = parser.parse_args()
old_model_dirname = os.path.dirname(args.old_model)
new_model_dirname = os.path.dirname(args.new_model)
if not os.path.exists(new_model_dirname):
os.makedirs(new_model_dirname)
with tf.Session() as sess:
for var_name, _ in checkpoint_utils.list_variables(args.old_model):
# Load the variable
var = checkpoint_utils.load_variable(args.old_model, var_name)
# Set the new name
new_name = var_name
if var_name in vars_to_rename:
new_name = vars_to_rename[var_name]
if args.dry_run:
print('%s would be renamed to %s.' % (var_name, new_name))
else:
print('Renaming %s to %s.' % (var_name, new_name))
# Rename the variable
var = tf.Variable(var, name=new_name)
else:
if args.dry_run:
print('%s is kept the same unchanged.' % (var_name))
else:
print('Re-creating %s.' % (var_name))
# Rename the variable
var = tf.Variable(var, name=new_name)
if not args.dry_run:
# Save the variables
saver = tf.train.Saver()
sess.run(tf.global_variables_initializer())
saver.save(sess, args.new_model)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment