Skip to content

Instantly share code, notes, and snippets.

@JeanOlivier
Created July 21, 2020 01:17
Show Gist options
  • Save JeanOlivier/597679307ce0517378bafd364cb23579 to your computer and use it in GitHub Desktop.
Save JeanOlivier/597679307ce0517378bafd364cb23579 to your computer and use it in GitHub Desktop.
Fit many functions/data at once with individual and shared parameters
#!/bin/python
# -*- coding: utf-8 -*-
# from pylab import * # Might be required by the loading script
from numpy import array, sqrt, diag, concatenate, mean, size, ndarray, iscomplex
from scipy.optimize import leastsq
class nlfit:
'''
Description to come
'''
def __init__(self, xs, ys, p0, fs, pmasks = None, fullo=1, xerrs=None, yerrs = None, verbose = True):
#if not (isinstance(xs,list) or (isinstance(xs, ndarray) and len(xs)==1):
xs = array(xs)
ys = array(ys)
if len(xs.shape)==1:
xs, ys, fs = [xs], [ys], [fs]
# Abscisse
self.xs = array(xs)
# xerrors not implemented yet (would require odr)
if xerrs is None:
self.xerrs = [[1.]*len(x) for x in self.xs]
else:
self.xerrs = array(xerrs)
# Ordonnée
self.ys = array(ys)
if yerrs is None:
self.yerrs = [[1.]*len(y) for y in self.ys]
else:
self.yerrs = array(yerrs)
if pmasks is None:
pmasks = [[1]*len(p0)]*len(xs)
#for i in [ys, fs, pmasks, self.xerrs, self.yerrs]:
# print(len(i))
if any(len(i) != len(xs) for i in [ys, fs, pmasks, self.xerrs, self.yerrs]):
raise(AssertionError,"List size don't match")
self._p0 = p0
self.para = p0
self.fs = fs
self.pmasks = pmasks
#self.scales = [mean(abs(y)) for y in ys]
self.scales = [y.max()-y.min() for y in ys]
self.fullo = fullo
self.verbose = verbose
def __call__(self, *args):
if len(args) == 1:
return self.__custom_call__(0,*args)
else:
return self.__custom_call__(*args)
def __custom_call__(self, fct_num, x):
return self.fs[fct_num](x,self._mask(self.para,self.pmasks[fct_num]))
def __getitem__(self,i):
return self.para[i]
def __len__(self):
return len(self.para)
def _mask(self, data, mask):
return [i for i,j in zip(data,mask) if j]
def _ps(self,p):
return [self._mask(p,mask) for mask in self.pmasks]
def _residuals(self, y, f, x, p, yerr, scale):
tmp = (y - f(x,p))/yerr/scale
#if tmp.dtype == 'complex128':
# tmp = tmp.real**2+tmp.imag**2
return tmp
def _residuals_global(self, p):
errs = [self._residuals(y,f,x,mp,yerr,scale) \
for y,f,x,mp,yerr,scale in zip(self.ys, self.fs, self.xs,
self._ps(p), self.yerrs, self.scales)]
return concatenate(errs)
def leastsq(self, **kwargs):
self.lsq = leastsq(self._residuals_global, self.para, full_output=self.fullo, **kwargs)
if self.lsq[1] is None:
if self.verbose: print('\n --- FIT DID NOT CONVERGE ---\n')
self.errs = None
self.err = None
self.chi2rs = None
return False
else:
self.para = self.lsq[0]
self.cv = self.lsq[1]
self.it = self.lsq[2]['nfev']
self.computevalues()
self.errs = array([self.sdcv*sqrt(chi2r) for chi2r in self.chi2rs])
self.err = self.errs[0]
if self.verbose:
print(self)
return True
def computevalues(self):
self.sdcv = sqrt(diag(self.cv))
# Matrice de corrélation
self.corrM = self.cv/self.sdcv/self.sdcv[:,None]
self.chi2s = [sum(self._residuals(y,f,x,mp,yerr,scale)**2) \
for y,f,x,mp,yerr,scale in zip(self.ys, self.fs, self.xs, self._ps(self.para),
self.yerrs, self.scales)]
# Chi^2 réduit
self.chi2rs = [chi2/(len(y)-len(self.para)) for chi2,y in zip(self.chi2s, self.ys)]
def __str__(self):
s = '\n--- FIT ON FUNCTION{} {} ---'+\
'\n\nFit parameters are\n{}\nFit errors are\n{}\n\nFit covariance\n{}'+\
'\nFit correlation matrix\n{}\nReduced chi2s are {}\n\n'
fmt = ['S' if len(self.xs)>1 else '',', '.join([f.__name__ for f in self.fs]),
self.para, self.errs, self.cv, self.corrM, self.chi2rs]
tmp = fmt[1].rfind(', ')
if not tmp == -1:
fmt[1] = fmt[1][:tmp] + ' and ' + fmt[1][tmp+2:]
return s.format(*fmt)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment