Skip to content

Instantly share code, notes, and snippets.

@Dapid
Created March 1, 2023 13:09
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Dapid/1da960739b9e006f41a962607f6b1c54 to your computer and use it in GitHub Desktop.
Save Dapid/1da960739b9e006f41a962607f6b1c54 to your computer and use it in GitHub Desktop.
import numpy as np
import scipy.optimize as spo
import scipy.stats as sps
import matplotlib.pyplot as plt
x, y = np.array([[3.16275414, 3.79136358],
[3.06332232, 3.56686702],
[2.71045949, 3.65764056],
[3.31620986, 3.9009491 ],
[3.0538026 , 3.77374607],
[2.65205418, 3.46548462],
[3.37982853, 3.87501873],
[3.17203268, 3.6633317 ],
[3.05775755, 3.75619677],
[2.91007762, 3.43503791],
[3.11535693, 3.69096722],
[3.39581938, 3.74696218],
[3.13272406, 3.42290705],
[2.78470603, 3.53131451],
[3.27740626, 3.84695112],
[2.85441284, 3.56755741],
[2.78566755, 3.38201978],
[3.24804465, 3.70167436],
[3.43091588, 3.75821148],
[3.54227291, 3.81811405],
[3.13807594, 3.67988712],
[3.46801239, 3.89121322],
[3.34413811, 3.81356879],
[2.88571352, 3.389147 ],
[3.46944009, 4.04670213],
[3.24443329, 3.66727863],
[3.19430041, 3.62833333],
[3.33884561, 3.69026204],
[3.14456693, 3.618733 ],
[2.86876214, 3.72446602],
[3.10657774, 3.75643054],
[3.58450835, 3.85183673],
[3.24242311, 3.80806151],
[2.69351248, 3.39682342],
[3.08693798, 3.69868217],
[3.21788634, 3.76964108],
[3.83225025, 4.05771948],
[3.12537788, 3.82863089],
[3.20272503, 3.49631319],
[3.06877493, 3.71745014],
[2.67606864, 3.61357254],
[3.1096587 , 3.62928328],
[3.34974315, 3.92875 ],
[3.26829268, 3.80015679],
[3.38762994, 3.55692308],
[3.32415556, 3.77868889],
[3.27641221, 3.646743 ],
[3.18084615, 3.95869231],
[3.42302139, 3.9386631 ],
[3.67470085, 4.11934473],
[3.31432177, 3.80624606],
[3.19971609, 3.59706625],
[2.97818182, 3.603367 ],
[2.97725424, 3.64698305],
[3.47149813, 3.71947566],
[2.82501901, 3.4169962 ],
[3.17886555, 3.75096639],
[2.88793991, 3.57669528],
[3.32315789, 3.99167464],
[3.36554404, 3.96227979],
[2.83045455, 3.48659091],
[3.60608187, 3.97976608],
[3.00285714, 3.69065476],
[3.26709877, 3.86067901],
[3.88263514, 4.1775 ],
[3.2752381 , 3.9162585 ],
[3.40239726, 3.82006849],
[3.54143939, 3.98871212],
[3.46833333, 3.8315 ],
[2.96941667, 3.68858333],
[2.86350427, 3.95418803],
[3.10672727, 3.65509091],
[3.30963636, 3.79818182],
[2.79913462, 3.48932692],
[3.10911765, 3.92911765]]).T
fake_data = False
if fake_data:
x = np.random.uniform(2.91, 3.2, size=10)
y = 2.114798570871842 + x * 0.5088641205763367
x_ = np.arange(0.46, 0.54, 0.001)
y_ = np.arange(2, 2.26, 0.001)
else:
x_ = np.arange(0.46, 0.52, 0.001)
y_ = np.arange(2.1, 2.26, 0.001)
X, Y = np.meshgrid(x_, y_, sparse=True)
def calc_sse(c_s, x, y):
errors = y - (c_s[0] + x * c_s[1])
return np.sum(errors ** 2)
def loss_2d(slope, intercept, x, y):
pred = slope[:, :, None] * x[None, None, :] + intercept[:, :, None]
error = pred - y[None, None, :]
return np.log(np.sum(error * error, axis=-1))
loss = loss_2d(X, Y, x, y)
ls_res = sps.linregress(x, y)
print(f'LS inter: {ls_res.intercept}')
print(f'LS slope: {ls_res.slope}')
ls_sse = calc_sse((ls_res.intercept, ls_res.slope), x, y)
print(f'LS SSE error: {ls_sse}')
start = [2.25, 0.47]
print('\nBGFS minimization:')
print(spo.fmin_bfgs(calc_sse, start, args=(x, y)))
print('\nPowell minimization:')
print(spo.fmin_powell(calc_sse, start, args=(x, y)))
traj_powell = np.array(spo.fmin_powell(calc_sse, start, args=(x, y), retall=True)[1])
traj_bfgs = np.array(spo.fmin_bfgs(calc_sse, start, args=(x, y), retall=True)[1])
plt.figure()
plt.imshow(loss.T, aspect='auto', interpolation='none', extent=( y_.min(), y_.max(),x_.min(), x_.max(),), origin='lower')
plt.plot(*traj_powell.T, label='Powell', marker='+', color='xkcd:gold')
plt.plot(*traj_bfgs.T, label='BFGS', marker='+', color='xkcd:bluish')
plt.scatter([start[0]], [start[1]], color='xkcd:orange', marker='^', label='Start')
plt.scatter( [ls_res.intercept],[ls_res.slope], color='xkcd:leaf green', marker='v', label='Optimum', s=50)
plt.xlabel('Intercept')
plt.ylabel('Slope')
plt.legend()
plt.legend()
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment