Skip to content

Instantly share code, notes, and snippets.

@aecker
Created November 4, 2016 12:24
Show Gist options
  • Save aecker/c379663c84fb60545a3f2921aed4b539 to your computer and use it in GitHub Desktop.
Save aecker/c379663c84fb60545a3f2921aed4b539 to your computer and use it in GitHub Desktop.
Tensorflow custom saver that keeps VGG variables out of checkpoint files
class MySaver(tf.train.Saver):
def __init__(self, var_list, extra_vars=None, extra_chkpt_file=None, **kwargs):
super().__init__(var_list=var_list, **kwargs)
self.extra_chkpt_file = extra_chkpt_file
self.extra_saver = tf.train.Saver(var_list=extra_vars)
def restore(self, sess, save_path):
super().restore(sess, save_path)
self.extra_saver.restore(sess, self.extra_chkpt_file)
# For training
variables_to_restore = slim.get_variables_to_restore(exclude=['vgg_19'])
vgg_vars = slim.get_model_variables('vgg_19')
saver = MySaver(var_list=variables_to_restore, extra_vars=vgg_vars, extra_chkpt_file=VGG_CHECKPOINT_FILE)
slim.learning.train(self.train_op, self.log_dir, saver=saver)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment