Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save luckmoon/3ff508298d6ae10fac52fc4d33918499 to your computer and use it in GitHub Desktop.
Save luckmoon/3ff508298d6ae10fac52fc4d33918499 to your computer and use it in GitHub Desktop.
copy pretrained weights from "saved_model.pb" into a new model for finetuning or transer learning
import numpy as np
import tensorflow as flow
from tensorflow.python.saved_model import loader
# first, read the pretrained weights into a dictionary
variables = {}
g1 = tf.Graph()
with g1.as_default():
restore_from = 'pretrained_model/1513006564'
with tf.Session() as sess:
graph1 = loader.load(sess, ['serve'], restore_from)
for v in tf.global_variables():
variables[v.name] = v.eval()
# build a new graph or import a graph from some graph definition
...
sess = tf.Session()
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
log_dir = 'logs'
step = 19000
saver = tf.train.import_meta_graph('{}/model.ckpt-{}.meta'.format(log_dir, step))
saver.restore(sess, '{}/model.ckpt-{}'.format(log_dir, step))
count = 0
for v in tf.global_variables():
for name in variables.keys():
if 'scope/'+name == v.name: # here, I add 'scope/' because I add a scope in the new graph, e.g., 'conv1/conv2d/kernel:0' --> 'scope/conv1/conv2d/kernel:0'
print('{}: {} ==> {}'.format(count, name, v.name))
sess.run(v.assign(variables[name])) # here, we could assign the operator directory from numpy array (i.e., pretained weight)
count += 1
break
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment