Skip to content

Instantly share code, notes, and snippets.

@jvmncs
Last active October 18, 2018 18:29
Show Gist options
  • Save jvmncs/9bed115452aefcb919a279cf948c7f5e to your computer and use it in GitHub Desktop.
Save jvmncs/9bed115452aefcb919a279cf948c7f5e to your computer and use it in GitHub Desktop.
A minimal example of using tf-encrypted for secure inference
import tensorflow as tf
import tf_encrypted as tfe
# define the parties
tfe.set_config(tfe.RemoteConfig({
"model-owner": "localhost:2222",
"prediction-client": "1.2.3.4:2222",
"server0": "1.1.1.1:4444",
"server1": "1.1.1.2:4444"})
# generic functions for loading model weights and input data on each party
def provide_weights(): """Loads the model weights on the model-owner party."""
def provide_input(): """Loads the input data on the prediction-client party."""
def receive_output(): """Receives and decrypts output on prediction-client."""
# get model weights/input data as private tensors from each party
w0, b0, w1, b1, w2, b2 = tfe.define_private_input("model-owner", provide_weights)
x = tfe.define_private_input("prediction-client", provide_input)
# compute secure prediction
layer0 = tfe.relu((tfe.matmul(x, w0) + b0))
layer1 = tfe.relu((tfe.matmul(layer0, w1) + b1))
logits = tfe.matmul(layer1, w2) + b2
# send prediction output back to client
prediction_op = tfe.define_output("prediction-client", [logits], receive_output)
# run secure graph execution in a tf.Session
with tfe.Session() as sess:
sess.run(tf.global_variables_initializer(), tag="init")
sess.run(prediction_op, tag="prediction")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment