Skip to content

Instantly share code, notes, and snippets.

@ByungSunBae
Created December 5, 2017 09:18
Show Gist options
  • Star 7 You must be signed in to star a gist
  • Fork 3 You must be signed in to fork a gist
  • Save ByungSunBae/393071e46409737cd341360a69957906 to your computer and use it in GitHub Desktop.
Save ByungSunBae/393071e46409737cd341360a69957906 to your computer and use it in GitHub Desktop.
simple tensorflow graph edit example
# from : https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/graph_editor/examples/edit_graph_example.py
import numpy as np
import tensorflow as tf
from tensorflow.contrib import graph_editor as ge
# create a graph
g = tf.Graph()
with g.as_default():
a = tf.constant(1.0, shape=[2, 3], name="a")
b = tf.constant(2.0, shape=[2, 3], name="b")
a_pl = tf.placeholder(dtype=tf.float32)
b_pl = tf.placeholder(dtype=tf.float32)
c = tf.add(a_pl, b_pl, name="c")
# Once print operations,
g.get_operations()
# Out:
#[<tf.Operation 'a' type=Const>,
#<tf.Operation 'b' type=Const>,
#<tf.Operation 'Placeholder' type=Placeholder>,
#<tf.Operation 'Placeholder_1' type=Placeholder>,
#<tf.Operation 'c' type=Add>]
# modify the graph (input graph)
ge.swap_inputs(c.op, [a, b])
# Out:
#(<tensorflow.contrib.graph_editor.subgraph.SubGraphView at 0x7ff1938823c8>,
#<tensorflow.contrib.graph_editor.subgraph.SubGraphView at 0x7ff193882f98>)
# and print g.get_operations()
g.get_operations()
# Out:
#[<tf.Operation 'a' type=Const>,
#<tf.Operation 'b' type=Const>,
#<tf.Operation 'Placeholder' type=Placeholder>,
#<tf.Operation 'Placeholder_1' type=Placeholder>,
#<tf.Operation 'c' type=Add>]
# Same thing!
# print the graph def
print(g.as_graph_def())
# and print the value of c
with tf.Session(graph=g) as sess:
res = sess.run(c)
print(res)
# But graph_replace is different from swap_inputs.
# One more create a graph
g = tf.Graph()
with g.as_default():
a = tf.constant(1.0, shape=[2, 3], name="a")
b = tf.constant(2.0, shape=[2, 3], name="b")
a_pl = tf.placeholder(dtype=tf.float32)
b_pl = tf.placeholder(dtype=tf.float32)
c = tf.add(a_pl, b_pl, name="c")
c_ = ge.graph_replace(c, {a_pl: a, b_pl: b})
# and print g.get_operations()
g.get_operations()
#Out:
#[<tf.Operation 'a' type=Const>,
#<tf.Operation 'b' type=Const>,
#<tf.Operation 'Placeholder' type=Placeholder>,
#<tf.Operation 'Placeholder_1' type=Placeholder>,
#<tf.Operation 'c' type=Add>,
#<tf.Operation 'c_1' type=Add>]
# We can see 'c_1' op that is added.
# and print the value of c
with tf.Session(graph=g) as sess:
res_ = sess.run(c_)
print(res_)
# same result of res
@helinwang
Copy link

helinwang commented May 20, 2019

By # Same thing! on line 41 did you mean ge.swap_inputs(c.op, [a, b]) did not do anything?
If you check the graphDef, it actually changes

node {
  name: "c"
  op: "Add"
  input: "Placeholder"
  input: "Placeholder_1"
  attr {
    key: "T"
    value {
      type: DT_FLOAT
    }
  }
}

to

node {
  name: "c"
  op: "Add"
  input: "a"
  input: "b"
  attr {
    key: "T"
    value {
      type: DT_FLOAT
    }
  }
}

Edit: oh, I guess you wanted to highlight the difference between ge.swap_inputs and ge.graph_replace. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment