Skip to content

Instantly share code, notes, and snippets.

@vierja
Created July 13, 2017 03:11
Show Gist options
  • Save vierja/668d8985746b0d5ba15fc038e130d150 to your computer and use it in GitHub Desktop.
Save vierja/668d8985746b0d5ba15fc038e130d150 to your computer and use it in GitHub Desktop.
import sonnet as snt
import tensorflow as tf
class SubmoduleA(snt.AbstractModule):
def __init__(self, name='submodule_a'):
super().__init__(name=name)
def _build(self):
return tf.Variable(tf.constant(1.0))
class SubmoduleB(snt.AbstractModule):
def __init__(self, name='submodule_b'):
super().__init__(name=name)
def _build(self):
return tf.Variable(tf.constant(2.0))
class MainModule(snt.AbstractModule):
def __init__(self, in_build=False, name='main_module'):
super().__init__(name=name)
self._in_build = in_build
if not self._in_build:
with self._enter_variable_scope():
self._submodule_a = SubmoduleA()
self._submodule_b = SubmoduleB()
def _build(self):
if self._in_build:
self._submodule_a = SubmoduleA()
self._submodule_b = SubmoduleB()
return self._submodule_a() + self._submodule_b()
if __name__ == '__main__':
g1 = tf.Graph()
with g1.as_default():
model = MainModule()
res = model()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print([v.name for v in tf.global_variables()])
g2 = tf.Graph()
with g2.as_default():
model = MainModule(in_build=True)
res = model()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print([v.name for v in tf.global_variables()])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment