Skip to content

Instantly share code, notes, and snippets.

@mwitiderrick
Created July 11, 2022 11:21
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mwitiderrick/48d663aa32f38146ef6fe682f797dadc to your computer and use it in GitHub Desktop.
Save mwitiderrick/48d663aa32f38146ef6fe682f797dadc to your computer and use it in GitHub Desktop.
@jax.jit
def sum_logistic(x):
return jnp.sum(1.0 / (1.0 + jnp.exp(-x))),(x + 1)
x_small = jnp.arange(6.)
derivative_fn = jax.grad(sum_logistic, has_aux=True)
print(derivative_fn(x_small))
# (DeviceArray([0.25 , 0.19661194, 0.10499357, 0.04517666, 0.01766271,
# 0.00664806], dtype=float32), DeviceArray([1., 2., 3., 4., 5., 6.], dtype=float32))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment