Skip to content

Instantly share code, notes, and snippets.

@timvieira
Last active December 19, 2016 19:24
Show Gist options
  • Save timvieira/3d3db3e5e78e17cdd103 to your computer and use it in GitHub Desktop.
Save timvieira/3d3db3e5e78e17cdd103 to your computer and use it in GitHub Desktop.
Complex-step derivative. Accompanies blog post http://timvieira.github.io/blog/post/2014/08/07/complex-step-derivative/
import numpy as np
import pylab as pl
from numpy import exp, cos, sin, sqrt
from arsenal.math import compare
def run_tests():
tests = """
x**2
2*x
exp(x)
exp(x)
exp(x)**2
2*exp(2*x)
exp(x-100)**2 + cos(x) + 100
2*exp(2*(x-100)) - sin(x)
exp(x) / sqrt(sin(x)**3 + cos(x)**3)
(exp(x)* (3*cos(x) + 5*cos(3*x) + 9*sin(x) + sin(3*x))) / (8*(cos(x)**3 + sin(x)**3)**(3./2))
"""
h = 1e-10
#tol = 1e-10
for xx in tests.strip().split('\n\n'):
f,d = xx.split('\n')
x = np.asarray([complex(xx,h) for xx in np.linspace(0,2,100)],
dtype=np.complex128)
fs = eval(f)
ds = eval(d).real
#pl.plot(xs.real, fs.real, alpha=0.5, label=f)
pl.figure()
pl.plot(x.real, fs.imag/h, alpha=0.5, label='complex-step', lw=2, c='r')
pl.plot(x.real, ds, alpha=0.5, label='analytic', c='b', lw=2)
pl.title('derivative of %s' % f)
pl.legend()
assert np.abs(fs.imag/h - ds).max() < 1e-10
pl.show()
def complex_step(f, eps=1e-10):
def f1(x):
y = f(complex(x, eps)) # convert input to complex number
return y.real, y.imag / eps # return function value and gradient
return f1
def fd(f, eps=1e-4):
def g(x):
x = x*1.0
g = np.zeros_like(x)
for i in xrange(len(x)):
was = x[i]
x[i] = was + eps
b = f(x)
x[i] = was - eps
a = f(x)
x[i] = was
g[i] = (b-a)/2/eps
return g
return g
if __name__ == '__main__':
run_tests()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment