Create a gist now

Instantly share code, notes, and snippets.

What would you like to do?
TensorFlow Scope Decorator
# Working example for my blog post at:
# https://danijar.github.io/structuring-your-tensorflow-models
import functools
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
def doublewrap(function):
"""
A decorator decorator, allowing to use the decorator to be used without
parentheses if no arguments are provided. All arguments must be optional.
"""
@functools.wraps(function)
def decorator(*args, **kwargs):
if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
return function(args[0])
else:
return lambda wrapee: function(wrapee, *args, **kwargs)
return decorator
@doublewrap
def define_scope(function, scope=None, *args, **kwargs):
"""
A decorator for functions that define TensorFlow operations. The wrapped
function will only be executed once. Subsequent calls to it will directly
return the result so that operations are added to the graph only once.
The operations added by the function live within a tf.variable_scope(). If
this decorator is used with arguments, they will be forwarded to the
variable scope. The scope name defaults to the name of the wrapped
function.
"""
attribute = '_cache_' + function.__name__
name = scope or function.__name__
@property
@functools.wraps(function)
def decorator(self):
if not hasattr(self, attribute):
with tf.variable_scope(name, *args, **kwargs):
setattr(self, attribute, function(self))
return getattr(self, attribute)
return decorator
class Model:
def __init__(self, image, label):
self.image = image
self.label = label
self.prediction
self.optimize
self.error
@define_scope(initializer=tf.contrib.slim.xavier_initializer())
def prediction(self):
x = self.image
x = tf.contrib.slim.fully_connected(x, 200)
x = tf.contrib.slim.fully_connected(x, 200)
x = tf.contrib.slim.fully_connected(x, 10, tf.nn.softmax)
return x
@define_scope
def optimize(self):
logprob = tf.log(self.prediction + 1e-12)
cross_entropy = -tf.reduce_sum(self.label * logprob)
optimizer = tf.train.RMSPropOptimizer(0.03)
return optimizer.minimize(cross_entropy)
@define_scope
def error(self):
mistakes = tf.not_equal(
tf.argmax(self.label, 1), tf.argmax(self.prediction, 1))
return tf.reduce_mean(tf.cast(mistakes, tf.float32))
def main():
mnist = input_data.read_data_sets('./mnist/', one_hot=True)
image = tf.placeholder(tf.float32, [None, 784])
label = tf.placeholder(tf.float32, [None, 10])
model = Model(image, label)
sess = tf.Session()
sess.run(tf.initialize_all_variables())
for _ in range(10):
images, labels = mnist.test.images, mnist.test.labels
error = sess.run(model.error, {image: images, label: labels})
print('Test error {:6.2f}%'.format(100 * error))
for _ in range(60):
images, labels = mnist.train.next_batch(100)
sess.run(model.optimize, {image: images, label: labels})
if __name__ == '__main__':
main()
@lfwin

This comment has been minimized.

Show comment Hide comment
@lfwin

lfwin Jul 15, 2016

hi,
I debuged this code and found that:
1 if set breakpoint at line number 15 'setattr(self, attribute, function(self))' which is execute in Model initilization ' model = Model(data, target)' at line number 53, can step in this line, but if set breakpoint before excecution this set attribute line such as line num 52 or line num 25 and can not step into this set attribute line, why?
2 after executing to set attribute self.optimize at line number 25 and befor excuting self.error, self has three attributes, self._optimize, self._error, self._prediction, i don't know how this be generated?

lfwin commented Jul 15, 2016

hi,
I debuged this code and found that:
1 if set breakpoint at line number 15 'setattr(self, attribute, function(self))' which is execute in Model initilization ' model = Model(data, target)' at line number 53, can step in this line, but if set breakpoint before excecution this set attribute line such as line num 52 or line num 25 and can not step into this set attribute line, why?
2 after executing to set attribute self.optimize at line number 25 and befor excuting self.error, self has three attributes, self._optimize, self._error, self._prediction, i don't know how this be generated?

@danijar

This comment has been minimized.

Show comment Hide comment
@danijar

danijar Oct 4, 2016

Hi @lfwin, I don't know why your debugger can't step into setattr() statements. Maybe because it's a built-in function. Regarding your second question, that's the intention of the decorator. Please read my blog post for an explanation, the URL is at the beginning of the file.

Owner

danijar commented Oct 4, 2016

Hi @lfwin, I don't know why your debugger can't step into setattr() statements. Maybe because it's a built-in function. Regarding your second question, that's the intention of the decorator. Please read my blog post for an explanation, the URL is at the beginning of the file.

@mshvartsman

This comment has been minimized.

Show comment Hide comment
@mshvartsman

mshvartsman Jan 23, 2017

Hi -- is there a license on this code? I'd like to use it in a package that is Apache 2.0 licensed. Happy to have a dependency on it if you prefer to package it up instead.

Hi -- is there a license on this code? I'd like to use it in a package that is Apache 2.0 licensed. Happy to have a dependency on it if you prefer to package it up instead.

@rllin

This comment has been minimized.

Show comment Hide comment
@rllin

rllin Jan 26, 2017

Great design! Any idea how it works with tensorflow saver? How would you restore a checkpoint of a session that's baked into this class?

rllin commented Jan 26, 2017

Great design! Any idea how it works with tensorflow saver? How would you restore a checkpoint of a session that's baked into this class?

@mihaic

This comment has been minimized.

Show comment Hide comment
@mihaic

mihaic Jan 27, 2017

@danijar, is it OK to use this code in a project licensed under Apache 2.0?

mihaic commented Jan 27, 2017

@danijar, is it OK to use this code in a project licensed under Apache 2.0?

@pucktada

This comment has been minimized.

Show comment Hide comment
@pucktada

pucktada Apr 19, 2017

Hi -- will "define_scope" work when the wrapped function have argument others than self? for example, if i try

@define_scope
def prediction(self, batch_size):
...
and I get "TypeError: prediction() missing 1 required positional argument"

Hi -- will "define_scope" work when the wrapped function have argument others than self? for example, if i try

@define_scope
def prediction(self, batch_size):
...
and I get "TypeError: prediction() missing 1 required positional argument"

@yuh8

This comment has been minimized.

Show comment Hide comment
@yuh8

yuh8 May 23, 2017

@pucktada, you may want to redefine "def decorator(self)" to "def decorator(self, *args, **kwargs)". Likewise for function(self, *args, **kwargs) inside the wrapper function

yuh8 commented May 23, 2017

@pucktada, you may want to redefine "def decorator(self)" to "def decorator(self, *args, **kwargs)". Likewise for function(self, *args, **kwargs) inside the wrapper function

@spk921

This comment has been minimized.

Show comment Hide comment
@spk921

spk921 May 28, 2017

@pucktada Have you solve the problem? @yuh8 how can I redefine it?

spk921 commented May 28, 2017

@pucktada Have you solve the problem? @yuh8 how can I redefine it?

@lakshayg

This comment has been minimized.

Show comment Hide comment
@lakshayg

lakshayg May 31, 2017

@rllin Have you found a way to use Saver with this model?
@danijar It would we very useful if you can incorporate the ability to save and load trained model using this class. Do you have some ideas on how this can be done?

@rllin Have you found a way to use Saver with this model?
@danijar It would we very useful if you can incorporate the ability to save and load trained model using this class. Do you have some ideas on how this can be done?

@tartavull

This comment has been minimized.

Show comment Hide comment
@tartavull

tartavull Jul 16, 2017

Hey, I tried extending the idea of the decorator here to set the variable names as well

Hey, I tried extending the idea of the decorator here to set the variable names as well

@madvn

This comment has been minimized.

Show comment Hide comment
@madvn

madvn Sep 18, 2017

Thanks for this gist @danijar Could you please explain how the feed_dict in line 87 (or any similar line) works? Are the feed_dict keys same as the args in init ? Thanks.

madvn commented Sep 18, 2017

Thanks for this gist @danijar Could you please explain how the feed_dict in line 87 (or any similar line) works? Are the feed_dict keys same as the args in init ? Thanks.

@madvn

This comment has been minimized.

Show comment Hide comment
@madvn

madvn Sep 18, 2017

I got that part! But I have another question now - Is it possible to define the attributes such as "optimize" in another file that imports this class? If yes, how? Thanks.

madvn commented Sep 18, 2017

I got that part! But I have another question now - Is it possible to define the attributes such as "optimize" in another file that imports this class? If yes, how? Thanks.

@jren2017

This comment has been minimized.

Show comment Hide comment
@jren2017

jren2017 Oct 4, 2017

Hello, I changed the code a little bit.
I want to change the cost function dynamically in each iteration, the code is here https://github.com/jren2017/accessAccuracyFromLossFunc/blob/master/from_gist_example.py

I want to access the error of last iteration, use the error to calculate the cross-entropy, this sounds weird, but just an example.
Hope you can help have a look.
changed 65th row of the code:
logprob = tf.log(self.prediction + 1e-12) *(1-current_error) #Here changed ????????????

jren2017 commented Oct 4, 2017

Hello, I changed the code a little bit.
I want to change the cost function dynamically in each iteration, the code is here https://github.com/jren2017/accessAccuracyFromLossFunc/blob/master/from_gist_example.py

I want to access the error of last iteration, use the error to calculate the cross-entropy, this sounds weird, but just an example.
Hope you can help have a look.
changed 65th row of the code:
logprob = tf.log(self.prediction + 1e-12) *(1-current_error) #Here changed ????????????

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment