Skip to content

Instantly share code, notes, and snippets.

@jvmncs
Last active November 30, 2018 22:41
Show Gist options
  • Save jvmncs/e7deb6438e4c143ff98665b39e39e7fa to your computer and use it in GitHub Desktop.
Save jvmncs/e7deb6438e4c143ff98665b39e39e7fa to your computer and use it in GitHub Desktop.
A minimal example of using tf-encrypted for secure inference, shortened for our NeurIPS poster
import tensorflow as tf
import tf_encrypted as tfe
# generic remote procedure calls
def provide_weights(): """Load model weights."""
def provide_input(): """Load input data."""
def receive_output(): """Receive and decrypt output."""
# get model weights/input data
# as private tensors from each party
weights = tfe.define_private_input("model-owner",
provide_weights)
w0, b0, w1, b1, w2, b2 = weights
x = tfe.define_private_input("prediction-client",
provide_input)
# compute secure prediction
layer0 = tfe.relu((tfe.matmul(x, w0) + b0))
layer1 = tfe.relu((tfe.matmul(layer0, w1) + b1))
logits = tfe.matmul(layer1, w2) + b2
# send prediction output back to client
prediction_op = tfe.define_output("prediction-client",
[logits],
receive_output)
# run secure graph execution in a tf.Session
with tfe.Session() as sess:
sess.run(tf.global_variables_initializer(),
tag="init")
sess.run(prediction_op, tag="prediction")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment