Skip to content

Instantly share code, notes, and snippets.

@vierja
Created July 15, 2017 23:20
Show Gist options
  • Save vierja/d0879ff294a784758730fa92a4973005 to your computer and use it in GitHub Desktop.
Save vierja/d0879ff294a784758730fa92a4973005 to your computer and use it in GitHub Desktop.
import os
import sonnet as snt
import tensorflow as tf
LOG_DIR = '/tmp/'
class MainModule(snt.AbstractModule):
def __init__(self, in_build):
super().__init__()
self._in_build = in_build
if not self._in_build:
with self._enter_variable_scope():
self._instantiate_layers()
def _instantiate_layers(self):
self._linear = snt.Linear(10)
def _build(self, inputs):
if self._in_build:
self._instantiate_layers()
return self._linear(inputs)
if __name__ == '__main__':
inputs = tf.placeholder(tf.float32, [None, 10])
model = MainModule(in_build=True)
outputs = model(inputs)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
writer = tf.summary.FileWriter(
os.path.join(LOG_DIR, 'instance_in_build'), sess.graph
)
print([v.name for v in tf.global_variables()])
tf.reset_default_graph()
inputs = tf.placeholder(tf.float32, [None, 10])
model = MainModule(in_build=False)
outputs = model(inputs)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
writer = tf.summary.FileWriter(
os.path.join(LOG_DIR, 'instance_in_init'), sess.graph
)
print([v.name for v in tf.global_variables()])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment