Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
TensorFlow decorator to share variables between calls. Works for both functions and methods.
import functools
import tensorflow as tf
class share_variables(object):
def __init__(self, callable_):
self._callable = callable_
self._wrappers = {}
self._wrapper = None
def __call__(self, *args, **kwargs):
return self._function_wrapper(*args, **kwargs)
def __get__(self, instance, owner):
decorator = self._method_wrapper
decorator = functools.partial(decorator, instance)
decorator = functools.wraps(self._callable)(decorator)
return decorator
def _method_wrapper(self, instance, *args, **kwargs):
if instance not in self._wrappers:
name = self._create_name(
type(instance).__module__,
type(instance).__name__,
instance.name if hasattr(instance, 'name') else id(instance),
self._callable.__name__)
self._wrappers[instance] = tf.make_template(
name, self._callable, create_scope_now_=True)
return self._wrappers[instance](instance, *args, **kwargs)
def _function_wrapper(self, *args, **kwargs):
if not self._wrapper:
name = self._create_name(
self._callable.__module__,
self._callable.__name__)
self._wrapper = tf.make_template(
name, self._callable, create_scope_now_=True)
return self._wrapper(*args, **kwargs)
def _create_name(self, *words):
words = [str(word) for word in words]
words = [word.replace('_', '') for word in words]
return '_'.join(words)
class Model(object):
def __init__(self, name):
self.name = name
@share_variables
def method(self, data):
return tf.layers.dense(data, 100)
@share_variables
def function(data):
return tf.layers.dense(data, 50)
data = tf.placeholder(tf.float32, [None, 50])
function(data)
function(data)
foo = Model('foo')
foo.method(data)
foo.method(data)
bar = Model('bar')
bar.method(data)
for var in tf.trainable_variables():
print(var.name)
# Output:
# main_function/dense/kernel:0
# main_function/dense/bias:0
# main_Model_foo_method/dense/kernel:0
# main_Model_foo_method/dense/bias:0
# main_Model_bar_method/dense/kernel:0
# main_Model_bar_method/dense/bias:0
@albertz

This comment has been minimized.

Copy link

commented Jan 15, 2018

So this will create different variable names on each execution because of the usage of id() for the variable names. This makes it hard to store/load models from disk.

@danijar

This comment has been minimized.

Copy link
Owner Author

commented Jan 15, 2018

@albertz Yes, I use the object ID for the variable scope, so that different instances of the same class have their own variables. You could make it a convention that model classes must implement a model.name attribute and use that for the scope name. Or you only save/load the variables inside a scope using export_scoped_meta_graph() and import_scoped_meta_graph().

@danijar

This comment has been minimized.

Copy link
Owner Author

commented Jun 26, 2018

I've updated the code to include a fix and to use the self.name attribute of model classes if available, and fall back to id(self) otherwise.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.