Skip to content

Instantly share code, notes, and snippets.

@mattjj
Created August 19, 2021 04:20
Show Gist options
  • Save mattjj/aeceb3f2944f25d5805f5789771d27db to your computer and use it in GitHub Desktop.
Save mattjj/aeceb3f2944f25d5805f5789771d27db to your computer and use it in GitHub Desktop.
# zhangqiaorjc@google.com
import functools
from absl import app
from absl import flags
from absl import logging
import jax
from jax.lib import xla_extension as xc
flags.DEFINE_string('server_ip', '', help='server ip addr')
flags.DEFINE_integer('server_port', 0, help='server ip port')
flags.DEFINE_integer('num_hosts', 1, help='num of hosts' )
flags.DEFINE_integer('host_idx', 0, help='index of current host' )
FLAGS = flags.FLAGS
def connect_to_gpu_cluster():
service = None
if FLAGS.host_idx == 0:
addr = f'localhost:{FLAGS.server_ip}'
logging.info('starting service on %s', addr)
service = xc.get_distributed_runtime_service(addr, FLAGS.num_hosts)
server_addr = f'{FLAGS.server_ip}:{FLAGS.server_port}'
logging.info('connecting to service on %s', server_addr)
dist_clent = xc.get_distributed_runtime_client(server_addr, 0) # on a different process change 0 to 1 or 2...
# register dist gpu backend
factory = functools.partial(jax.lib.xla_client.make_gpu_client, dist_clent, FLAGS.host_idx)
jax.lib.xla_bridge.register_backend_factory('gpu', factory, priority=300)
return service
def main(argv):
service = connect_to_gpu_cluster()
logging.info('gpu cluster connected')
logging.info('devices %s', jax.devices())
logging.info('local devices %s', jax.local_devices())
logging.info('shutting down gpu cluster...')
del service
if __name__ == '__main__':
app.run(main)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment