Last active
November 30, 2018 22:59
-
-
Save jvmncs/3f241e351ad986d5f8ddca4317b0a540 to your computer and use it in GitHub Desktop.
TensorFlow equivalent of neurips_tfe_minimal.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
# generic function stubs | |
def provide_weights(): """Load model weights.""" | |
def provide_input(): """Load input data.""" | |
# get model weights/input data | |
# (both unencrypted) | |
w0, b0, w1, b1, w2, b2 = provide_weights() | |
x = provide_input() | |
# compute prediction | |
layer0 = tf.nn.relu((tf.matmul(x, w0) + b0)) | |
layer1 = tf.nn.relu((tf.matmul(layer0, w1) + b1)) | |
logits = tf.matmul(layer2, w2) + b2 | |
# get result of prediction and print | |
prediction_op = tf.Print(result, | |
[logits], | |
summarize=10) | |
# run graph execution in a tf.Session | |
with tf.Session() as sess: | |
sess.run(tf.global_variables_initializer(), | |
tag="init") | |
sess.run(prediction_op, tag="prediction") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment