Skip to content

Instantly share code, notes, and snippets.

@dmmiller612
Created December 11, 2017 15:42
Show Gist options
  • Save dmmiller612/75e0f515083d3c61f0024a0e22662804 to your computer and use it in GitHub Desktop.
Save dmmiller612/75e0f515083d3c61f0024a0e22662804 to your computer and use it in GitHub Desktop.
Tensorflow Serialize and Deserialize GraphDef and set weights.
import tensorflow as tf
import numpy as np
from google.protobuf import json_format
np.random.seed(12345)
def tensorflow_get_weights():
"""
@author https://github.com/maxim5
"""
vs = tf.trainable_variables()
values = tf.get_default_session().run(vs)
return zip(vs, values)
def tensorflow_set_weights(weights):
"""
@author https://github.com/maxim5
"""
assign_ops = []
feed_dict = {}
for var, value in weights:
value = np.asarray(value)
assign_placeholder = tf.placeholder(var.dtype, shape=value.shape)
assign_op = var.assign(assign_placeholder)
assign_ops.append(assign_op)
feed_dict[assign_placeholder] = value
tf.get_default_session().run(assign_ops, feed_dict=feed_dict)
def create_simple_graph():
"""
Creates a very simple xor graph
"""
x = tf.placeholder(tf.float32, shape=[None, 2], name='x')
layer1 = tf.layers.dense(x, 12, activation=tf.nn.relu)
layer2 = tf.layers.dense(layer1, 7, activation=tf.nn.relu)
out = tf.layers.dense(layer2, 1, name='outer', activation=tf.nn.sigmoid)
opt = tf.train.AdamOptimizer(learning_rate=.01)
y = tf.placeholder(tf.float32, shape=[None, 1], name='y')
loss = tf.reduce_mean(tf.square(y - out))
mini = opt.minimize(loss, global_step=tf.train.get_or_create_global_step(), name='mini')
return mini
def retrieve_xor():
"""
Grabs xor data
"""
xor = [(0.0, np.array([0.0, 0.0])),
(0.0, np.array([1.0, 1.0])),
(1.0, np.array([1.0, 0.0])),
(1.0, np.array([0.0, 1.0]))]
a = np.asarray([x for y, x in xor])
b = np.asarray([y for y, _ in xor]).reshape((4, 1))
return a, b
def run_initial(opti, feed_dict):
"""
Run the session for the first time
"""
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for i in range(0, 10):
sess.run(opti, feed_dict=feed_dict)
first_weights = tensorflow_get_weights()
g = tf.get_default_graph().as_graph_def()
json_string = json_format.MessageToJson(g)
return json_string, first_weights
def run_serialized(json_graph, weights, feed_dict):
"""
deserialize graph and run it again
"""
gd = tf.GraphDef()
gd = json_format.Parse(json_graph, gd)
with tf.Session() as sess:
tf.import_graph_def(gd)
sess.run(tf.global_variables_initializer())
nu_out = tf.get_default_graph().get_tensor_by_name('outer/Sigmoid:0')
mini = tf.get_default_graph().get_tensor_by_name('mini:0')
tensorflow_set_weights(weights)
for i in range(0, 100):
sess.run(mini, feed_dict=feed_dict)
predicted = sess.run(nu_out, feed_dict=feed_dict)
return predicted
def run_with_no_serialized_weights():
"""
weights are not turned into json
"""
initial_graph = create_simple_graph()
a,b = retrieve_xor()
feed_dict = {'x:0': a, 'y:0': b}
json_graph, weights = run_initial(initial_graph, feed_dict)
predictions = run_serialized(json_graph, weights, feed_dict)
return predictions
if __name__ == "__main__":
print(run_with_no_serialized_weights())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment