Skip to content

Instantly share code, notes, and snippets.

Last active November 21, 2022 14:15
Show Gist options
  • Save zhangqiaorjc/0ae6e7114fb0b3e9243e6420e4d6f3e4 to your computer and use it in GitHub Desktop.
Save zhangqiaorjc/0ae6e7114fb0b3e9243e6420e4d6f3e4 to your computer and use it in GitHub Desktop.
# `jax.distributed.initialize` is available in jax-0.2.25.
# $ pip install --upgrade "jax[cuda]" -f # Note: wheels only available on linux.
# Run this script on 2 GPU nodes, assuming is the master node
# python --server_addr="" --num_hosts=2 --host_idx=0
# python --server_addr="" --num_hosts=2 --host_idx=1
from absl import app
from absl import flags
import jax
import jax.numpy as jnp
import numpy as np
from jax.experimental.pjit import pjit, PartitionSpec as P
from jax.experimental import maps
flags.DEFINE_string('server_addr', '', help='server ip addr')
flags.DEFINE_integer('num_hosts', 1, help='num of hosts' )
flags.DEFINE_integer('host_idx', 0, help='index of current host' )
def main(argv):
jax.distributed.initialize(FLAGS.server_addr, FLAGS.num_hosts, FLAGS.host_idx)
print('global devices=', jax.devices())
print('local devices=', jax.local_devices())
def f(x, w):
return jnp.einsum('blm,md->bld', x, w)
x = jnp.ones((2, 4, 20))
w = jnp.ones((20, 4))
print(f(x, w).shape)
# Model parallelism via pjit
n = jax.device_count()
mesh_shape = (n,)
device_mesh = np.array(jax.devices()).reshape(mesh_shape)
with maps.Mesh(device_mesh, ('mdl',)):
result = pjit(f, in_axis_resources=(P(None, None, 'mdl'), P('mdl', None)), out_axis_resources=None)(x, w)
# result is replicated on each chip
print('print shapes of result on each chip locally')
for i in range(len(result.device_buffers)):
if __name__ == '__main__':
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment