Skip to content

Instantly share code, notes, and snippets.

@rgommers
Last active June 28, 2017 09:50
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rgommers/3d05325678af5e7d1fb5818e28b8fc9d to your computer and use it in GitHub Desktop.
Save rgommers/3d05325678af5e7d1fb5818e28b8fc9d to your computer and use it in GitHub Desktop.
A test of autograd for automatic differentiation of scipy.special functions
import numpy as np
import matplotlib.pyplot as plt
from autograd import grad
import autograd.scipy.special as special
plt.style.use('ggplot')
x = np.linspace(-10, 10, num=1000)
y_j0 = special.j0(x)
y_j1 = special.j1(x)
order = 1
y_jn = special.jn(order, x)
def gradvec(fun, x):
"""
grad() takes only scalar inputs, so use a loop for now (inefficient, but it
works fine) - no need for speed just yet.
"""
dfun_dx = np.empty_like(x)
for ix, val in enumerate(x):
dfun_dx[ix] = grad(fun)(val)
return dfun_dx
def gradvec2(fun, arg1, x):
"""Same as `gradvec` but for functions like jn."""
dfun_dx = np.empty_like(x)
for ix, val in enumerate(x):
dfun_dx[ix] = grad(fun)(arg1, val)
return dfun_dx
dy_j0 = gradvec(special.j0, x)
dy_j1 = gradvec(special.j1, x)
# This gives a warning, seems the autograd jn implementation isn't correct (?)
dy_jn = gradvec2(special.jn, order, x)
fig = plt.figure()
ax = fig.add_subplot(111)
ax.plot(x, y_j0, '-', color='C0', label='$j_0$')
ax.plot(x, dy_j0, '--', color='C0', label='$dj_0/dx$')
ax.plot(x, y_j1, '-', color='C1', label='$j_1$')
ax.plot(x, dy_j1, '--', color='C1', label='$dj_1/dx$')
ax.set_xlabel('x')
ax.set_ylabel('f(x), df(x)/dx')
if False:
# Note that grad(special.jn) seems broken, returns zeros)
ax.plot(x, y_jn, '-', color='C2', label='$j_{n,%i}$' % order)
ax.plot(x, dy_jn, '--', color='C2', label='$dj_{n,%i}/dx$' % order)
ax.legend(loc='upper right')
# Create a second figure for psi, the vertical scale is quite different
x2 = np.linspace(0.1, 10, num=1000)
y_psi = special.psi(x2)
dy_psi = gradvec(special.psi, x2)
fig2 = plt.figure()
ax2 = fig2.add_subplot(111)
ax2.plot(x2, y_psi, '-', color='C0', label='$\psi$')
ax2.plot(x2, dy_psi, '--', color='C0', label='$d\psi/dx$')
ax2.set_xlabel('x')
ax2.set_ylabel('$\psi(x), d\psi(x)/dx$')
ax2.legend()
fig.savefig('bessel.png')
fig2.savefig('psi.png')
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment