Skip to content

Instantly share code, notes, and snippets.

@mattjj
Created November 15, 2021 05:59
Show Gist options
  • Save mattjj/d15bd97770ab5e6553ab198f33d5f67e to your computer and use it in GitHub Desktop.
Save mattjj/d15bd97770ab5e6553ab198f33d5f67e to your computer and use it in GitHub Desktop.
import itertools as it
import jax
import jax.numpy as jnp
jax.config.update('jax_enable_x64', True)
jax.config.update('jax_platform_name', 'cpu')
L = num_stages = 5
N = batch_size = 6
M = num_microbatches = 2
B = microbatch_size = 3
assert N == M * B
F = num_feat = 3
params = jax.random.normal(jax.random.PRNGKey(0), (L, F, F))
inputs = jnp.arange(N * F).reshape(N, F)
def fn(params, inputs):
assert params.ndim == 2 and inputs.ndim == 1
return jnp.tanh(jnp.dot(params, inputs))
state = inputs
for i in range(L):
state = jax.vmap(fn, (None, 0))(params[i], state)
outputs = state
print(outputs)
def spmd_pipeline(fn, params, inputs):
inputs = jnp.pad(inputs[:, None], [[0, L-1], [0, L-1], [0, 0], [0, 0]])
outputs = jnp.zeros((M+L-1, B, F))
state = jnp.zeros([L, B, F])
for i in range(M + L - 1):
state = shift_and_insert(state, inputs[i])
state = jax.vmap(jax.vmap(fn, (None, 0)))(params, state)
outputs = outputs.at[i].set(state[-1]) # last layer output
return outputs[L-1:]
def shift_and_insert(arr, x):
padding = [[1, 0]] + [[0, 0]] * (arr.ndim - 1)
arr = jnp.pad(arr, padding)[:-1]
iota = jax.lax.broadcasted_iota('int32', arr.shape, 0)
return jnp.where(iota == 0, x, arr)
outputs2 = spmd_pipeline(fn, params, inputs.reshape(M, B, F)).reshape(N, F)
print(outputs2)
@zhangqiaorjc
Copy link

The implementation can be simplified with a one-liner

34: inputs = jnp.pad(inputs, [[0, L-1], [0, 0], [0, 0]])

@mattjj
Copy link
Author

mattjj commented Nov 20, 2021

Good point! I wondered about that, though I kind of assumed we had to add this extra dimension for GSPMD or something.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment