import tensorflow as tf class Net: def __init__(self, session, name, multiplier): self.session = session self.name = name with tf.variable_scope(name): self.x1 = tf.placeholder(tf.int32) # Don't do this # tf.summary.histogram('in', self.x1) # Do this # Add summary to collection tf.summary.histogram('in', self.x1, collections=[self.name]) a1 = tf.constant(multiplier) self.y1 = tf.multiply(self.x1, a1) tf.summary.histogram('out', self.y1) def __enter__(self): self._prepare_log_dir() # Don't do this # self.merged_summery = tf.summary.merge_all() # Do this # Collect necessary summaries only by specifying collection key self.merged_summery = tf.summary.merge_all(key=self.name) self.train_writer = tf.summary.FileWriter(self.name, self.session.graph) def __exit__(self, exception_type, exception_value, traceback): self.train_writer.close() def predict(self, x): result, _ = sess.run([self.y1, self.merged_summery], feed_dict={self.x1: x}) return result def _prepare_log_dir(self): if tf.gfile.Exists(self.name): tf.gfile.DeleteRecursively(self.name) tf.gfile.MakeDirs(self.name) with tf.Session() as sess: net1 = Net(sess, 'graph-1', 2) net2 = Net(sess, 'graph-2', 3) with net1, net2: tf.global_variables_initializer().run() print(net1.predict(3)) print(net2.predict(3))