Skip to content

Instantly share code, notes, and snippets.

@jvmncs
Last active November 28, 2018 23:18
Show Gist options
  • Save jvmncs/4f5af5307f6a615b7857ca2bfef6fcf3 to your computer and use it in GitHub Desktop.
Save jvmncs/4f5af5307f6a615b7857ca2bfef6fcf3 to your computer and use it in GitHub Desktop.
TensorFlow equivalent of tfe_minimal.py
import tensorflow as tf
# generic functions for loading model weights and input data
def provide_weights(): """Load model weights as TensorFlow objects."""
def provide_input(): """Load input data as TensorFlow objects."""
# get model weights/input data (both unencrypted)
w0, b0, w1, b1, w2, b2 = provide_weights()
x = provide_input()
# compute prediction
layer0 = tf.nn.relu((tf.matmul(x, w0) + b0))
layer1 = tf.nn.relu((tf.matmul(layer0, w1) + b1))
logits = tf.matmul(layer2, w2) + b2
# get result of prediction and print
prediction_op = tf.Print(result, [logits], message="prediction: ", summarize=10)
# run graph execution in a tf.Session
with tf.Session() as sess:
sess.run(tf.global_variables_initializer(), tag="init")
sess.run(prediction_op, tag="prediction")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment