Skip to content

Instantly share code, notes, and snippets.

@goyalankit
Last active February 2, 2021 22:54
Show Gist options
  • Save goyalankit/9be1e8752d67725635dd89c62c29545a to your computer and use it in GitHub Desktop.
Save goyalankit/9be1e8752d67725635dd89c62c29545a to your computer and use it in GitHub Desktop.
TFF Multi Node Toy Example
# Copyright 2019, The TensorFlow Federated Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A generic worker binary for deployment, e.g., on GCP."""
from absl import app
from absl import flags
import grpc
import tensorflow_federated as tff
import tensorflow as tf
tf.get_logger().setLevel('DEBUG')
FLAGS = flags.FLAGS
flags.DEFINE_integer('port', '8000', 'port to listen on')
flags.DEFINE_integer('threads', '10', 'number of worker threads in thread pool')
flags.DEFINE_string('private_key', '', 'the private key for SSL/TLS setup')
flags.DEFINE_string('certificate_chain', '', 'the cert for SSL/TLS setup')
flags.DEFINE_integer('clients', '1', 'number of clients to host on this worker')
flags.DEFINE_integer('fanout', '100',
'max fanout in the hierarchy of local executors')
def main(argv):
del argv
executor_factory = tff.framework.local_executor_factory(
num_clients=FLAGS.clients, max_fanout=FLAGS.fanout)
if FLAGS.private_key:
if FLAGS.certificate_chain:
with open(FLAGS.private_key, 'rb') as f:
private_key = f.read()
with open(FLAGS.certificate_chain, 'rb') as f:
certificate_chain = f.read()
credentials = grpc.ssl_server_credentials(((
private_key,
certificate_chain,
),))
else:
raise ValueError(
'Private key has been specified, but the certificate chain missing.')
else:
credentials = None
tff.simulation.run_server(
executor_factory, FLAGS.threads, FLAGS.port,
credentials)
if __name__ == '__main__':
app.run(main)
import nest_asyncio
nest_asyncio.apply()
import sys
import numpy as np
import grpc
import tensorflow as tf
tf.get_logger().setLevel('DEBUG')
import collections
import time
import tensorflow as tf
import tensorflow_federated as tff
@tf.function
def print_me(ds):
tf.print("should print on client..", ds)
return ds
@tff.tf_computation(tff.SequenceType(tf.int32))
def process_data(ds):
ds = print_me(ds)
return ds.reduce(np.int32(0), lambda x, y: x + y)
@tff.federated_computation(tff.FederatedType(tff.SequenceType(tf.int32), tff.CLIENTS))
def process_data_on_clients(federated_ds):
return tff.federated_map(process_data, federated_ds)
@tff.tf_computation(tf.int32)
def make_data(n):
tf.print("hello...", output_stream=sys.stderr)
return tf.data.Dataset.range(tf.cast(n, tf.int64)).map(lambda x: tf.cast(x + 1, tf.int32))
@tff.federated_computation(tff.FederatedType(tf.int32, tff.CLIENTS))
def make_data_on_clients(federated_n):
return tff.federated_map(make_data, federated_n)
@tff.federated_computation(tff.FederatedType(tf.int32, tff.CLIENTS))
def make_and_process_data_on_clients(federated_n):
tf.print("Running on server")
federated_ds = make_data_on_clients(federated_n)
return process_data_on_clients(federated_ds)
ip_address = '0.0.0.0'
port = 8000
channels = [grpc.insecure_channel(f'{ip_address}:{port}')]
channels.append(grpc.insecure_channel(f'{ip_address}:{port}'))
port = 8001
channels.append(grpc.insecure_channel(f'{ip_address}:{port}'))
channels.append(grpc.insecure_channel(f'{ip_address}:{port}'))
tff.backends.native.set_remote_execution_context(channels, rpc_mode='STREAMING')
federated_n = [2, 3, 4, 5]
print(make_and_process_data_on_clients(federated_n))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment