Created
April 8, 2013 10:34
-
-
Save cdeil/5335831 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"metadata": { | |
"name": "models" | |
}, | |
"nbformat": 3, | |
"nbformat_minor": 0, | |
"worksheets": [ | |
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Try out astropy.models\n", | |
"\n", | |
"Let's see what's possible with astropy.models ..." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"import numpy as np\n", | |
"from astropy.models import models, fitting, parameters" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [], | |
"prompt_number": 1 | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"# We use iminuit to check the results\n", | |
"# http://iminuit.github.io/iminuit/\n", | |
"import iminuit" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [], | |
"prompt_number": 2 | |
}, | |
{ | |
"cell_type": "heading", | |
"level": 2, | |
"metadata": {}, | |
"source": [ | |
"Chi^2 fit" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"# Let's use the example from here:\n", | |
"# http://nbviewer.ipython.org/5030045\n", | |
"\n", | |
"# Data\n", | |
"xdata = np.array([0, 1, 2, 3, 4, 5])\n", | |
"ydata = np.array([1, 1, 5, 7, 8, 12])\n", | |
"sigma = np.array([1, 2, 1, 2, 1, 2], dtype=float)\n", | |
"\n", | |
"# Model\n", | |
"def f(x, slope, intercept):\n", | |
" \"\"\"Linear function (two fit parameters)\"\"\"\n", | |
" return slope * x + intercept" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [], | |
"prompt_number": 3 | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"# Compute parameters and errors with iminuit\n", | |
"def chi2(slope, intercept):\n", | |
" \"\"\"Define fit statistic, interpreting sigma as errors on ydata\"\"\"\n", | |
" chi = (ydata - f(xdata, slope, intercept)) / sigma\n", | |
" return np.sum(chi ** 2)\n", | |
"\n", | |
"minuit = iminuit.Minuit(chi2, pedantic=False, print_level=0)\n", | |
"minuit.migrad()\n", | |
"minuit.hesse()\n", | |
"print 'values:', minuit.values\n", | |
"print 'errors:', minuit.errors" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"values: {'slope': 1.952830188677304, 'intercept': 0.7704402515790711}\n", | |
"errors: {'slope': 0.3071475582636637, 'intercept': 0.8504530789683833}\n" | |
] | |
} | |
], | |
"prompt_number": 4 | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"# Compute parameters and errors with astropy.models\n", | |
"model = models.Poly1DModel(1)\n", | |
"\n", | |
"# TODO: Is it possible to rename parameters to what I want?\n", | |
"#model.parnames = ['intercept', 'slope']\n", | |
"\n", | |
"fitter = fitting.LinearLSQFitter(model)\n", | |
"fitter(xdata, ydata, weights=sigma ** (-1))\n", | |
"print dict(slope=model.c1, intercept=model.c0)\n", | |
"\n", | |
"# TODO: Currently no way to compute parameter errors with astropy.models?" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"{'slope': [1.9528301886792441], 'intercept': [0.77044025157232865]}\n" | |
] | |
} | |
], | |
"prompt_number": 5 | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"import astropy.models.parameters as p\n", | |
"p._Parameter" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"output_type": "pyout", | |
"prompt_number": 6, | |
"text": [ | |
"astropy.models.parameters._Parameter" | |
] | |
} | |
], | |
"prompt_number": 6 | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"# Let's try implementing this as a user-defined model in astropy\n", | |
"from astropy.models.models import ParametricModel\n", | |
"from astropy.models.parameters import _Parameter\n", | |
"\n", | |
"class Line(ParametricModel):\n", | |
" parnames = ['slope', 'intercept'] \n", | |
" def __init__(self, slope, intercept, paramdim=1):\n", | |
" self.linear = True\n", | |
" self.ndim = 1\n", | |
" self.outdim = 1\n", | |
" self._slope = _Parameter(name='slope', val=slope, mclass=self, paramdim=paramdim)\n", | |
" self._intercept = _Parameter(name='intercept', val=intercept, mclass=self, paramdim=paramdim)\n", | |
" ParametricModel.__init__(self, self.parnames, paramdim=paramdim)\n", | |
" def eval(self, x, params):\n", | |
" return params[0] * x + params[1]\n", | |
" def __call__(self, x):\n", | |
" x, format = _convert_input(x, self.paramdim)\n", | |
" result = self.eval(x, self.psets)\n", | |
" return _convert_output(result, format)\n", | |
"\n", | |
"model = Line(slope=1, intercept=1)" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"ename": "AttributeError", | |
"evalue": "can't set attribute", | |
"output_type": "pyerr", | |
"traceback": [ | |
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m\n\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", | |
"\u001b[0;32m<ipython-input-7-20332935293d>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 19\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0m_convert_output\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresult\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mformat\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 20\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 21\u001b[0;31m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mLine\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mslope\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mintercept\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", | |
"\u001b[0;32m<ipython-input-7-20332935293d>\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, slope, intercept, paramdim)\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mslope\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mintercept\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparamdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlinear\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mTrue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 9\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mndim\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 10\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moutdim\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slope\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_Parameter\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'slope'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mval\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mslope\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmclass\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparamdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mparamdim\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;31mAttributeError\u001b[0m: can't set attribute" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"ERROR: AttributeError: can't set attribute [IPython.core.interactiveshell]\n" | |
] | |
} | |
], | |
"prompt_number": 7 | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"# Let's try implementing a this as a user-defined fit statistic in astropy\n", | |
"from astropy.models.fitting import Fitter\n", | |
"\n", | |
"MAXITER = 100\n", | |
"EPS = 1e-10\n", | |
"\n", | |
"class SLSQPFitter(Fitter):\n", | |
" def __init__(self, model, fixed=None, tied=None, bounds=None,\n", | |
" eqcons=None, ineqcons=None):\n", | |
" Fitter.__init__(self, model, fixed=fixed, tied=tied, bounds=bounds,\n", | |
" eqcons=eqcons, ineqcons=ineqcons)\n", | |
" if self.model.linear:\n", | |
" raise ModelLinearityException('Model is linear in parameters, '\n", | |
" 'non-linear fitting methods should not be used.')\n", | |
"\n", | |
" def errorfunc(self, fps, *args):\n", | |
" meas = args[0]\n", | |
" self.fitpars = fps\n", | |
" res = self.model(*args[1:]) - meas\n", | |
" return np.sum(res**2)\n", | |
"\n", | |
" def __call__(self, x, y , maxiter=MAXITER, epsilon=EPS):\n", | |
" self.fitpars = optimize.fmin_slsqp(self.errorfunc, p0=self.model.parameters[:], args=(y, x),\n", | |
" bounds=self.constraints._bounds, eqcons=self.constraints.eqcons,\n", | |
" ieqcons=self.constraints.ineqcons)\n", | |
"\n", | |
"model = models.Poly1DModel(1)\n", | |
"fitter = SLSQPFitter(model)" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"ename": "TypeError", | |
"evalue": "__init__() got an unexpected keyword argument 'ineqcons'", | |
"output_type": "pyerr", | |
"traceback": [ | |
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m\n\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", | |
"\u001b[0;32m<ipython-input-8-025b390ac3a8>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 26\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 27\u001b[0m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodels\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mPoly1DModel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 28\u001b[0;31m \u001b[0mfitter\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mSLSQPFitter\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", | |
"\u001b[0;32m<ipython-input-8-025b390ac3a8>\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, model, fixed, tied, bounds, eqcons, ineqcons)\u001b[0m\n\u001b[1;32m 9\u001b[0m eqcons=None, ineqcons=None):\n\u001b[1;32m 10\u001b[0m Fitter.__init__(self, model, fixed=fixed, tied=tied, bounds=bounds,\n\u001b[0;32m---> 11\u001b[0;31m eqcons=eqcons, ineqcons=ineqcons)\n\u001b[0m\u001b[1;32m 12\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlinear\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 13\u001b[0m raise ModelLinearityException('Model is linear in parameters, '\n", | |
"\u001b[0;31mTypeError\u001b[0m: __init__() got an unexpected keyword argument 'ineqcons'" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"stream": "stdout", | |
"text": [ | |
"ERROR: TypeError: __init__() got an unexpected keyword argument 'ineqcons' [IPython.core.interactiveshell]\n" | |
] | |
} | |
], | |
"prompt_number": 8 | |
}, | |
{ | |
"cell_type": "heading", | |
"level": 2, | |
"metadata": {}, | |
"source": [ | |
"Poisson likelihood fit" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Once the chi^2 examples work, let's see if we can do a Poisson likelihood fit (Cash statistic) with astropy.models" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"def cash(D, M, safe=False):\n", | |
" \"\"\"cash fit statistic where D = data, M = model\n", | |
" http://cxc.cfa.harvard.edu/sherpa/statistics/#cash\n", | |
"\n", | |
" This is simply 2 x the negative log Poisson likelihood:\n", | |
" P = (M ** D) * exp(-M) / (D!)\n", | |
" log(P) = D * log(M) - M - log(D!)\n", | |
" cash = - 2 * log(P) (model-independent term dropped)\n", | |
" \"\"\"\n", | |
" stat = 2 * (M - D * log(M))\n", | |
" if safe:\n", | |
" return np.where(M > 0, stat, 1)\n", | |
" else:\n", | |
" return stat" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [], | |
"prompt_number": 9 | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [] | |
} | |
], | |
"metadata": {} | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment