Created
February 13, 2019 18:33
-
-
Save xilenteyex/97fafd210d73b30443db7dbdb73e6c80 to your computer and use it in GitHub Desktop.
Toy example for adding control dependencies after graph creation. Note that to run this, you must have 2GPUs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Placeholder_2:0 | |
Placeholder_4:0 | |
Placeholder_6:0 | |
Placeholder_8:0 | |
Placeholder_10:0 | |
Placeholder_12:0 | |
Placeholder_14:0 | |
Placeholder_16:0 | |
Placeholder_3:0 | |
Placeholder_5:0 | |
Placeholder_7:0 | |
Placeholder_9:0 | |
Placeholder_11:0 | |
Placeholder_13:0 | |
Placeholder_15:0 | |
Placeholder_17:0 | |
Placeholder:0 | |
Placeholder_1:0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
X0:0 | |
X1:0 | |
X2:0 | |
X3:0 | |
X4:0 | |
X5:0 | |
X6:0 | |
X7:0 | |
Y0:0 | |
Y1:0 | |
Y2:0 | |
Y3:0 | |
Y4:0 | |
Y5:0 | |
Y6:0 | |
Y7:0 | |
random_uniform:0 | |
random_uniform_1:0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
MatMul_5 | |
MatMul_8 | |
MatMul_11 | |
MatMul_14 | |
MatMul_17 | |
MatMul_20 | |
MatMul_23 | |
MatMul_26 | |
MatMul_2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
import json | |
import sys | |
from tensorflow.python.client import timeline | |
from protobuf_to_dict import protobuf_to_dict | |
from google.protobuf.json_format import MessageToJson | |
from tensorflow.core.protobuf import rewriter_config_pb2 | |
fname_postfix = 'frozengraph_modified' | |
def get_ops_list(fname): | |
with open(fname) as f: | |
op_names = f.readlines() | |
# you may also want to remove whitespace characters like `\n` at the end of each line | |
return [tf.get_default_graph().get_operation_by_name(x.strip()) for x in op_names] | |
def get_tensor_list(fname): | |
with open(fname) as f: | |
op_names = f.readlines() | |
# you may also want to remove whitespace characters like `\n` at the end of each line | |
return [tf.get_default_graph().get_tensor_by_name(x.strip()) for x in op_names] | |
new_saver = tf.train.import_meta_graph('modified.meta') | |
config_proto = tf.ConfigProto(graph_options=tf.GraphOptions(build_cost_model=1)) | |
config_proto.intra_op_parallelism_threads = 1 | |
config_proto.inter_op_parallelism_threads = 1 | |
config_proto.graph_options.optimizer_options.opt_level = -1 | |
config_proto.graph_options.rewrite_options.constant_folding = (rewriter_config_pb2.RewriterConfig.OFF) | |
config_proto.graph_options.rewrite_options.arithmetic_optimization = (rewriter_config_pb2.RewriterConfig.OFF) | |
sess = tf.Session(config=config_proto) | |
sess.run(tf.global_variables_initializer()) | |
place_holder_vals = sess.run(get_tensor_list('to_exec1')) | |
for out in place_holder_vals: | |
print(out.shape) | |
place_holder_names = get_tensor_list('feed_dict') | |
matmul_ops = get_ops_list('to_exec2') | |
run_metadata = tf.RunMetadata() | |
run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE, output_partition_graphs=True) | |
W_ = sess.run(matmul_ops, | |
{_i: i_ for _i, i_ in zip(place_holder_names, place_holder_vals)}, | |
options=run_options, | |
run_metadata=run_metadata) | |
jsonObj = MessageToJson(run_metadata) | |
with open('metadata_%s.json' % (fname_postfix), 'w') as outfile: | |
json.dump(jsonObj, outfile) | |
trace = timeline.Timeline(step_stats=run_metadata.step_stats) | |
trace_file = open('timeline_%s.ctf.json' % (fname_postfix), 'w') | |
trace_file.write(trace.generate_chrome_trace_format()) | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
import json | |
import sys | |
from tensorflow.python.client import timeline | |
from protobuf_to_dict import protobuf_to_dict | |
from google.protobuf.json_format import MessageToJson | |
from tensorflow.core.protobuf import rewriter_config_pb2 | |
from tensorflow.python.framework import graph_io | |
import os | |
dim = 16384 | |
with tf.device('/gpu:0'): | |
X9 = tf.random_uniform([dim, dim], 0, 10) | |
_X9 = tf.placeholder(dtype=tf.float32, shape=[dim, dim]) | |
Z19 = tf.matmul(_X9, _X9) | |
with tf.device('/gpu:1'): | |
Y9 = tf.random_uniform([dim, dim], 0, 10) | |
_Y9 = tf.placeholder(dtype=tf.float32, shape=[dim, dim]) | |
Z29 = tf.matmul(_Y9, _Y9) | |
W9 = tf.matmul(Z19, Z29) | |
n = 8 | |
dim = 32 | |
Z1, Z2, W = [], [], [] | |
X, _X, Y, _Y = [], [], [], [] | |
with tf.device('/gpu:0'): | |
for i in range(n): | |
dim *= 2 | |
X.append(tf.random_uniform([dim, dim], 0, 10, name='X' + str(i))) | |
Y.append(tf.random_uniform([dim, dim], 0, 10, name='Y' + str(i))) | |
_X.append(tf.placeholder(dtype=tf.float32, shape=[dim, dim])) | |
Z1.append(tf.matmul(_X[i], _X[i])) | |
_Y.append(tf.placeholder(dtype=tf.float32, shape=[dim, dim])) | |
Z2.append(tf.matmul(_Y[i], _Y[i])) | |
W.append(tf.matmul(Z1[i], Z2[i])) | |
meta_graph = tf.train.export_meta_graph() | |
nodes = meta_graph.graph_def.node | |
nodes_with_dep = [u'Placeholder_2', u'Placeholder_3', u'Placeholder_4', u'Placeholder_5', u'Placeholder_6', u'Placeholder_7', u'Placeholder_8', u'Placeholder_9', u'Placeholder_10', u'Placeholder_11', u'Placeholder_12', u'Placeholder_13', u'Placeholder_14', u'Placeholder_15', u'Placeholder_16', u'Placeholder_17'] | |
deps = ['^MatMul'] | |
for node in nodes: | |
print(node.device) | |
node.device = u'/device:GPU:0' | |
print(type(node.device)) | |
if node.name in nodes_with_dep: | |
print('added') | |
print(node.name) | |
node.input.extend(deps) | |
print("save_graph") | |
filename = 'modified_cdep_graph.meta' | |
graph_io.write_graph( | |
meta_graph, | |
os.path.dirname(filename), | |
os.path.basename(filename), | |
as_text=False) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment