Skip to content

Instantly share code, notes, and snippets.

@alexwal
Forked from danijar/share_variables_decorator.py
Created December 27, 2018 22:19
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save alexwal/7a6855158e64630113081fb3b7e6e494 to your computer and use it in GitHub Desktop.
Save alexwal/7a6855158e64630113081fb3b7e6e494 to your computer and use it in GitHub Desktop.
TensorFlow decorator to share variables between calls. Works for both functions and methods.
import functools
import tensorflow as tf
class share_variables(object):
def __init__(self, callable_):
self._callable = callable_
self._wrappers = {}
self._wrapper = None
def __call__(self, *args, **kwargs):
return self._function_wrapper(*args, **kwargs)
def __get__(self, instance, owner):
decorator = self._method_wrapper
decorator = functools.partial(decorator, instance)
decorator = functools.wraps(self._callable)(decorator)
return decorator
def _method_wrapper(self, instance, *args, **kwargs):
if instance not in self._wrappers:
name = self._create_name(
type(instance).__module__,
type(instance).__name__,
instance.name if hasattr(instance, 'name') else id(instance),
self._callable.__name__)
self._wrappers[instance] = tf.make_template(
name, self._callable, create_scope_now_=True)
return self._wrappers[instance](instance, *args, **kwargs)
def _function_wrapper(self, *args, **kwargs):
if not self._wrapper:
name = self._create_name(
self._callable.__module__,
self._callable.__name__)
self._wrapper = tf.make_template(
name, self._callable, create_scope_now_=True)
return self._wrapper(*args, **kwargs)
def _create_name(self, *words):
words = [str(word) for word in words]
words = [word.replace('_', '') for word in words]
return '_'.join(words)
class Model(object):
def __init__(self, name):
self.name = name
@share_variables
def method(self, data):
return tf.layers.dense(data, 100)
@share_variables
def function(data):
return tf.layers.dense(data, 50)
data = tf.placeholder(tf.float32, [None, 50])
function(data)
function(data)
foo = Model('foo')
foo.method(data)
foo.method(data)
bar = Model('bar')
bar.method(data)
for var in tf.trainable_variables():
print(var.name)
# Output:
# main_function/dense/kernel:0
# main_function/dense/bias:0
# main_Model_foo_method/dense/kernel:0
# main_Model_foo_method/dense/bias:0
# main_Model_bar_method/dense/kernel:0
# main_Model_bar_method/dense/bias:0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment