Skip to content

Instantly share code, notes, and snippets.

@adrn
Last active December 17, 2015 11:59
Show Gist options
  • Save adrn/5606738 to your computer and use it in GitHub Desktop.
Save adrn/5606738 to your computer and use it in GitHub Desktop.
# coding: utf-8
""" Astropy modeling demo against leastsq """
from __future__ import division, print_function
__author__ = "adrn <adrn@astro.columbia.edu>"
# Standard library
import os, sys
sys.path = ["/Users/adrian/projects/astropy_adrn/build/lib.macosx-10.7-intel-2.7/"] + sys.path
# Third-party
from astropy.modeling import models, fitting
import matplotlib.pyplot as plt
import numpy as np
from scipy.optimize import leastsq
def gaussian2d(p, x, y):
A, xmean, ymean, xsigma, ysigma, rho = p
xx = x - xmean
yy = y - ymean
#C1 = 1/(2*np.pi) / (xsigma*ysigma*np.sqrt(1-rho**2))
C2 = -1./(2*(1-rho**2))
C1 = 1.
f = np.exp(C2*( (xx/xsigma)**2 + (yy/ysigma)**2 + 2*rho*xx*yy/(xsigma*ysigma) ))
return A*C1*f
def error_function(p, x, y, z):
return z - gaussian2d(p, x, y)
# define pixel grid
X,Y = np.meshgrid(np.arange(11),np.arange(11))
# generate fake data
p0 = [137., 5.1, 5.4, 1.5, 2., np.pi/4]
data = gaussian2d(p0, X.ravel(), Y.ravel()).reshape(X.shape)
data += np.random.normal(0., 1., size=data.shape)
# create 2d gaussian model from astropy.modeling
gauss = models.Gaussian2DModel(amplitude=10., x_mean=5., y_mean=5.,
x_stddev=4., y_stddev=4., theta=0.5,
bounds={"x_mean" : [0.,11.],
"y_mean" : [0.,11.],
"x_stddev" : [1.,4],
"y_stddev" : [1.,4]})
gauss_fit = fitting.NonLinearLSQFitter(gauss)
gauss_fit(X, Y, data)
# fit the image data with scipy.optimize.leastsq for comparison
p_opt, success = leastsq(error_function,
[10., 5., 5., 4., 4., 0.5],
args=(X.ravel(),Y.ravel(),data.ravel()),
maxfev=10000)
print("true", ["{0:.3f}".format(x) for x in p0])
print("leastsq", ["{0:.3f}".format(x) for x in p_opt])
print("astropy.modeling", ["{0:.3f}".format(x) for x in gauss.parameters])
fig,ax = plt.subplots(1,3)
ax[0].imshow(data, interpolation="nearest")
ax[1].imshow(gauss(X,Y), interpolation="nearest")
ax[2].imshow(gaussian2d(p_opt,X.ravel(),Y.ravel()).reshape(X.shape), interpolation="nearest")
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment