# Working example for my blog post at: | |
# https://danijar.github.io/structuring-your-tensorflow-models | |
import functools | |
import tensorflow as tf | |
from tensorflow.examples.tutorials.mnist import input_data | |
def doublewrap(function): | |
""" | |
A decorator decorator, allowing to use the decorator to be used without | |
parentheses if no arguments are provided. All arguments must be optional. | |
""" | |
@functools.wraps(function) | |
def decorator(*args, **kwargs): | |
if len(args) == 1 and len(kwargs) == 0 and callable(args[0]): | |
return function(args[0]) | |
else: | |
return lambda wrapee: function(wrapee, *args, **kwargs) | |
return decorator | |
@doublewrap | |
def define_scope(function, scope=None, *args, **kwargs): | |
""" | |
A decorator for functions that define TensorFlow operations. The wrapped | |
function will only be executed once. Subsequent calls to it will directly | |
return the result so that operations are added to the graph only once. | |
The operations added by the function live within a tf.variable_scope(). If | |
this decorator is used with arguments, they will be forwarded to the | |
variable scope. The scope name defaults to the name of the wrapped | |
function. | |
""" | |
attribute = '_cache_' + function.__name__ | |
name = scope or function.__name__ | |
@property | |
@functools.wraps(function) | |
def decorator(self): | |
if not hasattr(self, attribute): | |
with tf.variable_scope(name, *args, **kwargs): | |
setattr(self, attribute, function(self)) | |
return getattr(self, attribute) | |
return decorator | |
class Model: | |
def __init__(self, image, label): | |
self.image = image | |
self.label = label | |
self.prediction | |
self.optimize | |
self.error | |
@define_scope(initializer=tf.contrib.slim.xavier_initializer()) | |
def prediction(self): | |
x = self.image | |
x = tf.contrib.slim.fully_connected(x, 200) | |
x = tf.contrib.slim.fully_connected(x, 200) | |
x = tf.contrib.slim.fully_connected(x, 10, tf.nn.softmax) | |
return x | |
@define_scope | |
def optimize(self): | |
logprob = tf.log(self.prediction + 1e-12) | |
cross_entropy = -tf.reduce_sum(self.label * logprob) | |
optimizer = tf.train.RMSPropOptimizer(0.03) | |
return optimizer.minimize(cross_entropy) | |
@define_scope | |
def error(self): | |
mistakes = tf.not_equal( | |
tf.argmax(self.label, 1), tf.argmax(self.prediction, 1)) | |
return tf.reduce_mean(tf.cast(mistakes, tf.float32)) | |
def main(): | |
mnist = input_data.read_data_sets('./mnist/', one_hot=True) | |
image = tf.placeholder(tf.float32, [None, 784]) | |
label = tf.placeholder(tf.float32, [None, 10]) | |
model = Model(image, label) | |
sess = tf.Session() | |
sess.run(tf.initialize_all_variables()) | |
for _ in range(10): | |
images, labels = mnist.test.images, mnist.test.labels | |
error = sess.run(model.error, {image: images, label: labels}) | |
print('Test error {:6.2f}%'.format(100 * error)) | |
for _ in range(60): | |
images, labels = mnist.train.next_batch(100) | |
sess.run(model.optimize, {image: images, label: labels}) | |
if __name__ == '__main__': | |
main() |
This comment has been minimized.
This comment has been minimized.
Hi @lfwin, I don't know why your debugger can't step into |
This comment has been minimized.
This comment has been minimized.
mshvartsman
commented
Jan 23, 2017
Hi -- is there a license on this code? I'd like to use it in a package that is Apache 2.0 licensed. Happy to have a dependency on it if you prefer to package it up instead. |
This comment has been minimized.
This comment has been minimized.
rllin
commented
Jan 26, 2017
Great design! Any idea how it works with tensorflow saver? How would you restore a checkpoint of a session that's baked into this class? |
This comment has been minimized.
This comment has been minimized.
mihaic
commented
Jan 27, 2017
@danijar, is it OK to use this code in a project licensed under Apache 2.0? |
This comment has been minimized.
This comment has been minimized.
pucktada
commented
Apr 19, 2017
Hi -- will "define_scope" work when the wrapped function have argument others than self? for example, if i try
|
This comment has been minimized.
This comment has been minimized.
yuh8
commented
May 23, 2017
•
@pucktada, you may want to redefine "def decorator(self)" to "def decorator(self, *args, **kwargs)". Likewise for function(self, *args, **kwargs) inside the wrapper function |
This comment has been minimized.
This comment has been minimized.
spk921
commented
May 28, 2017
This comment has been minimized.
This comment has been minimized.
lakshayg
commented
May 31, 2017
This comment has been minimized.
This comment has been minimized.
tartavull
commented
Jul 16, 2017
Hey, I tried extending the idea of the decorator here to set the variable names as well |
This comment has been minimized.
This comment has been minimized.
madvn
commented
Sep 18, 2017
Thanks for this gist @danijar Could you please explain how the feed_dict in line 87 (or any similar line) works? Are the feed_dict keys same as the args in init ? Thanks. |
This comment has been minimized.
This comment has been minimized.
madvn
commented
Sep 18, 2017
I got that part! But I have another question now - Is it possible to define the attributes such as "optimize" in another file that imports this class? If yes, how? Thanks. |
This comment has been minimized.
This comment has been minimized.
jren2017
commented
Oct 4, 2017
Hello, I changed the code a little bit. I want to access the error of last iteration, use the error to calculate the cross-entropy, this sounds weird, but just an example. |
This comment has been minimized.
This comment has been minimized.
richdevboston
commented
Nov 13, 2018
I get the error message below when I run the example. How do I fix it? It appears to be missing the _gru_ops library. File "C:\ProgramData\Anaconda3\envs\TensorFlow\lib\site-packages\tensorflow\python\framework\load_library.py", line 56, in load_op_library NotFoundError: C:\ProgramData\Anaconda3\envs\TensorFlow\lib\site-packages\tensorflow\contrib\rnn\python\ops_gru_ops.so not found |
This comment has been minimized.
lfwin commentedJul 15, 2016
hi,
I debuged this code and found that:
1 if set breakpoint at line number 15 'setattr(self, attribute, function(self))' which is execute in Model initilization ' model = Model(data, target)' at line number 53, can step in this line, but if set breakpoint before excecution this set attribute line such as line num 52 or line num 25 and can not step into this set attribute line, why?
2 after executing to set attribute self.optimize at line number 25 and befor excuting self.error, self has three attributes, self._optimize, self._error, self._prediction, i don't know how this be generated?