Skip to content

Instantly share code, notes, and snippets.

@ageron
Created September 21, 2018 17:05
Show Gist options
  • Save ageron/1a92faa5e72134f49e6bcbcbeecfe8af to your computer and use it in GitHub Desktop.
Save ageron/1a92faa5e72134f49e6bcbcbeecfe8af to your computer and use it in GitHub Desktop.
import tensorflow as tf
tf.enable_eager_execution()
import tensorflow.contrib.eager as tfe
# predict the future
tf.function = tfe.defun # I vote for that name
tf.method = tfe.defun # Okay for what we will use it for
tf.Checkpointable = tfe.Checkpointable
# 1. Nice, clean and object-oriented
print("Solution 1")
class MySimpleModel(tf.Checkpointable):
def __init__(self, multiplier):
self.w = tf.Variable(multiplier)
@tf.method
def call(self, x):
return tf.multiply(self.w, x)
my_simple_model = MySimpleModel(3.)
print(my_simple_model.call(tf.constant(4.))) # prints a tensor equal to 12.
# 2. Not too bad either
print("Solution 2")
w = tf.Variable(3.)
@tf.function
def my_simple_model(x):
return tf.multiply(w, x)
print(my_simple_model(4.)) # prints a tensor equal to 12.
# I still prefer solution 1, since it encourages users to use objects to manage
# variables. It's much cleaner.
# 3. Okay, if specific init logic is needed
print("Solution 3")
w = None
def initialize_w(multiplier):
global w
w = tf.Variable(multiplier)
@tf.function
def my_simple_model(x):
return tf.multiply(w, x)
initialize_w(3.)
print(my_simple_model(4.)) # prints a tensor equal to 12.
# 4. But this is horrible coding style IMHO
print("Solution 4")
w = None
@tf.function
def my_simple_model(x, multiplier=1.):
print("trace")
global w
if w is None:
w = tf.Variable(multiplier)
return tf.multiply(w, x)
print("first eager mode")
print(my_simple_model(4., multiplier=3.)) # prints a tensor equal to 12.
print(my_simple_model(5.)) # prints a tensor equal to 15.
print("then graph mode")
with tf.Graph().as_default() as g:
w = None
my_simple_model_4 = my_simple_model(tf.constant(4.))
my_simple_model_5 = my_simple_model(tf.constant(5.))
with tf.Session(graph=g) as sess:
w.initializer.run()
print(my_simple_model_4.eval())
print(my_simple_model_5.eval())
# This is horrible in many ways: use of global, unclear that the multiplier
# parameter will only be used once, complexity of multiple traces (although
# right now there's just one trace in graph mode, not two), etc.
# The only apparent benefit is that users can call `my_simple_model()` without
# having to explicitly initialize the variables first. I'm not sure this is a
# good thing, it's basically the singleton pattern, which is often brittle.
# If you really absolutely want to have a function that initializes the
# variables if needed, then solution 5 seems preferable to me (but still
# horrible).
# 5. If you really need this, then this solution seems less horrible
print("Solution 5")
w = None
@tf.function
def my_simple_model(x):
return w * x
def my_simple_model_maybe_init(x, multiplier=1.):
global w
if w is None:
w = tf.Variable(multiplier)
return my_simple_model(x)
print(my_simple_model_maybe_init(4., multiplier=3.)) # prints a tensor = 12.
print(my_simple_model_maybe_init(5.)) # prints a tensor = 15.
# Only solution 4 requires the double-trace magic. I don't see the benefit of
# supporting it, it just encourages poor programming style.
# I think the rule should be that no TF state should be created in a function
# or method decorated by tf.function or tf.method.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment