Skip to content

Instantly share code, notes, and snippets.

@willkurt
Created October 24, 2020 20:42
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save willkurt/f29863fc5be9414a2e91d818846079c7 to your computer and use it in GitHub Desktop.
Save willkurt/f29863fc5be9414a2e91d818846079c7 to your computer and use it in GitHub Desktop.
"""
This just a quick example of how creating derivatives easily allows us to
think about mathematics in a very different, computationally focused way.
In this example we consider the defintion of e as the value of x in
f(t) = x^t
Where the derivative of f, f', is equal to f.
f(t) = f'(t)
By comparing the loss when we look at f(0) we can use
JAX and Newton's method to "discover" e in a way that is
very similar in spirit to the analytical approaches but
allows us to solve this problem computationally.
This example should show up in Hacking Statistics with Python
https://www.countbayesie.com/blog/2020/9/16/writing-the-next-book-and-i-want-you-involved
"""
import jax.numpy as np
from jax import grad
# e is defined as the value f(t) = x^t where f' = f
# start by defining f
def f(x,t):
return np.power(x,t)
# use JAX to get our derivative
d_f_wrt_t = grad(f,argnums=1)
# loss is just the difference between these two
def loss_f(x):
return f(x,0.0) - d_f_wrt_t(x,0.0)
# we can now use Newton's method to find teh root of our loss function...
d_loss_f = grad(loss_f)
guess = 4.0
for _ in range(10):
guess -= loss_f(guess)/d_loss_f(guess)
# and tada! we found (float32) e!
print(guess)
#DeviceArray(2.718282, dtype=float32)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment