Skip to content

Instantly share code, notes, and snippets.

@jagaudin
Created April 8, 2020 09:56
Show Gist options
  • Save jagaudin/9c281e42d0b629035f57a65c89c949a0 to your computer and use it in GitHub Desktop.
Save jagaudin/9c281e42d0b629035f57a65c89c949a0 to your computer and use it in GitHub Desktop.
import matplotlib.pyplot as plt
from scipy import interpolate
from scipy.optimize import curve_fit
import numpy as np
def func(r, kb, b0):
return kb * (r - b0)**2
# This function generates total functions for a given number of objects
def total_factory(n):
def total(r, *args):
sum_int = 0
kb = np.array(args[:n**2]).reshape((n, n))
b0 = np.array(args[n**2:]).reshape((n, n))
for i in range(0, n):
for j in range(0, n):
if j == i:
interaction = 0 # no self interaction
else:
interaction = func(r, kb[i, j], b0[i, j])
sum_int += interaction
return sum_int
return total
xdata = [1.3,1.35,1.4,1.45,1.5,1.55,1.6,1.65,1.7]
ydata = [-136.82,-164.87,-181.16,-188.53,-189.10,-184.49,-175.96,-164.51,-150.95]
x = np.array(xdata)
y = np.array(ydata)
plt.plot(x, y, 'bo', label='data')
# Need more data to determine 32 parameters
f = interpolate.interp1d(xdata, ydata)
x = np.arange(xdata[0], xdata[-1], 0.01)
y = f(x)
plt.plot(x, y, 'r', label='interpolated')
# Initialize the parameters and create total function for n = 4
n = 4
kb = np.ones((n, n))
b0 = np.zeros((n, n))
total = total_factory(n)
popt, pcov = curve_fit(total, x, y, p0=(kb, b0), maxfev=100000)
plt.plot(xdata, total(xdata, *popt), 'g-', label='curve fit')
plt.legend()
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment