Last active
July 1, 2019 03:44
-
-
Save openerror/14b50889683cf352865fdb08ccc1ef92 to your computer and use it in GitHub Desktop.
Sample do-while loop in Tensorflow
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from math import fabs | |
import tensorflow as tf | |
def test_func_tf(x): | |
x = tf.math.abs(x) | |
# Callables used for tf.cond(), which replicates the if statement | |
def continue_loop(): return tf.while_loop(cond, body, loop_vars=[x], return_same_structure=True) | |
def exit_loop(): return x | |
# Callables for tf.while_loop(): if cond(x)==True, then body(x) executes | |
def cond(x): return tf.math.less(x, 10.0) | |
def body(x): | |
x = x*2.0 | |
x = tf.cond( tf.math.greater_equal(x, 10.0), | |
true_fn = lambda: (x / 1.2), | |
false_fn = lambda: x) | |
return x | |
''' Execution ''' | |
# Step 1: Executes body(x) at least once | |
x = body(x) | |
# Step 2: Depending on what pred evaluates to, keep executing the loop OR exit | |
x = tf.cond( tf.math.greater_equal(x, 10.0), true_fn=exit_loop, false_fn=continue_loop) | |
# Step 3: Return results | |
return x | |
def test_func_py(x): | |
x = fabs(x) # Make sure x is positve, so that while loop would terminate | |
while True: | |
x = x * 2.0 | |
if x >= 10.0: | |
x = x/1.2 | |
break | |
return x | |
def evaluation(x): | |
with tf.Session() as sess: | |
y_tf = sess.run( test_func_tf(x) ) | |
y_py = test_func_py(x) | |
print("Py: {}; Tf: {}".format(y_py, y_tf)) | |
# Purely for clean-up | |
tf.reset_default_graph() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment