Skip to content

Instantly share code, notes, and snippets.

Created January 26, 2018 04:58
Show Gist options
  • Save glhfgg1024/6d54faf29ccaf5dc7cca8034287e39e0 to your computer and use it in GitHub Desktop.
Save glhfgg1024/6d54faf29ccaf5dc7cca8034287e39e0 to your computer and use it in GitHub Desktop.
copy pretrained weights from "saved_model.pb" into a new model for finetuning or transer learning
import numpy as np
import tensorflow as flow
from tensorflow.python.saved_model import loader
# first, read the pretrained weights into a dictionary
variables = {}
g1 = tf.Graph()
with g1.as_default():
restore_from = 'pretrained_model/1513006564'
with tf.Session() as sess:
graph1 = loader.load(sess, ['serve'], restore_from)
for v in tf.global_variables():
variables[] = v.eval()
# build a new graph or import a graph from some graph definition
sess = tf.Session()
log_dir = 'logs'
step = 19000
saver = tf.train.import_meta_graph('{}/model.ckpt-{}.meta'.format(log_dir, step))
saver.restore(sess, '{}/model.ckpt-{}'.format(log_dir, step))
count = 0
for v in tf.global_variables():
for name in variables.keys():
if 'scope/'+name == # here, I add 'scope/' because I add a scope in the new graph, e.g., 'conv1/conv2d/kernel:0' --> 'scope/conv1/conv2d/kernel:0'
print('{}: {} ==> {}'.format(count, name,[name])) # here, we could assign the operator directory from numpy array (i.e., pretained weight)
count += 1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment