Skip to content

Instantly share code, notes, and snippets.

@realwecan
Created January 17, 2018 11:27
Show Gist options
  • Save realwecan/2a30885f3adf84064fe0b82ae0e09ae0 to your computer and use it in GitHub Desktop.
Save realwecan/2a30885f3adf84064fe0b82ae0e09ae0 to your computer and use it in GitHub Desktop.
class DictRestore(SessionInit):
"""
Restore variables from a dictionary.
"""
def __init__(self, variable_dict):
"""
Args:
variable_dict (dict): a dict of {name: value}
"""
assert isinstance(variable_dict, dict), type(variable_dict)
# use varname (with :0) for consistency
self._prms = {get_op_tensor_name(n)[1]: v for n, v in six.iteritems(variable_dict)}
def _run_init(self, sess):
variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
variable_names = set([k.name for k in variables])
param_names = set(six.iterkeys(self._prms))
intersect = variable_names & param_names
logger.info("Variables to restore from dict: {}".format(', '.join(map(str, intersect))))
mismatch = MismatchLogger('graph', 'dict')
for k in sorted(variable_names - param_names):
if not is_training_name(k):
mismatch.add(k)
mismatch.log()
mismatch = MismatchLogger('dict', 'graph')
for k in sorted(param_names - variable_names):
mismatch.add(k)
mismatch.log()
upd = SessionUpdate(sess, [v for v in variables if v.name in intersect])
logger.info("Restoring from dict ...")
upd.update({name: value for name, value in six.iteritems(self._prms) if name in intersect})
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment