Skip to content

Instantly share code, notes, and snippets.

@bagrow
Created December 10, 2015 13:44
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save bagrow/0f46aab8ec4ac608c0ee to your computer and use it in GitHub Desktop.
Save bagrow/0f46aab8ec4ac608c0ee to your computer and use it in GitHub Desktop.
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# exhaustive_search_linearRegression.py
# Jim Bagrow
# Last Modified: 2015-10-05
import sys, os
import numpy as np
import scipy, scipy.stats
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import axes3d
if __name__ == '__main__':
# our data:
X = [0,1,2,3,4,5,6,7,8]
Y = [19, 20, 20.5, 21.5, 22, 23, 23, 25.5, 24]
# parameters to search:
num_pts_to_try = 20
possible_b1_vals = np.linspace(0, 1, num_pts_to_try) # slope
possible_b2_vals = np.linspace(15,25,num_pts_to_try) # intercept
def f(x,b1,b2):
return b1*x + b2
record = []
min_S = 1000000
min_b = (None,None)
for b1 in possible_b1_vals:
for b2 in possible_b2_vals:
S = sum([ (f(xi, b1,b2)-yi)**2 for xi,yi in zip(X,Y) ])
#if S < 200:
record.append((b1,b2,S))
if S < min_S:
min_S = S
min_b = (b1,b2)
print
print min_S, min_b
# 3d scatter plot of S = f(beta_1,beta_2)
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
xr,yr,zr = zip(*record)
ax.scatter(xr,yr,zr)
ax.set_xlabel(r"$\beta_1$", fontsize=22)
ax.set_ylabel(r"$\beta_2$", fontsize=22)
ax.set_zlabel(r"S", fontsize=22)
# add vertical column at minimum
zL,zU = ax.get_zlim()
b1s,b2s = min_b
ax.plot([b1s,b1s], [b2s,b2s], [zL,zU], 'r-', linewidth=3)
#ax.plot([slope,slope], [intercept,intercept], [zL,zU], 'g-', linewidth=3)
ax.set_zlim(0,200)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment