Skip to content

Instantly share code, notes, and snippets.

Forked from mattjj/
Last active November 25, 2021 22:26
Show Gist options
  • Save zhangqiaorjc/acaa00aceba71c59b3b8511d764a0abd to your computer and use it in GitHub Desktop.
Save zhangqiaorjc/acaa00aceba71c59b3b8511d764a0abd to your computer and use it in GitHub Desktop.
import itertools as it
import jax
import jax.numpy as jnp
jax.config.update('jax_enable_x64', True)
jax.config.update('jax_platform_name', 'cpu')
L = num_stages = 5
N = batch_size = 6
M = num_microbatches = 2
B = microbatch_size = 3
assert N == M * B
F = num_feat = 3
params = jax.random.normal(jax.random.PRNGKey(0), (L, F, F))
inputs = jnp.arange(N * F).reshape(N, F)
def fn(params, inputs):
assert params.ndim == 2 and inputs.ndim == 1
return jnp.tanh(, inputs))
state = inputs
for i in range(L):
state = jax.vmap(fn, (None, 0))(params[i], state)
outputs = state
def spmd_pipeline(fn, params, inputs):
inputs = jnp.pad(inputs, [[0, L-1], [0, 0], [0, 0]])
outputs = jnp.zeros((M+L-1, B, F))
state = jnp.zeros([L, B, F])
for i in range(M + L - 1):
state = shift_and_insert(state, inputs[i])
state = jax.vmap(jax.vmap(fn, (None, 0)))(params, state)
outputs =[i].set(state[-1]) # last layer output
return outputs[L-1:]
def shift_and_insert(arr, x):
padding = [[1, 0]] + [[0, 0]] * (arr.ndim - 1)
arr = jnp.pad(arr, padding)[:-1]
iota = jax.lax.broadcasted_iota('int32', arr.shape, 0)
return jnp.where(iota == 0, x, arr)
outputs2 = spmd_pipeline(fn, params, inputs.reshape(M, B, F)).reshape(N, F)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment