Last active
December 8, 2022 10:35
-
-
Save azelcer/39c0b8acbcd18576392bb9b841a7a4d9 to your computer and use it in GitHub Desktop.
Shows a way to use multiprocessing and lmfit.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Shows a way to use multiprocessing and lmfit. | |
This is a dummy sample showing how to perform parallel fits using lmfit and | |
multiprocessing. There are other options that does not use map. | |
It is important to keep in mind that, as the minimization is performed in | |
separate processes, effects are not visible in the main process. | |
Note that the cost of spawning new processes and pickling usually outweighs the | |
benefits of multiprocessing. In this example, the lenght of each dataset and | |
the number of datasets can be modified to check this asseveration. | |
@author: azelcer | |
""" | |
from numpy import exp, sin, linspace, random | |
from time import time | |
from multiprocessing import Pool, set_start_method, get_start_method | |
from lmfit import Parameters | |
from lmfit.minimizer import Minimizer | |
from matplotlib import pyplot as plt | |
# This is the model function | |
def model(x, params): | |
amp = params['amp'] | |
phaseshift = params['phase'] | |
freq = params['frequency'] | |
decay = params['decay'] | |
return amp * sin(x*freq + phaseshift) * exp(-(x**2)*decay) | |
# And this is the residual calculation | |
def residual(params, x, data, error_scale): | |
return (data - model(x, params)) / error_scale | |
# This function performs the minimization. It will run on a different process | |
# changes to global variables or objects made inside this function will not be | |
# visible to the caller: only the data, params and the return value (a lmfit | |
# MinimizerResult) will be pickled. | |
def do_fit(x, data, initial_params, error_scale): | |
mnz = Minimizer(residual, initial_params, fcn_args=(x, data, error_scale)) | |
out = mnz.minimize() | |
return out | |
if __name__ == "__main__": | |
# This guard is here so the program can be run more than once | |
if get_start_method(allow_none=True) is None: | |
set_start_method('spawn') # test 'worst' case, 'spawn' is the only | |
# method available in windows | |
# Number of datasets to create | |
n_samples = 4 | |
# Let's create an array of simulated data | |
error_scale = 0.2 | |
x = linspace(0, 100, num=150000) # with 1.5E6 points parallel fits might be worth | |
noises = [random.normal(size=x.size, scale=error_scale) for _ in range(n_samples)] | |
amplitudes = linspace(4, 8, n_samples) | |
frequencies = linspace(.5, 1.7, n_samples) | |
phases = linspace(0, 2, n_samples) | |
decays = linspace(1E-4, 1E-3, n_samples) | |
all_data = [amp * sin(x*freq + phase) * exp(-(x**2)*decay) + noise for amp, | |
phase, freq, decay, noise in zip(amplitudes, phases, frequencies, | |
decays, noises)] | |
# Create the starting parameters (same guess for all the datasets) | |
params = Parameters() | |
params.add('amp', value=10) | |
params.add('decay', value=0.007) | |
params.add('phase', value=0.2) | |
params.add('frequency', value=1.0) | |
# Perform the fittings sequentially | |
print("Starting sequential fitting...") | |
start_time = time() | |
sequential_results = [do_fit(x, d, params, error_scale) for d in all_data] | |
print(" took", time()-start_time, "seconds") | |
# Perform the fittings in parallel | |
print("Starting parallel fitting...") | |
start_time = time() | |
with Pool() as p: | |
parallel_results = p.starmap(do_fit, zip((x, ) * n_samples, all_data, | |
(params, ) * n_samples, | |
(error_scale, ) * n_samples)) | |
print(" took", time()-start_time, "seconds") | |
for i in range(n_samples): | |
plt.figure("Dataset #" + str(i)) | |
plt.plot(x, all_data[i], label="Data") | |
plt.plot(x, model(x, sequential_results[i].params), label="sequential fit") | |
plt.scatter(x, model(x, parallel_results[i].params), label="parallel fit") | |
plt.legend() |
Well, I updated it again. It all depends on how the OS starts a new process. The 'spawn' method (the only one currently available in windows) reloads (and re-executes) the file, but on those processes the value of __name__
is not "__main__"
but something like "__mp_main__"
. So,the guard avoids executing that block again in the new processes (which I guess would lead to some ugly infinite recursion)
finally works!, thank you azelcer. I just wanted to write you, that it is not working :D,but then i realized, that when running from spyder with run cell command (cells can be made as #%%) it returns RuntimeError. But when i run it as a whole file from spyder it works... I wonder why spyder makes it, anyway it works now...
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
hey azelcer thanks for your effort on this, i guess i need to completely understand what is happening in here and what main guard is doing, because even with this variation i do not get pass "starting parallel fitting..." now it just hangs there for over 2 min, than i have closed terminal. I am surprised that OS differences make such pain, seems like there are different defaults how is the process started depending on OS.