Created
June 3, 2012 23:37
-
-
Save DM-/2865428 to your computer and use it in GitHub Desktop.
Gradient descent
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #gradient ascent test | |
| def linear(theta0,theta1,x): | |
| return theta0+theta1*x | |
| def cost(theta0,theta1,xm): | |
| L=0 | |
| for i in xm: | |
| Z=0 | |
| Z= linear(theta0,theta1,i[0])-i[1] | |
| Z*=Z | |
| L+=Z | |
| L *=(1.0/(2*len(xm))) | |
| return L | |
| def dcost(theta0,theta1,xm): | |
| L=0 | |
| for i in xm: | |
| Z=0 | |
| Z= linear(theta0,theta1,i[0])-i[1] | |
| L+=Z | |
| L*= 1.0/(len(xm)) | |
| return L | |
| def dcost2(theta0,theta1,xm): | |
| L=0 | |
| for i in xm: | |
| Z=0 | |
| Z= linear(theta0,theta1,i[0])-i[1] | |
| L+=(Z*i[0]) | |
| L*= 1.0/(len(xm)) | |
| return L | |
| def grd(xm,alpha=0,theta0=1,theta1=1): | |
| alpha=0.01 | |
| while cost(theta0,theta1,xm) > 0.0000000000000001: | |
| theta0x=theta0-alpha*dcost(theta0,theta1,xm) | |
| theta1x=theta1-alpha*dcost2(theta0,theta1,xm) | |
| theta0,theta1=theta0x,theta1x | |
| return theta0,theta1 | |
| def quack(z): | |
| l=[] | |
| for i in xrange(20): | |
| q=[i] | |
| q.append(z(i)) | |
| l.append(q) | |
| return l | |
| # 0,2 | |
| q=[[0, 0], [1, 2], [2, 4], [3, 6], [4, 8], [5, 10], [6, 12], [7, 14], [8, 16], [9, 18], [10, 20], [11, 22], [12, 24], [13, 26], [14, 28], [15, 30], [16, 32], [17, 34], [18, 36], [19, 38]] | |
| # 4,5 | |
| l=[[0, 4], [1, 9], [2, 14], [3, 19], [4, 24], [5, 29], [6, 34], [7, 39], [8, 44], [9, 49], [10, 54], [11, 59], [12, 64], [13, 69], [14, 74], [15, 79], [16, 84], [17, 89], [18, 94], [19, 99]] | |
| # 20,-10 | |
| p=[[0, 20], [1, 10], [2, 0], [3, -10], [4, -20], [5, -30], [6, -40], [7, -50], [8, -60], [9, -70], [10, -80], [11, -90], [12, -100], [13, -110], [14, -120], [15, -130], [16, -140], [17, -150], [18, -160], [19, -170]] | |
| # 300,25 | |
| m=[[0, 300], [1, 325], [2, 350], [3, 375], [4, 400], [5, 425], [6, 450], [7, 475], [8, 500], [9, 525], [10, 550], [11, 575], [12, 600], [13, 625], [14, 650], [15, 675], [16, 700], [17, 725], [18, 750], [19, 775]] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment