Skip to content

Instantly share code, notes, and snippets.

@MInner
Created April 4, 2017 19:01
Show Gist options
  • Save MInner/19ba25f03ca4e112a5c5f9390ccd2ff9 to your computer and use it in GitHub Desktop.
Save MInner/19ba25f03ca4e112a5c5f9390ccd2ff9 to your computer and use it in GitHub Desktop.
One way of restoring weights when computation graph structure has changed
# example code for EC500 K1 / CS591 S2 Deep Learning (Spring 2017)
import tensorflow as tf
import numpy as np
def main_save():
with tf.Graph().as_default() as g:
with tf.variable_scope('to_save'):
a = tf.get_variable('a_name', [100, 100])
a_dense_original = tf.layers.dense(a, 10)
a_another_dense_original = tf.layers.dense(a, 20)
b = tf.get_variable('b', [200, 200])
to_save_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'to_save')
for v in to_save_vars:
print(v.name, tf.shape(v).value)
saver = tf.train.Saver(to_save_vars)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver.save(sess, './saved.ckp')
print(np.sum(sess.run(a)))
def main_load():
with tf.Graph().as_default() as g:
with tf.variable_scope('to_save'):
a = tf.get_variable('a_name', [100, 100])
a_dense_restored = tf.layers.dense(a, 10)
a_another_dense_restored = tf.layers.dense(a, 20)
c = tf.get_variable('b', [300, 300])
to_save_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'to_save')
for v in to_save_vars:
print(v.name, tf.shape(v))
saver = tf.train.Saver(to_save_vars)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver.restore(sess, './saved.ckp')
print(np.sum(sess.run(a)))
def main():
main_save()
main_load()
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment