Created
November 22, 2019 05:34
-
-
Save zhiics/1d2219a60ff127fa7d988750ca279569 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
import numpy as np | |
from tvm import relay | |
from tvm.relay.frontend.tensorflow import from_tensorflow | |
def check_equal(graph, tf_out): | |
mod, params = from_tensorflow(graph.as_graph_def(add_shapes=True)) | |
print("---------------") | |
print(mod["main"]) | |
print("---------------") | |
ex = relay.create_executor('vm', mod=mod) | |
relay_out = ex.evaluate()(**params) | |
if isinstance(relay_out, relay.backend.interpreter.TensorValue): | |
np.testing.assert_allclose(tf_out, relay_out.asnumpy()) | |
else: | |
if not isinstance(tf_out, list): | |
tf_out = [tf_out] | |
for x, y in zip(tf_out, [r.asnumpy() for r in relay_out]): | |
np.testing.assert_allclose(x, y) | |
def test_1(): | |
array = tf.Variable(tf.random.normal([10])) | |
i = tf.constant(0) | |
l = tf.Variable([]) | |
def body(i, l): | |
temp = tf.gather(array, i) | |
l = tf.concat([l, [temp]], 0) | |
return i + 1, l | |
def cond(i, l): | |
return i < 10 | |
index, list_vals = tf.while_loop( | |
cond, | |
body, [i, l], | |
shape_invariants=[i.get_shape(), tf.TensorShape([None])]) | |
sess = tf.Session() | |
sess.run(tf.global_variables_initializer()) | |
sess.run(list_vals) | |
def test_ta(): | |
graph = tf.Graph() | |
with graph.as_default(): | |
array = tf.ones([10], dtype=tf.int32) | |
step = tf.constant(0) | |
output = tf.TensorArray(dtype=tf.int32, size=10) | |
def cond(step, _): | |
return step < 10 | |
def body(step, output): | |
output = output.write(step, tf.gather(array, step)) | |
return step + 1, output | |
index, final_output = tf.while_loop( | |
cond, body, loop_vars=[step, output]) | |
final_output = final_output.stack() | |
with tf.Session() as sess: | |
# sess.run(tf.global_variables_initializer()) | |
tf_out = sess.run(final_output) | |
tf_index_out = sess.run(index) | |
print(tf_out) | |
print(tf_index_out) | |
check_equal(graph, tf_out) | |
test_ta() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment