Skip to content

Instantly share code, notes, and snippets.

@tdavchev
Created March 8, 2019 17:51
Show Gist options
  • Save tdavchev/4db0d01409ae0c8966b1ee7511fe567c to your computer and use it in GitHub Desktop.
Save tdavchev/4db0d01409ae0c8966b1ee7511fe567c to your computer and use it in GitHub Desktop.
def get_model_params(self):
# get trainable params.
model_names = []
model_params = []
model_shapes = []
with self.g.as_default():
t_vars = tf.trainable_variables()
for var in t_vars:
param_name = var.name
p = self.sess.run(var)
model_names.append(param_name)
params = np.round(p*10000).astype(np.int).tolist() # ..?!
model_params.append(params)
model_shapes.append(p.shape)
return model_params, model_shapes, model_names
def set_model_params(self, params):
with self.g.as_default():
trainable_vars = tf.trainable_variables()
idx = 0
for var in trainable_vars:
t_shape = self.sess.run(var).shape
p_ = np.array(params[idx])
assert t_shape == p_.shape, "inconsistent shape"
assign_op = var.assign(p_.astype(np.float)/10000.)
self.sess.run(assign_op)
idx += 1
def save_json(self, jsonfile='rnn.json'):
model_params, model_shapes, model_names = self.get_model_params()
qparams = []
for p in model_params:
qparams.append(p)
with open(jsonfile, 'wt') as outfile:
json.dump(qparams, outfile, sort_keys=True, indent=0, separators=(',', ': '))
def load_json(self, jsonfile='rnn.json'):
with open(jsonfile, 'r') as f:
params = json.load(f)
self.set_model_params(params)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment