Last active
April 26, 2017 09:50
-
-
Save mikigom/c083b8c6f0bcc1540adbb1741263f148 to your computer and use it in GitHub Desktop.
Project Nagne DL semina code
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"#The optimal values of m and b can be actually calculated with way less effort than doing a linear regression. \n", | |
"#this is just to demonstrate gradient descent\n", | |
"\n", | |
"from numpy import *" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"# y = mx + b\n", | |
"# m is slope, b is y-intercept\n", | |
"def compute_error_for_line_given_points(b, m, points):\n", | |
" totalError = 0\n", | |
" for i in range(0, len(points)):\n", | |
" x = points[i, 0]\n", | |
" y = points[i, 1]\n", | |
" totalError += (y - (m * x + b)) ** 2\n", | |
" return totalError / float(len(points))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"def step_gradient(b_current, m_current, points, learningRate):\n", | |
" b_gradient = 0\n", | |
" m_gradient = 0\n", | |
" N = float(len(points))\n", | |
" for i in range(0, len(points)):\n", | |
" x = points[i, 0]\n", | |
" y = points[i, 1]\n", | |
" b_gradient += -(2/N) * (y - ((m_current * x) + b_current))\n", | |
" m_gradient += -(2/N) * x * (y - ((m_current * x) + b_current))\n", | |
" new_b = b_current - (learningRate * b_gradient)\n", | |
" new_m = m_current - (learningRate * m_gradient)\n", | |
" return [new_b, new_m]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"def gradient_descent_runner(points, starting_b, starting_m, learning_rate, num_iterations):\n", | |
" b = starting_b\n", | |
" m = starting_m\n", | |
" for i in range(num_iterations):\n", | |
" print(\"... b = {0}, m = {1}, error = {2}\".format(b, m, compute_error_for_line_given_points(b, m, points)))\n", | |
" b, m = step_gradient(b, m, array(points), learning_rate)\n", | |
" return [b, m]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"def run():\n", | |
" points = genfromtxt(\"linear_regression_data.csv\", delimiter=\",\")\n", | |
" learning_rate = 0.0001\n", | |
" initial_b = 0 # initial y-intercept guess\n", | |
" initial_m = 0 # initial slope guess\n", | |
" num_iterations = 100\n", | |
" print \"Starting gradient descent at b = {0}, m = {1}, error = {2}\".format(initial_b, initial_m, compute_error_for_line_given_points(initial_b, initial_m, points))\n", | |
" print \"Running...\"\n", | |
" [b, m] = gradient_descent_runner(points, initial_b, initial_m, learning_rate, num_iterations)\n", | |
" print \"After {0} iterations b = {1}, m = {2}, error = {3}\".format(num_iterations, b, m, compute_error_for_line_given_points(b, m, points))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": { | |
"collapsed": false, | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Starting gradient descent at b = 0, m = 0, error = 5565.10783448\n", | |
"Running...\n", | |
"... b = 0, m = 0, error = 5565.10783448\n", | |
"... b = 0.0145470101107, m = 0.737070297359, error = 1484.58655741\n", | |
"... b = 0.0218739629596, m = 1.10679545435, error = 457.854257574\n", | |
"... b = 0.0255792243213, m = 1.29225466491, error = 199.509985726\n", | |
"... b = 0.0274677895591, m = 1.38528325565, error = 134.505910582\n", | |
"... b = 0.0284450719817, m = 1.43194723238, error = 118.149693422\n", | |
"... b = 0.0289652407665, m = 1.4553540089, error = 114.03414906\n", | |
"... b = 0.029256114126, m = 1.46709461772, error = 112.998577317\n", | |
"... b = 0.0294319691638, m = 1.47298329822, error = 112.737981876\n", | |
"... b = 0.0295501290244, m = 1.4759365619, error = 112.672384359\n", | |
"... b = 0.0296393478747, m = 1.47741737555, error = 112.655851815\n", | |
"... b = 0.0297140492452, m = 1.47815958573, error = 112.651664898\n", | |
"... b = 0.0297814681995, m = 1.47853130111, error = 112.650584362\n", | |
"... b = 0.0298452339563, m = 1.47871717063, error = 112.650285447\n", | |
"... b = 0.0299071669873, m = 1.47880981703, error = 112.650183203\n", | |
"... b = 0.0299681804689, m = 1.47885570128, error = 112.650130445\n", | |
"... b = 0.0300287324645, m = 1.47887812893, error = 112.650090139\n", | |
"... b = 0.0300890527455, m = 1.47888879039, error = 112.650052967\n", | |
"... b = 0.0301492565689, m = 1.47889354976, error = 112.650016584\n", | |
"... b = 0.0302094017495, m = 1.47889534855, error = 112.649980399\n", | |
"... b = 0.0302695172878, m = 1.47889566228, error = 112.649944265\n", | |
"... b = 0.0303296177311, m = 1.47889523108, error = 112.649908144\n", | |
"... b = 0.0303897103765, m = 1.47889442622, error = 112.649872027\n", | |
"... b = 0.0304497988843, m = 1.47889343392, error = 112.649835911\n", | |
"... b = 0.0305098850906, m = 1.47889234761, error = 112.649799796\n", | |
"... b = 0.0305699699165, m = 1.47889121415, error = 112.649763681\n", | |
"... b = 0.0306300538238, m = 1.47889005704, error = 112.649727567\n", | |
"... b = 0.0306901370445, m = 1.47888888807, error = 112.649691454\n", | |
"... b = 0.0307502196946, m = 1.47888771316, error = 112.649655341\n", | |
"... b = 0.0308103018326, m = 1.47888653527, error = 112.649619228\n", | |
"... b = 0.0308703834876, m = 1.47888535589, error = 112.649583117\n", | |
"... b = 0.0309304646744, m = 1.47888417577, error = 112.649547005\n", | |
"... b = 0.0309905454003, m = 1.47888299528, error = 112.649510895\n", | |
"... b = 0.031050625669, m = 1.47888181461, error = 112.649474784\n", | |
"... b = 0.0311107054824, m = 1.47888063386, error = 112.649438675\n", | |
"... b = 0.0311707848414, m = 1.47887945306, error = 112.649402566\n", | |
"... b = 0.0312308637464, m = 1.47887827226, error = 112.649366457\n", | |
"... b = 0.0312909421978, m = 1.47887709144, error = 112.649330349\n", | |
"... b = 0.0313510201956, m = 1.47887591064, error = 112.649294242\n", | |
"... b = 0.0314110977398, m = 1.47887472983, error = 112.649258135\n", | |
"... b = 0.0314711748306, m = 1.47887354904, error = 112.649222028\n", | |
"... b = 0.031531251468, m = 1.47887236825, error = 112.649185922\n", | |
"... b = 0.0315913276518, m = 1.47887118747, error = 112.649149817\n", | |
"... b = 0.0316514033823, m = 1.4788700067, error = 112.649113712\n", | |
"... b = 0.0317114786593, m = 1.47886882594, error = 112.649077608\n", | |
"... b = 0.0317715534828, m = 1.47886764519, error = 112.649041505\n", | |
"... b = 0.031831627853, m = 1.47886646444, error = 112.649005401\n", | |
"... b = 0.0318917017697, m = 1.47886528371, error = 112.648969299\n", | |
"... b = 0.0319517752329, m = 1.47886410298, error = 112.648933197\n", | |
"... b = 0.0320118482428, m = 1.47886292227, error = 112.648897095\n", | |
"... b = 0.0320719207993, m = 1.47886174156, error = 112.648860994\n", | |
"... b = 0.0321319929023, m = 1.47886056086, error = 112.648824894\n", | |
"... b = 0.0321920645519, m = 1.47885938017, error = 112.648788794\n", | |
"... b = 0.0322521357481, m = 1.47885819949, error = 112.648752695\n", | |
"... b = 0.032312206491, m = 1.47885701882, error = 112.648716596\n", | |
"... b = 0.0323722767804, m = 1.47885583815, error = 112.648680498\n", | |
"... b = 0.0324323466164, m = 1.4788546575, error = 112.6486444\n", | |
"... b = 0.032492415999, m = 1.47885347685, error = 112.648608303\n", | |
"... b = 0.0325524849283, m = 1.47885229622, error = 112.648572207\n", | |
"... b = 0.0326125534042, m = 1.47885111559, error = 112.648536111\n", | |
"... b = 0.0326726214266, m = 1.47884993497, error = 112.648500015\n", | |
"... b = 0.0327326889957, m = 1.47884875436, error = 112.64846392\n", | |
"... b = 0.0327927561115, m = 1.47884757376, error = 112.648427826\n", | |
"... b = 0.0328528227738, m = 1.47884639317, error = 112.648391732\n", | |
"... b = 0.0329128889828, m = 1.47884521259, error = 112.648355639\n", | |
"... b = 0.0329729547384, m = 1.47884403201, error = 112.648319546\n", | |
"... b = 0.0330330200407, m = 1.47884285145, error = 112.648283454\n", | |
"... b = 0.0330930848896, m = 1.47884167089, error = 112.648247362\n", | |
"... b = 0.0331531492852, m = 1.47884049034, error = 112.648211271\n", | |
"... b = 0.0332132132274, m = 1.4788393098, error = 112.64817518\n", | |
"... b = 0.0332732767162, m = 1.47883812928, error = 112.64813909\n", | |
"... b = 0.0333333397517, m = 1.47883694875, error = 112.648103001\n", | |
"... b = 0.0333934023339, m = 1.47883576824, error = 112.648066912\n", | |
"... b = 0.0334534644627, m = 1.47883458774, error = 112.648030824\n", | |
"... b = 0.0335135261382, m = 1.47883340725, error = 112.647994736\n", | |
"... b = 0.0335735873604, m = 1.47883222676, error = 112.647958648\n", | |
"... b = 0.0336336481292, m = 1.47883104628, error = 112.647922562\n", | |
"... b = 0.0336937084447, m = 1.47882986582, error = 112.647886475\n", | |
"... b = 0.0337537683069, m = 1.47882868536, error = 112.64785039\n", | |
"... b = 0.0338138277158, m = 1.47882750491, error = 112.647814305\n", | |
"... b = 0.0338738866714, m = 1.47882632447, error = 112.64777822\n", | |
"... b = 0.0339339451737, m = 1.47882514404, error = 112.647742136\n", | |
"... b = 0.0339940032226, m = 1.47882396362, error = 112.647706053\n", | |
"... b = 0.0340540608183, m = 1.4788227832, error = 112.64766997\n", | |
"... b = 0.0341141179606, m = 1.4788216028, error = 112.647633887\n", | |
"... b = 0.0341741746497, m = 1.4788204224, error = 112.647597805\n", | |
"... b = 0.0342342308854, m = 1.47881924201, error = 112.647561724\n", | |
"... b = 0.0342942866679, m = 1.47881806164, error = 112.647525643\n", | |
"... b = 0.0343543419971, m = 1.47881688127, error = 112.647489563\n", | |
"... b = 0.034414396873, m = 1.47881570091, error = 112.647453483\n", | |
"... b = 0.0344744512956, m = 1.47881452055, error = 112.647417404\n", | |
"... b = 0.034534505265, m = 1.47881334021, error = 112.647381326\n", | |
"... b = 0.0345945587811, m = 1.47881215988, error = 112.647345248\n", | |
"... b = 0.0346546118439, m = 1.47881097955, error = 112.64730917\n", | |
"... b = 0.0347146644534, m = 1.47880979924, error = 112.647273093\n", | |
"... b = 0.0347747166097, m = 1.47880861893, error = 112.647237017\n", | |
"... b = 0.0348347683127, m = 1.47880743863, error = 112.647200941\n", | |
"... b = 0.0348948195625, m = 1.47880625834, error = 112.647164866\n", | |
"... b = 0.034954870359, m = 1.47880507806, error = 112.647128791\n", | |
"... b = 0.0350149207023, m = 1.47880389779, error = 112.647092717\n", | |
"After 100 iterations b = 0.0350749705923, m = 1.47880271753, error = 112.647056643\n" | |
] | |
} | |
], | |
"source": [ | |
"if __name__ == '__main__':\n", | |
" run()" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 2", | |
"language": "python", | |
"name": "python2" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 2 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython2", | |
"version": "2.7.6" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 1 | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
32.502345269453031 | 31.70700584656992 | |
---|---|---|
53.426804033275019 | 68.77759598163891 | |
61.530358025636438 | 62.562382297945803 | |
47.475639634786098 | 71.546632233567777 | |
59.813207869512318 | 87.230925133687393 | |
55.142188413943821 | 78.211518270799232 | |
52.211796692214001 | 79.64197304980874 | |
39.299566694317065 | 59.171489321869508 | |
48.10504169176825 | 75.331242297063056 | |
52.550014442733818 | 71.300879886850353 | |
45.419730144973755 | 55.165677145959123 | |
54.351634881228918 | 82.478846757497919 | |
44.164049496773352 | 62.008923245725825 | |
58.16847071685779 | 75.392870425994957 | |
56.727208057096611 | 81.43619215887864 | |
48.955888566093719 | 60.723602440673965 | |
44.687196231480904 | 82.892503731453715 | |
60.297326851333466 | 97.379896862166078 | |
45.618643772955828 | 48.847153317355072 | |
38.816817537445637 | 56.877213186268506 | |
66.189816606752601 | 83.878564664602763 | |
65.41605174513407 | 118.59121730252249 | |
47.48120860786787 | 57.251819462268969 | |
41.57564261748702 | 51.391744079832307 | |
51.84518690563943 | 75.380651665312357 | |
59.370822011089523 | 74.765564032151374 | |
57.31000343834809 | 95.455052922574737 | |
63.615561251453308 | 95.229366017555307 | |
46.737619407976972 | 79.052406169565586 | |
50.556760148547767 | 83.432071421323712 | |
52.223996085553047 | 63.358790317497878 | |
35.567830047746632 | 41.412885303700563 | |
42.436476944055642 | 76.617341280074044 | |
58.16454011019286 | 96.769566426108199 | |
57.504447615341789 | 74.084130116602523 | |
45.440530725319981 | 66.588144414228594 | |
61.89622268029126 | 77.768482417793024 | |
33.093831736163963 | 50.719588912312084 | |
36.436009511386871 | 62.124570818071781 | |
37.675654860850742 | 60.810246649902211 | |
44.555608383275356 | 52.682983366387781 | |
43.318282631865721 | 58.569824717692867 | |
50.073145632289034 | 82.905981485070512 | |
43.870612645218372 | 61.424709804339123 | |
62.997480747553091 | 115.24415280079529 | |
32.669043763467187 | 45.570588823376085 | |
40.166899008703702 | 54.084054796223612 | |
53.575077531673656 | 87.994452758110413 | |
33.864214971778239 | 52.725494375900425 | |
64.707138666121296 | 93.576118692658241 | |
38.119824026822805 | 80.166275447370964 | |
44.502538064645101 | 65.101711570560326 | |
40.599538384552318 | 65.562301260400375 | |
41.720676356341293 | 65.280886920822823 | |
51.088634678336796 | 73.434641546324301 | |
55.078095904923202 | 71.13972785861894 | |
41.377726534895203 | 79.102829683549857 | |
62.494697427269791 | 86.520538440347153 | |
49.203887540826003 | 84.742697807826218 | |
41.102685187349664 | 59.358850248624933 | |
41.182016105169822 | 61.684037524833627 | |
50.186389494880601 | 69.847604158249183 | |
52.378446219236217 | 86.098291205774103 | |
50.135485486286122 | 59.108839267699643 | |
33.644706006191782 | 69.89968164362763 | |
39.557901222906828 | 44.862490711164398 | |
56.130388816875467 | 85.498067778840223 | |
57.362052133238237 | 95.536686846467219 | |
60.269214393997906 | 70.251934419771587 | |
35.678093889410732 | 52.721734964774988 | |
31.588116998132829 | 50.392670135079896 | |
53.66093226167304 | 63.642398775657753 | |
46.682228649471917 | 72.247251068662365 | |
43.107820219102464 | 57.812512976181402 | |
70.34607561504933 | 104.25710158543822 | |
44.492855880854073 | 86.642020318822006 | |
57.50453330326841 | 91.486778000110135 | |
36.930076609191808 | 55.231660886212836 | |
55.805733357942742 | 79.550436678507609 | |
38.954769073377065 | 44.847124242467601 | |
56.901214702247074 | 80.207523139682763 | |
56.868900661384046 | 83.14274979204346 | |
34.33312470421609 | 55.723489260543914 | |
59.04974121466681 | 77.634182511677864 | |
57.788223993230673 | 99.051414841748269 | |
54.282328705967409 | 79.120646274680027 | |
51.088719898979143 | 69.588897851118475 | |
50.282836348230731 | 69.510503311494389 | |
44.211741752090113 | 73.687564318317285 | |
38.005488008060688 | 61.366904537240131 | |
32.940479942618296 | 67.170655768995118 | |
53.691639571070056 | 85.668203145001542 | |
68.76573426962166 | 114.85387123391394 | |
46.230966498310252 | 90.123572069967423 | |
68.319360818255362 | 97.919821035242848 | |
50.030174340312143 | 81.536990783015028 | |
49.239765342753763 | 72.111832469615663 | |
50.039575939875988 | 85.232007342325673 | |
48.149858891028863 | 66.224957888054632 | |
25.128484647772304 | 53.454394214850524 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment