Created
September 21, 2018 17:05
-
-
Save ageron/1a92faa5e72134f49e6bcbcbeecfe8af to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
tf.enable_eager_execution() | |
import tensorflow.contrib.eager as tfe | |
# predict the future | |
tf.function = tfe.defun # I vote for that name | |
tf.method = tfe.defun # Okay for what we will use it for | |
tf.Checkpointable = tfe.Checkpointable | |
# 1. Nice, clean and object-oriented | |
print("Solution 1") | |
class MySimpleModel(tf.Checkpointable): | |
def __init__(self, multiplier): | |
self.w = tf.Variable(multiplier) | |
@tf.method | |
def call(self, x): | |
return tf.multiply(self.w, x) | |
my_simple_model = MySimpleModel(3.) | |
print(my_simple_model.call(tf.constant(4.))) # prints a tensor equal to 12. | |
# 2. Not too bad either | |
print("Solution 2") | |
w = tf.Variable(3.) | |
@tf.function | |
def my_simple_model(x): | |
return tf.multiply(w, x) | |
print(my_simple_model(4.)) # prints a tensor equal to 12. | |
# I still prefer solution 1, since it encourages users to use objects to manage | |
# variables. It's much cleaner. | |
# 3. Okay, if specific init logic is needed | |
print("Solution 3") | |
w = None | |
def initialize_w(multiplier): | |
global w | |
w = tf.Variable(multiplier) | |
@tf.function | |
def my_simple_model(x): | |
return tf.multiply(w, x) | |
initialize_w(3.) | |
print(my_simple_model(4.)) # prints a tensor equal to 12. | |
# 4. But this is horrible coding style IMHO | |
print("Solution 4") | |
w = None | |
@tf.function | |
def my_simple_model(x, multiplier=1.): | |
print("trace") | |
global w | |
if w is None: | |
w = tf.Variable(multiplier) | |
return tf.multiply(w, x) | |
print("first eager mode") | |
print(my_simple_model(4., multiplier=3.)) # prints a tensor equal to 12. | |
print(my_simple_model(5.)) # prints a tensor equal to 15. | |
print("then graph mode") | |
with tf.Graph().as_default() as g: | |
w = None | |
my_simple_model_4 = my_simple_model(tf.constant(4.)) | |
my_simple_model_5 = my_simple_model(tf.constant(5.)) | |
with tf.Session(graph=g) as sess: | |
w.initializer.run() | |
print(my_simple_model_4.eval()) | |
print(my_simple_model_5.eval()) | |
# This is horrible in many ways: use of global, unclear that the multiplier | |
# parameter will only be used once, complexity of multiple traces (although | |
# right now there's just one trace in graph mode, not two), etc. | |
# The only apparent benefit is that users can call `my_simple_model()` without | |
# having to explicitly initialize the variables first. I'm not sure this is a | |
# good thing, it's basically the singleton pattern, which is often brittle. | |
# If you really absolutely want to have a function that initializes the | |
# variables if needed, then solution 5 seems preferable to me (but still | |
# horrible). | |
# 5. If you really need this, then this solution seems less horrible | |
print("Solution 5") | |
w = None | |
@tf.function | |
def my_simple_model(x): | |
return w * x | |
def my_simple_model_maybe_init(x, multiplier=1.): | |
global w | |
if w is None: | |
w = tf.Variable(multiplier) | |
return my_simple_model(x) | |
print(my_simple_model_maybe_init(4., multiplier=3.)) # prints a tensor = 12. | |
print(my_simple_model_maybe_init(5.)) # prints a tensor = 15. | |
# Only solution 4 requires the double-trace magic. I don't see the benefit of | |
# supporting it, it just encourages poor programming style. | |
# I think the rule should be that no TF state should be created in a function | |
# or method decorated by tf.function or tf.method. |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment