Skip to content

Instantly share code, notes, and snippets.

@zhiics
Created November 22, 2019 05:34
Show Gist options
  • Save zhiics/1d2219a60ff127fa7d988750ca279569 to your computer and use it in GitHub Desktop.
Save zhiics/1d2219a60ff127fa7d988750ca279569 to your computer and use it in GitHub Desktop.
import tensorflow as tf
import numpy as np
from tvm import relay
from tvm.relay.frontend.tensorflow import from_tensorflow
def check_equal(graph, tf_out):
mod, params = from_tensorflow(graph.as_graph_def(add_shapes=True))
print("---------------")
print(mod["main"])
print("---------------")
ex = relay.create_executor('vm', mod=mod)
relay_out = ex.evaluate()(**params)
if isinstance(relay_out, relay.backend.interpreter.TensorValue):
np.testing.assert_allclose(tf_out, relay_out.asnumpy())
else:
if not isinstance(tf_out, list):
tf_out = [tf_out]
for x, y in zip(tf_out, [r.asnumpy() for r in relay_out]):
np.testing.assert_allclose(x, y)
def test_1():
array = tf.Variable(tf.random.normal([10]))
i = tf.constant(0)
l = tf.Variable([])
def body(i, l):
temp = tf.gather(array, i)
l = tf.concat([l, [temp]], 0)
return i + 1, l
def cond(i, l):
return i < 10
index, list_vals = tf.while_loop(
cond,
body, [i, l],
shape_invariants=[i.get_shape(), tf.TensorShape([None])])
sess = tf.Session()
sess.run(tf.global_variables_initializer())
sess.run(list_vals)
def test_ta():
graph = tf.Graph()
with graph.as_default():
array = tf.ones([10], dtype=tf.int32)
step = tf.constant(0)
output = tf.TensorArray(dtype=tf.int32, size=10)
def cond(step, _):
return step < 10
def body(step, output):
output = output.write(step, tf.gather(array, step))
return step + 1, output
index, final_output = tf.while_loop(
cond, body, loop_vars=[step, output])
final_output = final_output.stack()
with tf.Session() as sess:
# sess.run(tf.global_variables_initializer())
tf_out = sess.run(final_output)
tf_index_out = sess.run(index)
print(tf_out)
print(tf_index_out)
check_equal(graph, tf_out)
test_ta()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment