Skip to content

Instantly share code, notes, and snippets.

@mwitiderrick
Created July 11, 2022 11:42
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mwitiderrick/90916eb6ade99c5ba23d579d457233f0 to your computer and use it in GitHub Desktop.
Save mwitiderrick/90916eb6ade99c5ba23d579d457233f0 to your computer and use it in GitHub Desktop.
x = np.arange(5)
w = np.array([2., 3., 4.])
def convolve(x, w):
output = []
for i in range(1, len(x)-1):
output.append(jnp.dot(x[i-1:i+2], w))
return jnp.array(output)
convolve(x, w)
n_devices = jax.local_device_count()
xs = np.arange(5 * n_devices).reshape(-1, 5)
ws = np.stack([w] * n_devices)
jax.pmap(convolve)(xs, ws)
# ShardedDeviceArray([[ 11., 20., 29.],
# .................
# [326., 335., 344.]], dtype=float32)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment