Skip to content

Instantly share code, notes, and snippets.

@enijkamp
Created July 12, 2022 21:46
Show Gist options
  • Save enijkamp/a057a036a3df6196936419a63d3435bf to your computer and use it in GitHub Desktop.
Save enijkamp/a057a036a3df6196936419a63d3435bf to your computer and use it in GitHub Desktop.
nan.py
value_and_grad_f = jax.value_and_grad(train_apply_f, has_aux=False, allow_int=True)
grad_init = jax.tree_map(lambda x: jnp.zeros_like(x).astype(jnp.bfloat16), params_bf16)
def scan(f, init, xs):
carry = init
ys = []
for i in range(xs[0].shape[0]):
carry, y = f(carry, (xs[0][i], xs[1][i]))
ys.append(y)
return carry, jnp.stack(ys)
def grad_sum(grad_old, sample):
x, y = sample
loss, grad = value_and_grad_f(params_bf16, x, y)
grad_new = jax.tree_multimap(lambda a, b: a + b, grad_old, grad)
return grad_new, loss
grad_0, _ = jax.lax.scan(grad_sum, grad_init, (x, y))
grad_1, _ = scan(grad_sum, grad_init, (x, y))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment