TFF Multi Node Toy Example
# Copyright 2019, The TensorFlow Federated Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""A generic worker binary for deployment, e.g., on GCP.""" | |
from absl import app | |
from absl import flags | |
import grpc | |
import tensorflow_federated as tff | |
import tensorflow as tf | |
tf.get_logger().setLevel('DEBUG') | |
FLAGS = flags.FLAGS | |
flags.DEFINE_integer('port', '8000', 'port to listen on') | |
flags.DEFINE_integer('threads', '10', 'number of worker threads in thread pool') | |
flags.DEFINE_string('private_key', '', 'the private key for SSL/TLS setup') | |
flags.DEFINE_string('certificate_chain', '', 'the cert for SSL/TLS setup') | |
flags.DEFINE_integer('clients', '1', 'number of clients to host on this worker') | |
flags.DEFINE_integer('fanout', '100', | |
'max fanout in the hierarchy of local executors') | |
def main(argv): | |
del argv | |
executor_factory = tff.framework.local_executor_factory( | |
num_clients=FLAGS.clients, max_fanout=FLAGS.fanout) | |
if FLAGS.private_key: | |
if FLAGS.certificate_chain: | |
with open(FLAGS.private_key, 'rb') as f: | |
private_key = f.read() | |
with open(FLAGS.certificate_chain, 'rb') as f: | |
certificate_chain = f.read() | |
credentials = grpc.ssl_server_credentials((( | |
private_key, | |
certificate_chain, | |
),)) | |
else: | |
raise ValueError( | |
'Private key has been specified, but the certificate chain missing.') | |
else: | |
credentials = None | |
tff.simulation.run_server( | |
executor_factory, FLAGS.threads, FLAGS.port, | |
credentials) | |
if __name__ == '__main__': | |
app.run(main) |
import nest_asyncio | |
nest_asyncio.apply() | |
import sys | |
import numpy as np | |
import grpc | |
import tensorflow as tf | |
tf.get_logger().setLevel('DEBUG') | |
import collections | |
import time | |
import tensorflow as tf | |
import tensorflow_federated as tff | |
@tf.function | |
def print_me(ds): | |
tf.print("should print on client..", ds) | |
return ds | |
@tff.tf_computation(tff.SequenceType(tf.int32)) | |
def process_data(ds): | |
ds = print_me(ds) | |
return ds.reduce(np.int32(0), lambda x, y: x + y) | |
@tff.federated_computation(tff.FederatedType(tff.SequenceType(tf.int32), tff.CLIENTS)) | |
def process_data_on_clients(federated_ds): | |
return tff.federated_map(process_data, federated_ds) | |
@tff.tf_computation(tf.int32) | |
def make_data(n): | |
tf.print("hello...", output_stream=sys.stderr) | |
return tf.data.Dataset.range(tf.cast(n, tf.int64)).map(lambda x: tf.cast(x + 1, tf.int32)) | |
@tff.federated_computation(tff.FederatedType(tf.int32, tff.CLIENTS)) | |
def make_data_on_clients(federated_n): | |
return tff.federated_map(make_data, federated_n) | |
@tff.federated_computation(tff.FederatedType(tf.int32, tff.CLIENTS)) | |
def make_and_process_data_on_clients(federated_n): | |
tf.print("Running on server") | |
federated_ds = make_data_on_clients(federated_n) | |
return process_data_on_clients(federated_ds) | |
ip_address = '0.0.0.0' | |
port = 8000 | |
channels = [grpc.insecure_channel(f'{ip_address}:{port}')] | |
channels.append(grpc.insecure_channel(f'{ip_address}:{port}')) | |
port = 8001 | |
channels.append(grpc.insecure_channel(f'{ip_address}:{port}')) | |
channels.append(grpc.insecure_channel(f'{ip_address}:{port}')) | |
tff.backends.native.set_remote_execution_context(channels, rpc_mode='STREAMING') | |
federated_n = [2, 3, 4, 5] | |
print(make_and_process_data_on_clients(federated_n)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment