Skip to content

Instantly share code, notes, and snippets.

@openerror
Last active July 1, 2019 03:44
Show Gist options
  • Save openerror/14b50889683cf352865fdb08ccc1ef92 to your computer and use it in GitHub Desktop.
Save openerror/14b50889683cf352865fdb08ccc1ef92 to your computer and use it in GitHub Desktop.
Sample do-while loop in Tensorflow
from math import fabs
import tensorflow as tf
def test_func_tf(x):
x = tf.math.abs(x)
# Callables used for tf.cond(), which replicates the if statement
def continue_loop(): return tf.while_loop(cond, body, loop_vars=[x], return_same_structure=True)
def exit_loop(): return x
# Callables for tf.while_loop(): if cond(x)==True, then body(x) executes
def cond(x): return tf.math.less(x, 10.0)
def body(x):
x = x*2.0
x = tf.cond( tf.math.greater_equal(x, 10.0),
true_fn = lambda: (x / 1.2),
false_fn = lambda: x)
return x
''' Execution '''
# Step 1: Executes body(x) at least once
x = body(x)
# Step 2: Depending on what pred evaluates to, keep executing the loop OR exit
x = tf.cond( tf.math.greater_equal(x, 10.0), true_fn=exit_loop, false_fn=continue_loop)
# Step 3: Return results
return x
def test_func_py(x):
x = fabs(x) # Make sure x is positve, so that while loop would terminate
while True:
x = x * 2.0
if x >= 10.0:
x = x/1.2
break
return x
def evaluation(x):
with tf.Session() as sess:
y_tf = sess.run( test_func_tf(x) )
y_py = test_func_py(x)
print("Py: {}; Tf: {}".format(y_py, y_tf))
# Purely for clean-up
tf.reset_default_graph()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment