Created
April 28, 2020 14:23
-
-
Save JackInTaiwan/16824d2d581c95f91a75f0e7002750f8 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
### Save the graph g1 | |
g1 = tf.Graph() | |
with g1.as_default(): | |
a = tf.placeholder(tf.float32, name='a') | |
b = tf.Variable(initial_value=tf.truncated_normal((1,)), name='b') | |
c = tf.multiply(a, b, name='c') | |
s1 = tf.train.Saver() | |
with tf.compat.v1.Session(graph=g1) as sess: | |
sess.run(tf.global_variables_initializer()) | |
s1.save(sess, 'g1') | |
### Load the graph g2 and manipulate it | |
g2 = tf.Graph() | |
with g2.as_default(): | |
tf.compat.v1.train.import_meta_graph("g1.meta") | |
c = g2.get_tensor_by_name("c:0") | |
d = tf.multiply(c, c, name='d') | |
with tf.compat.v1.Session(graph=g2) as sess: | |
d = tf.get_default_graph().get_tensor_by_name("d:0") | |
sess.run(tf.global_variables_initializer()) | |
d = sess.run(d, feed_dict={"a:0": 32}) | |
print(d) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment