Skip to content

Instantly share code, notes, and snippets.

@herberteuler
Last active March 23, 2020 12:15
Show Gist options
  • Save herberteuler/7b2c9de3e35e635811e0caacd812f893 to your computer and use it in GitHub Desktop.
Save herberteuler/7b2c9de3e35e635811e0caacd812f893 to your computer and use it in GitHub Desktop.
import numpy as np
import tensorflow as tf
def input_fn():
def parse(tensor):
return tensor[:3], tensor[-1:]
ds = tf.data.Dataset.from_tensor_slices(np.random.random_sample([10,4]))
ds = ds.map(parse).batch(10).repeat(5)
return ds
class CheckEMAHook(tf.train.SessionRunHook):
def __init__(self, ema_name):
self.ema_name = ema_name
def before_run(self, run_context):
sess = run_context.session
tensor = sess.graph.get_tensor_by_name(self.ema_name)
print(sess.run(tensor))
def show(msg, color):
print('\033[1;%s;40m%s\033[00m' % (color, msg))
ema_name = 'weight/kernel/ExponentialMovingAverage:0'
with tf.Graph().as_default() as graph:
def model_fn(features, labels, mode):
predictions = tf.layers.dense(features, 1, name='weight')
if mode == tf.estimator.ModeKeys.TRAIN:
loss = tf.losses.mean_squared_error(labels, predictions)
optimizer = tf.train.GradientDescentOptimizer(0.1)
train_op = tf.contrib.training.create_train_op(loss, optimizer)
ema = tf.train.ExponentialMovingAverage(0.1)
with tf.control_dependencies([train_op]):
train_op = ema.apply()
return tf.estimator.EstimatorSpec(
mode=mode, predictions=predictions, loss=loss, train_op=train_op)
estimator = tf.estimator.Estimator(model_fn, './models/expected_behavior')
show('The moving average(s) are updated during training,', 36)
show('as can be seen from here:', 36)
estimator.train(input_fn, hooks=[CheckEMAHook(ema_name)])
with tf.Graph().as_default() as graph:
input = tf.keras.Input((3,))
output = tf.keras.layers.Dense(1, name='weight')(input)
model = tf.keras.Model(inputs=input, outputs=output)
ema = tf.train.ExponentialMovingAverage(0.1)
model.add_update(ema.apply())
show('The update ops can also be added to a Keras model,', 32)
show('as can be seen from here:', 32)
for node in graph.as_graph_def().node:
if node.name.endswith('ExponentialMovingAverage'):
print(node.name)
optimizer = tf.train.GradientDescentOptimizer(0.1)
model.compile(optimizer=optimizer, loss='mse', metrics=['accuracy', 'mse'])
estimator = tf.keras.estimator.model_to_estimator(
keras_model=model, model_dir='./models/keras')
try:
estimator.train(input_fn, hooks=[CheckEMAHook(ema_name)])
except KeyError as exc:
show('But after the conversion to an estimator, the update ops are', 31)
show('lost during the cloning of the model. And the attempt of', 31)
show('retrieving the moving average(s) results in an exception:', 31)
raise(exc)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment