Created
July 21, 2020 01:17
-
-
Save JeanOlivier/597679307ce0517378bafd364cb23579 to your computer and use it in GitHub Desktop.
Fit many functions/data at once with individual and shared parameters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/bin/python | |
# -*- coding: utf-8 -*- | |
# from pylab import * # Might be required by the loading script | |
from numpy import array, sqrt, diag, concatenate, mean, size, ndarray, iscomplex | |
from scipy.optimize import leastsq | |
class nlfit: | |
''' | |
Description to come | |
''' | |
def __init__(self, xs, ys, p0, fs, pmasks = None, fullo=1, xerrs=None, yerrs = None, verbose = True): | |
#if not (isinstance(xs,list) or (isinstance(xs, ndarray) and len(xs)==1): | |
xs = array(xs) | |
ys = array(ys) | |
if len(xs.shape)==1: | |
xs, ys, fs = [xs], [ys], [fs] | |
# Abscisse | |
self.xs = array(xs) | |
# xerrors not implemented yet (would require odr) | |
if xerrs is None: | |
self.xerrs = [[1.]*len(x) for x in self.xs] | |
else: | |
self.xerrs = array(xerrs) | |
# Ordonnée | |
self.ys = array(ys) | |
if yerrs is None: | |
self.yerrs = [[1.]*len(y) for y in self.ys] | |
else: | |
self.yerrs = array(yerrs) | |
if pmasks is None: | |
pmasks = [[1]*len(p0)]*len(xs) | |
#for i in [ys, fs, pmasks, self.xerrs, self.yerrs]: | |
# print(len(i)) | |
if any(len(i) != len(xs) for i in [ys, fs, pmasks, self.xerrs, self.yerrs]): | |
raise(AssertionError,"List size don't match") | |
self._p0 = p0 | |
self.para = p0 | |
self.fs = fs | |
self.pmasks = pmasks | |
#self.scales = [mean(abs(y)) for y in ys] | |
self.scales = [y.max()-y.min() for y in ys] | |
self.fullo = fullo | |
self.verbose = verbose | |
def __call__(self, *args): | |
if len(args) == 1: | |
return self.__custom_call__(0,*args) | |
else: | |
return self.__custom_call__(*args) | |
def __custom_call__(self, fct_num, x): | |
return self.fs[fct_num](x,self._mask(self.para,self.pmasks[fct_num])) | |
def __getitem__(self,i): | |
return self.para[i] | |
def __len__(self): | |
return len(self.para) | |
def _mask(self, data, mask): | |
return [i for i,j in zip(data,mask) if j] | |
def _ps(self,p): | |
return [self._mask(p,mask) for mask in self.pmasks] | |
def _residuals(self, y, f, x, p, yerr, scale): | |
tmp = (y - f(x,p))/yerr/scale | |
#if tmp.dtype == 'complex128': | |
# tmp = tmp.real**2+tmp.imag**2 | |
return tmp | |
def _residuals_global(self, p): | |
errs = [self._residuals(y,f,x,mp,yerr,scale) \ | |
for y,f,x,mp,yerr,scale in zip(self.ys, self.fs, self.xs, | |
self._ps(p), self.yerrs, self.scales)] | |
return concatenate(errs) | |
def leastsq(self, **kwargs): | |
self.lsq = leastsq(self._residuals_global, self.para, full_output=self.fullo, **kwargs) | |
if self.lsq[1] is None: | |
if self.verbose: print('\n --- FIT DID NOT CONVERGE ---\n') | |
self.errs = None | |
self.err = None | |
self.chi2rs = None | |
return False | |
else: | |
self.para = self.lsq[0] | |
self.cv = self.lsq[1] | |
self.it = self.lsq[2]['nfev'] | |
self.computevalues() | |
self.errs = array([self.sdcv*sqrt(chi2r) for chi2r in self.chi2rs]) | |
self.err = self.errs[0] | |
if self.verbose: | |
print(self) | |
return True | |
def computevalues(self): | |
self.sdcv = sqrt(diag(self.cv)) | |
# Matrice de corrélation | |
self.corrM = self.cv/self.sdcv/self.sdcv[:,None] | |
self.chi2s = [sum(self._residuals(y,f,x,mp,yerr,scale)**2) \ | |
for y,f,x,mp,yerr,scale in zip(self.ys, self.fs, self.xs, self._ps(self.para), | |
self.yerrs, self.scales)] | |
# Chi^2 réduit | |
self.chi2rs = [chi2/(len(y)-len(self.para)) for chi2,y in zip(self.chi2s, self.ys)] | |
def __str__(self): | |
s = '\n--- FIT ON FUNCTION{} {} ---'+\ | |
'\n\nFit parameters are\n{}\nFit errors are\n{}\n\nFit covariance\n{}'+\ | |
'\nFit correlation matrix\n{}\nReduced chi2s are {}\n\n' | |
fmt = ['S' if len(self.xs)>1 else '',', '.join([f.__name__ for f in self.fs]), | |
self.para, self.errs, self.cv, self.corrM, self.chi2rs] | |
tmp = fmt[1].rfind(', ') | |
if not tmp == -1: | |
fmt[1] = fmt[1][:tmp] + ' and ' + fmt[1][tmp+2:] | |
return s.format(*fmt) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment