-
-
Save danijar/8663d3bbfd586bffecf6a0094cd116f2 to your computer and use it in GitHub Desktop.
# Working example for my blog post at: | |
# https://danijar.github.io/structuring-your-tensorflow-models | |
import functools | |
import tensorflow as tf | |
from tensorflow.examples.tutorials.mnist import input_data | |
def doublewrap(function): | |
""" | |
A decorator decorator, allowing to use the decorator to be used without | |
parentheses if no arguments are provided. All arguments must be optional. | |
""" | |
@functools.wraps(function) | |
def decorator(*args, **kwargs): | |
if len(args) == 1 and len(kwargs) == 0 and callable(args[0]): | |
return function(args[0]) | |
else: | |
return lambda wrapee: function(wrapee, *args, **kwargs) | |
return decorator | |
@doublewrap | |
def define_scope(function, scope=None, *args, **kwargs): | |
""" | |
A decorator for functions that define TensorFlow operations. The wrapped | |
function will only be executed once. Subsequent calls to it will directly | |
return the result so that operations are added to the graph only once. | |
The operations added by the function live within a tf.variable_scope(). If | |
this decorator is used with arguments, they will be forwarded to the | |
variable scope. The scope name defaults to the name of the wrapped | |
function. | |
""" | |
attribute = '_cache_' + function.__name__ | |
name = scope or function.__name__ | |
@property | |
@functools.wraps(function) | |
def decorator(self): | |
if not hasattr(self, attribute): | |
with tf.variable_scope(name, *args, **kwargs): | |
setattr(self, attribute, function(self)) | |
return getattr(self, attribute) | |
return decorator | |
class Model: | |
def __init__(self, image, label): | |
self.image = image | |
self.label = label | |
self.prediction | |
self.optimize | |
self.error | |
@define_scope(initializer=tf.contrib.slim.xavier_initializer()) | |
def prediction(self): | |
x = self.image | |
x = tf.contrib.slim.fully_connected(x, 200) | |
x = tf.contrib.slim.fully_connected(x, 200) | |
x = tf.contrib.slim.fully_connected(x, 10, tf.nn.softmax) | |
return x | |
@define_scope | |
def optimize(self): | |
logprob = tf.log(self.prediction + 1e-12) | |
cross_entropy = -tf.reduce_sum(self.label * logprob) | |
optimizer = tf.train.RMSPropOptimizer(0.03) | |
return optimizer.minimize(cross_entropy) | |
@define_scope | |
def error(self): | |
mistakes = tf.not_equal( | |
tf.argmax(self.label, 1), tf.argmax(self.prediction, 1)) | |
return tf.reduce_mean(tf.cast(mistakes, tf.float32)) | |
def main(): | |
mnist = input_data.read_data_sets('./mnist/', one_hot=True) | |
image = tf.placeholder(tf.float32, [None, 784]) | |
label = tf.placeholder(tf.float32, [None, 10]) | |
model = Model(image, label) | |
sess = tf.Session() | |
sess.run(tf.initialize_all_variables()) | |
for _ in range(10): | |
images, labels = mnist.test.images, mnist.test.labels | |
error = sess.run(model.error, {image: images, label: labels}) | |
print('Test error {:6.2f}%'.format(100 * error)) | |
for _ in range(60): | |
images, labels = mnist.train.next_batch(100) | |
sess.run(model.optimize, {image: images, label: labels}) | |
if __name__ == '__main__': | |
main() |
Hey, I tried extending the idea of the decorator here to set the variable names as well
Thanks for this gist @danijar Could you please explain how the feed_dict in line 87 (or any similar line) works? Are the feed_dict keys same as the args in init ? Thanks.
I got that part! But I have another question now - Is it possible to define the attributes such as "optimize" in another file that imports this class? If yes, how? Thanks.
Hello, I changed the code a little bit.
I want to change the cost function dynamically in each iteration, the code is here https://github.com/jren2017/accessAccuracyFromLossFunc/blob/master/from_gist_example.py
I want to access the error of last iteration, use the error to calculate the cross-entropy, this sounds weird, but just an example.
Hope you can help have a look.
changed 65th row of the code:
logprob = tf.log(self.prediction + 1e-12) *(1-current_error) #Here changed ????????????
I get the error message below when I run the example. How do I fix it? It appears to be missing the _gru_ops library.
Great article. Thank you.
File "C:\ProgramData\Anaconda3\envs\TensorFlow\lib\site-packages\tensorflow\python\framework\load_library.py", line 56, in load_op_library
lib_handle = py_tf.TF_LoadLibrary(library_filename)
NotFoundError: C:\ProgramData\Anaconda3\envs\TensorFlow\lib\site-packages\tensorflow\contrib\rnn\python\ops_gru_ops.so not found
@rllin Have you found a way to use Saver with this model?
@danijar It would we very useful if you can incorporate the ability to save and load trained model using this class. Do you have some ideas on how this can be done?