Skip to content

Instantly share code, notes, and snippets.

@scturtle
Created March 24, 2017 09:32
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save scturtle/715900f56d09b01558de2a0d47e0201c to your computer and use it in GitHub Desktop.
Save scturtle/715900f56d09b01558de2a0d47e0201c to your computer and use it in GitHub Desktop.
levenberg marquardt algorithm
import numpy as np
import numpy.linalg as npl
from easydict import EasyDict
opts = EasyDict(dict(max_iter=10, eps1=1e-8, eps2=1e-8))
def fJ(x, p, y=0):
f = p[0] * np.exp(- (x - p[1]) ** 2 / (2 * p[2] ** 2))
J = np.empty((p.size, x.size), dtype=np.float)
J[0, :] = f / p[0]
J[1, :] = f * (x - p[1]) / p[2] ** 2
J[2, :] = f * (x - p[1]) ** 2 / p[2] ** 3
J[3, :] = 1
return f + p[3] - y, J
def lm(func, x, y, p):
f, J = func(x, p, y)
A, g = np.inner(J, J), np.inner(J, f)
I = np.eye(p.size)
mu = 1e-3 * np.max(np.diag(A))
v = 2
k = 0
while npl.norm(g, np.inf) > opts.eps1 and k < opts.max_iter:
k += 1
h = npl.solve(A + mu * I, -g)
print('k:', k, 'mu:', mu, 'f:', npl.norm(f))
print(' p:', p, '\n h:', h)
if npl.norm(h) <= opts.eps2 * (npl.norm(p) + opts.eps2):
break
p2 = p + h
f2, J2 = func(x, p2, y)
rho = (npl.norm(f) - npl.norm(f2)) / np.inner(h, mu * h - g)
print(' rho:', rho)
if rho > 0:
p, f, J = p2, f2, J2
A, g = np.inner(J, J), np.inner(J, f)
mu *= max(1. / 3, 1 - (2 * rho - 1) ** 3)
v = 2
else:
mu *= v
v *= 2
return p
def test():
x = np.linspace(-3, 3, 1001)
p_true = np.array([1, 0.1, 1, 0.5])
y, _ = fJ(x, p_true)
y += 1e-2 * np.random.randn(x.size)
p_0 = np.array([1.1, 0.15, 1.3, 0.2])
p = lm(fJ, x, y, p_0)
print('final p:', p)
if __name__ == '__main__':
test()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment