Skip to content

Instantly share code, notes, and snippets.

@mwitiderrick
Last active July 11, 2022 11:15
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mwitiderrick/3296f35b96fca2aebf9fd978000fb3af to your computer and use it in GitHub Desktop.
Save mwitiderrick/3296f35b96fca2aebf9fd978000fb3af to your computer and use it in GitHub Desktop.
@jax.jit
def sum_logistic(x):
return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))
x_small = jnp.arange(6.)
derivative_fn = jax.grad(sum_logistic)
print(derivative_fn(x_small))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment