Skip to content

Instantly share code, notes, and snippets.

@yaroslavvb
Created March 23, 2017 21:02
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save yaroslavvb/d67410e240369736fc4ba0267250ef27 to your computer and use it in GitHub Desktop.
Save yaroslavvb/d67410e240369736fc4ba0267250ef27 to your computer and use it in GitHub Desktop.
import tensorflow as tf
from tensorflow.python.ops import control_flow_ops
from tensorflow.contrib import graph_editor as ge
def make_conditional_initializer(v):
"""Makes initializer of variable var lazy, returns new conditional init
op."""
cond = tf.is_variable_initialized(v)
dummy_data = cond
output_false, output_true = control_flow_ops.switch(dummy_data, cond)
variable_uninited_op = tf.identity(output_false)
variable_inited_op = tf.identity(output_true)
# only evaluate initial value if variable is not initialized
ge.reroute.add_control_inputs(v.initial_value.op, [variable_uninited_op.op])
with tf.control_dependencies([v.initializer]):
initializer_triggered = tf.ones(())
initializer_triggered = tf.Print(initializer_triggered,
[initializer_triggered],
"triggered path")
with tf.control_dependencies([variable_inited_op]):
initializer_not_triggered = tf.zeros(())
initializer_not_triggered = tf.Print(initializer_not_triggered,
[initializer_not_triggered],
"Non-triggered path")
return control_flow_ops.merge([initializer_not_triggered,
initializer_triggered])
def conditional_initializer_test():
result0_ = tf.random_uniform(())
result0 = tf.Print(result0_, [result0_], "initializing")
var = tf.Variable(result0)
conditional_init = make_conditional_initializer(var)
sess = tf.Session()
print("Init 1")
print(sess.run(conditional_init))
print("Init 2")
print(sess.run(conditional_init))
print("Init 3")
print(sess.run(conditional_init))
if __name__=='__main__':
conditional_initializer_test()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment