Skip to content

Instantly share code, notes, and snippets.

@ansjsun
Created May 10, 2014 12:00
Show Gist options
  • Save ansjsun/b4d5e80cca99b05d1903 to your computer and use it in GitHub Desktop.
Save ansjsun/b4d5e80cca99b05d1903 to your computer and use it in GitHub Desktop.
python 梯度线性回归
import math ;
import matplotlib.pyplot as plt
x_list = [2.0658746, 2.3684087, 2.5399929, 2.5420804, 2.549079, 2.7866882, 2.9116825, 3.035627, 3.1146696, 3.1582389, 3.3275944, 3.3793165, 3.4122006, 3.4215823, 3.5315732, 3.6393002, 3.6732537, 3.9256462, 4.0498646, 4.2483348, 4.3440052, 4.3826531, 4.4230602, 4.6102443, 4.6881183, 4.9777333, 5.0359967, 5.0684536, 5.4161491, 5.4395623, 5.4563207, 5.5698458, 5.6015729, 5.6877617, 5.7215602, 5.8538914, 6.1978026, 6.3510941, 6.4797033, 6.7383791, 6.8637686, 7.0223387, 7.0782373, 7.1514232, 7.4664023, 7.5973874, 7.7440717, 7.7729662, 7.8264514, 7.9306356]
y_list = [0.77918926, 0.91596757, 0.90538354, 0.90566138, 0.9389889, 0.9668474, 0.96436824, 0.91445939, 0.93933944, 0.96074971, 0.89837094, 0.91209739, 0.94238499, 0.96624578, 1.05265, 1.0143791, 0.95969426, 0.96853716, 1.0766065, 1.1454978, 1.0340625, 1.0070009, 0.96683648, 1.0895919, 1.0634462, 1.1237239, 1.0323374, 1.0874452, 1.0702988, 1.1606493, 1.0778037, 1.1069758, 1.0971875, 1.1648603, 1.1411796, 1.0844156, 1.1252493, 1.1168341, 1.1970789, 1.2069462, 1.1251046, 1.1235672, 1.2132829, 1.2522652, 1.2497065, 1.1799706, 1.1897299, 1.3029934, 1.2601134, 1.2562267]
for i in range(len(x_list)):
plt.plot(x_list[i],y_list[i],"*") ;
MAX_ITR = 1000;
alpha = 0.001;
theta0 = 0.0 ;
theta1 = 0.0 ;
m = len(x_list)
#y = theta0 + theta1*x
def fun():
global theta0,theta1,alpha
print theta0
for i in range(m):
grad = theta0 + theta1*x_list[i]- y_list[i];
theta1 = theta1 - alpha*grad*x_list[i];
theta0 = theta0 - alpha*grad ;
for i in range(MAX_ITR) :
fun() ;
print "ite ",i , theta0 , theta1;
v = 0.0 ;
for i in range(m):
v = v + theta0 + theta1*x_list[i]- y_list[i];
print v,theta0 ,theta1 ;
max_x = max(x_list)+3 ;
x1 = 2 ;
y1 = x1*theta1+theta0 ;
x2 = 9 ;
y2 = x2*theta1+theta0 ;
plt.plot([x1,x2],[y1,y2])
plt.show();
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment