Skip to content

Instantly share code, notes, and snippets.

@zhezh
Last active April 8, 2020 23:07
Show Gist options
  • Save zhezh/ccc7e7b70338c6b882e08113d7706530 to your computer and use it in GitHub Desktop.
Save zhezh/ccc7e7b70338c6b882e08113d7706530 to your computer and use it in GitHub Desktop.
[pytorch 分层设置学习率] #pytorch
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"分层设置学习率"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"import torch.optim as optim"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.4.0\n"
]
}
],
"source": [
"print(torch.__version__)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"# 构建一个简单多层网络结构\n",
"class TwoLayerNet(torch.nn.Module):\n",
" def __init__(self, D_in, H, D_out):\n",
" \"\"\"\n",
" In the constructor we instantiate two nn.Linear modules and assign them as\n",
" member variables.\n",
" \"\"\"\n",
" super(TwoLayerNet, self).__init__()\n",
" self.linear1 = torch.nn.Linear(D_in, H)\n",
" self.linear2 = torch.nn.Linear(H, D_out)\n",
"\n",
" def forward(self, x):\n",
" \"\"\"\n",
" In the forward function we accept a Tensor of input data and we must return\n",
" a Tensor of output data. We can use Modules defined in the constructor as\n",
" well as arbitrary operators on Tensors.\n",
" \"\"\"\n",
" h_relu = F.relu(self.linear1(x))\n",
" y_pred = self.linear2(h_relu)\n",
" return y_pred"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"# N is batch size; D_in is input dimension;\n",
"# H is hidden dimension; D_out is output dimension.\n",
"N, D_in, H, D_out = 64, 1000, 100, 10"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"x = torch.randn(N, D_in)\n",
"y = torch.randn(N, D_out)\n",
"\n",
"# Construct our model by instantiating the class defined above\n",
"model = TwoLayerNet(D_in, H, D_out)\n",
"\n",
"# Construct our loss function and an Optimizer. The call to model.parameters()\n",
"# in the SGD constructor will contain the learnable parameters of the two\n",
"# nn.Linear modules which are members of the model.\n",
"criterion = torch.nn.MSELoss(size_average=False)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"查看模型的参数名称"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"参数名: linear1.weight , id: 140705409071936\n",
"参数名: linear1.bias , id: 140705409072008\n",
"参数名: linear2.weight , id: 140705409072296\n",
"参数名: linear2.bias , id: 140705409072656\n"
]
}
],
"source": [
"for pname, p in model.named_parameters():\n",
" print('参数名: {: <18}, id: {}'.format(pname, id(p)))"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"all_parameters = model.parameters()\n",
"\n",
"lin1_parameters = []\n",
"for pname, p in model.named_parameters():\n",
" if pname.find('linear1') >= 0:\n",
" lin1_parameters.append(p)\n",
"\n",
"lin1_parameters_id = list(map(id, lin1_parameters))\n",
"other_parameters = list(filter(lambda p: id(p) not in lin1_parameters_id,\n",
" all_parameters))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"现在获得了两组参数,一组是linear1,另一组是其他的(本程序中即linear2)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"linear1组参数id: \n",
"140705409071936\n",
"140705409072008\n",
"\n",
"\n",
"other组参数id: \n",
"140705409072296\n",
"140705409072656\n"
]
}
],
"source": [
"print('linear1组参数id: ')\n",
"for p in lin1_parameters:\n",
" print(id(p))\n",
" \n",
"print('\\n')\n",
"print('other组参数id: ')\n",
"for p in other_parameters:\n",
" print(id(p))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"构造optim"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
"optimizer = optim.SGD([\n",
" {'params': lin1_parameters},\n",
" {'params': other_parameters, 'lr': 1e-3}\n",
" ], lr=1e-4)\n",
"# linear1层的学习率1e-4,其它层1e-3"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"训练网络"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0 719.4832153320312\n",
"1 618.053466796875\n",
"2 556.1675415039062\n",
"3 500.43212890625\n",
"4 447.5682678222656\n",
"5 396.4813537597656\n",
"6 347.07867431640625\n",
"7 300.09326171875\n",
"8 256.21868896484375\n",
"9 216.0878448486328\n",
"10 180.35525512695312\n",
"11 149.25791931152344\n",
"12 122.67479705810547\n",
"13 100.36481475830078\n",
"14 81.89598083496094\n",
"15 66.74868774414062\n",
"16 54.42697525024414\n",
"17 44.45567321777344\n",
"18 36.39970397949219\n",
"19 29.910730361938477\n",
"20 24.662458419799805\n",
"21 20.40622901916504\n",
"22 16.947093963623047\n",
"23 14.121529579162598\n",
"24 11.805665969848633\n",
"25 9.908512115478516\n",
"26 8.346722602844238\n",
"27 7.056419849395752\n",
"28 5.984086036682129\n",
"29 5.090843677520752\n",
"30 4.344158172607422\n",
"31 3.7168867588043213\n",
"32 3.1883127689361572\n",
"33 2.7413690090179443\n",
"34 2.3627569675445557\n",
"35 2.0411109924316406\n",
"36 1.7670531272888184\n",
"37 1.5327770709991455\n",
"38 1.3317700624465942\n",
"39 1.1588612794876099\n",
"40 1.0100197792053223\n",
"41 0.8820691108703613\n",
"42 0.7714465856552124\n",
"43 0.6757064461708069\n",
"44 0.592779815196991\n",
"45 0.5206921696662903\n",
"46 0.45796385407447815\n",
"47 0.403344064950943\n",
"48 0.3557352125644684\n",
"49 0.31413477659225464\n",
"50 0.27769696712493896\n",
"51 0.24573519825935364\n",
"52 0.2176705002784729\n",
"53 0.19302386045455933\n",
"54 0.17133362591266632\n",
"55 0.15221278369426727\n",
"56 0.135351300239563\n",
"57 0.12046048790216446\n",
"58 0.1073007881641388\n",
"59 0.09565050154924393\n",
"60 0.08532743155956268\n",
"61 0.07617135345935822\n",
"62 0.06804817914962769\n",
"63 0.06084805354475975\n",
"64 0.05446131154894829\n",
"65 0.04884392023086548\n",
"66 0.043833404779434204\n",
"67 0.039361849427223206\n",
"68 0.03537042811512947\n",
"69 0.03180677816271782\n",
"70 0.028617050498723984\n",
"71 0.025749675929546356\n",
"72 0.023186029866337776\n",
"73 0.02088870480656624\n",
"74 0.018828196451067924\n",
"75 0.016980471089482307\n",
"76 0.015320717357099056\n",
"77 0.013830263167619705\n",
"78 0.012490017339587212\n",
"79 0.011284894309937954\n",
"80 0.01020009908825159\n",
"81 0.00922376848757267\n",
"82 0.008343766443431377\n",
"83 0.007550488226115704\n",
"84 0.0068353088572621346\n",
"85 0.006190172396600246\n",
"86 0.005607489496469498\n",
"87 0.005081566050648689\n",
"88 0.00460641598328948\n",
"89 0.004176917020231485\n",
"90 0.0037886006757616997\n",
"91 0.0034373654052615166\n",
"92 0.0031197601929306984\n",
"93 0.0028322283178567886\n",
"94 0.0025717862881720066\n",
"95 0.002335888333618641\n",
"96 0.0021221640054136515\n",
"97 0.0019284778973087668\n",
"98 0.0017529240576550364\n",
"99 0.0015937236603349447\n",
"100 0.0014493277994915843\n",
"101 0.0013182209804654121\n",
"102 0.0011992763029411435\n",
"103 0.0010912807192653418\n",
"104 0.0009932058164849877\n",
"105 0.0009040983277373016\n",
"106 0.0008233404951170087\n",
"107 0.0007499091443605721\n",
"108 0.0006831525824964046\n",
"109 0.0006224379176273942\n",
"110 0.0005672484403476119\n",
"111 0.0005170325748622417\n",
"112 0.0004713317903224379\n",
"113 0.0004297299892641604\n",
"114 0.0003918729198630899\n",
"115 0.0003573991998564452\n",
"116 0.0003260155499447137\n",
"117 0.0002974196686409414\n",
"118 0.0002713669091463089\n",
"119 0.0002476317167747766\n",
"120 0.00022600177908316255\n",
"121 0.0002062939602183178\n",
"122 0.00018832141358871013\n",
"123 0.0001719275169307366\n",
"124 0.0001569933956488967\n",
"125 0.00014336439198814332\n",
"126 0.00013093784218654037\n",
"127 0.0001196042139781639\n",
"128 0.0001092585880542174\n",
"129 9.982137999031693e-05\n",
"130 9.120586764765903e-05\n",
"131 8.33444792078808e-05\n",
"132 7.616882066940889e-05\n",
"133 6.961503822822124e-05\n",
"134 6.363449210766703e-05\n",
"135 5.817100827698596e-05\n",
"136 5.318074545357376e-05\n",
"137 4.8620247980579734e-05\n",
"138 4.4461063225753605e-05\n",
"139 4.0658000216353685e-05\n",
"140 3.7181245716055855e-05\n",
"141 3.400376590434462e-05\n",
"142 3.110080797341652e-05\n",
"143 2.8448537705116905e-05\n",
"144 2.6025612896773964e-05\n",
"145 2.3810694983694702e-05\n",
"146 2.1784513592137955e-05\n",
"147 1.9931732822442427e-05\n",
"148 1.8238761185784824e-05\n",
"149 1.669217635935638e-05\n",
"150 1.5273904864443466e-05\n",
"151 1.3978798961034045e-05\n",
"152 1.279618481930811e-05\n",
"153 1.1711815204762388e-05\n",
"154 1.071973474608967e-05\n",
"155 9.814746590564027e-06\n",
"156 8.985691238194704e-06\n",
"157 8.226681529777125e-06\n",
"158 7.533013558713719e-06\n",
"159 6.897260846017161e-06\n",
"160 6.3151273934636265e-06\n",
"161 5.784675977338338e-06\n",
"162 5.296439212543191e-06\n",
"163 4.851282938034274e-06\n",
"164 4.442661975190276e-06\n",
"165 4.068862381245708e-06\n",
"166 3.7270233406161424e-06\n",
"167 3.413488684600452e-06\n",
"168 3.1275621950044297e-06\n",
"169 2.8650440526689636e-06\n",
"170 2.624645276227966e-06\n",
"171 2.405057784926612e-06\n",
"172 2.2035642359696794e-06\n",
"173 2.019183966694982e-06\n",
"174 1.850062517405604e-06\n",
"175 1.6953827071120031e-06\n",
"176 1.5531543340330245e-06\n",
"177 1.4232090279620024e-06\n",
"178 1.3044416391494451e-06\n",
"179 1.1956101388932439e-06\n",
"180 1.0957942322420422e-06\n",
"181 1.0046904890259611e-06\n",
"182 9.207004154632159e-07\n",
"183 8.440935062026256e-07\n",
"184 7.73405361087498e-07\n",
"185 7.091608722475939e-07\n",
"186 6.500075642179581e-07\n",
"187 5.959356030871277e-07\n",
"188 5.460437932924833e-07\n",
"189 5.007885874874773e-07\n",
"190 4.5921461833131616e-07\n",
"191 4.2098395169887226e-07\n",
"192 3.8619594988631434e-07\n",
"193 3.539971658028662e-07\n",
"194 3.244428512516606e-07\n",
"195 2.975311303998751e-07\n",
"196 2.726736170188815e-07\n",
"197 2.5038809781108284e-07\n",
"198 2.2951981293317658e-07\n",
"199 2.1044718323537381e-07\n",
"200 1.931230571017295e-07\n",
"201 1.7704486765524052e-07\n",
"202 1.625433725394032e-07\n",
"203 1.4899816846991598e-07\n",
"204 1.3666438292148086e-07\n",
"205 1.2532166238088394e-07\n",
"206 1.1494875451489861e-07\n",
"207 1.054767295727288e-07\n",
"208 9.676008971837291e-08\n",
"209 8.872893886291422e-08\n",
"210 8.135914697504631e-08\n",
"211 7.470005414234038e-08\n",
"212 6.852687306491134e-08\n",
"213 6.293034005011577e-08\n",
"214 5.7790256136058815e-08\n",
"215 5.302708672161316e-08\n",
"216 4.8682117892440147e-08\n",
"217 4.453647051150256e-08\n",
"218 4.099248585021087e-08\n",
"219 3.761395817036828e-08\n",
"220 3.453242669593237e-08\n",
"221 3.170368501059784e-08\n",
"222 2.915122720992258e-08\n",
"223 2.6717120960029206e-08\n",
"224 2.451883673870725e-08\n",
"225 2.2580659120308155e-08\n",
"226 2.0710798409595554e-08\n",
"227 1.9040477639009623e-08\n",
"228 1.7526446072224644e-08\n",
"229 1.6122124080197864e-08\n",
"230 1.4806063042271944e-08\n",
"231 1.3633435713700237e-08\n",
"232 1.2587334730085331e-08\n",
"233 1.1583493275679757e-08\n",
"234 1.0632591695980409e-08\n",
"235 9.816670143436568e-09\n",
"236 9.059893280038978e-09\n",
"237 8.38562641547469e-09\n",
"238 7.752122499482539e-09\n",
"239 7.144894009769587e-09\n",
"240 6.591914125664289e-09\n",
"241 6.152251152968802e-09\n",
"242 5.6817590632363135e-09\n",
"243 5.275397008119853e-09\n",
"244 4.890560845183245e-09\n",
"245 4.5295536210687715e-09\n",
"246 4.211996085246028e-09\n",
"247 3.908982026956664e-09\n",
"248 3.6504110845214655e-09\n",
"249 3.4168203821849374e-09\n",
"250 3.187831110196271e-09\n",
"251 2.98542235377397e-09\n",
"252 2.8059994328089033e-09\n",
"253 2.626949102690901e-09\n",
"254 2.4697992540012592e-09\n",
"255 2.3114088421039014e-09\n",
"256 2.1845512065965522e-09\n",
"257 2.0542922918309614e-09\n",
"258 1.9309427390368228e-09\n",
"259 1.8381356436947272e-09\n",
"260 1.7248911188261218e-09\n",
"261 1.6398804536521538e-09\n",
"262 1.5482324311477669e-09\n",
"263 1.4635379574912122e-09\n",
"264 1.3910612661760524e-09\n",
"265 1.317505660125562e-09\n",
"266 1.2539014271339965e-09\n",
"267 1.2011907024600532e-09\n",
"268 1.1445275838184443e-09\n",
"269 1.0953595808160799e-09\n",
"270 1.0399190397691882e-09\n",
"271 1.0023235574863065e-09\n",
"272 9.588654314995892e-10\n",
"273 9.140234680238279e-10\n",
"274 8.714531318787522e-10\n",
"275 8.341509150078252e-10\n",
"276 8.028351317079796e-10\n",
"277 7.760427300773642e-10\n",
"278 7.34666716351029e-10\n",
"279 7.090298348444435e-10\n",
"280 6.824536491478739e-10\n",
"281 6.50852871597607e-10\n",
"282 6.332582236368012e-10\n",
"283 6.151034126489208e-10\n",
"284 5.901392152729557e-10\n",
"285 5.691445092992353e-10\n",
"286 5.507771461132904e-10\n",
"287 5.32691946109054e-10\n",
"288 5.124157764768711e-10\n",
"289 4.940373665718312e-10\n",
"290 4.766944616818591e-10\n",
"291 4.6186365842970645e-10\n",
"292 4.470791514776806e-10\n",
"293 4.3581016573313036e-10\n",
"294 4.2106490516502504e-10\n",
"295 4.066213477038616e-10\n",
"296 3.981817098264173e-10\n",
"297 3.8970893179168797e-10\n",
"298 3.72095521061766e-10\n",
"299 3.6663283520255163e-10\n",
"300 3.573504825382656e-10\n",
"301 3.487677924240984e-10\n",
"302 3.3633168472491093e-10\n",
"303 3.290345773621084e-10\n",
"304 3.224511213595349e-10\n",
"305 3.122619940398863e-10\n",
"306 3.0483571222816863e-10\n",
"307 2.9414332081145744e-10\n",
"308 2.8667790363812173e-10\n",
"309 2.8062335788447967e-10\n",
"310 2.762368667141857e-10\n",
"311 2.6522151141961103e-10\n",
"312 2.57844356976733e-10\n",
"313 2.5349577992273e-10\n",
"314 2.478507399317209e-10\n",
"315 2.41882402995941e-10\n",
"316 2.3629426193494396e-10\n",
"317 2.3070964583205011e-10\n",
"318 2.264485127190241e-10\n",
"319 2.2164882429454025e-10\n",
"320 2.1533744232193897e-10\n",
"321 2.097950702051321e-10\n",
"322 2.064158427517171e-10\n",
"323 1.990261011552974e-10\n",
"324 1.9497652103961371e-10\n",
"325 1.9140954099494678e-10\n",
"326 1.8693983860895713e-10\n",
"327 1.8207133023473432e-10\n",
"328 1.783010128431073e-10\n",
"329 1.763045681668629e-10\n",
"330 1.7145321273837055e-10\n",
"331 1.6931771262829187e-10\n",
"332 1.6673835923075586e-10\n",
"333 1.6416644432748484e-10\n",
"334 1.6168200112076647e-10\n",
"335 1.5868138747432425e-10\n",
"336 1.5622357574240908e-10\n",
"337 1.5223147742382537e-10\n",
"338 1.4958895233618819e-10\n",
"339 1.4682752236261365e-10\n",
"340 1.436487179207191e-10\n",
"341 1.4157740257925155e-10\n",
"342 1.410829092440835e-10\n",
"343 1.380528469319131e-10\n",
"344 1.3580145341585137e-10\n",
"345 1.341790289988154e-10\n",
"346 1.3181483682345174e-10\n",
"347 1.3028327028319353e-10\n",
"348 1.2861381404327688e-10\n",
"349 1.2620478273550617e-10\n",
"350 1.2542905603041277e-10\n",
"351 1.238803642999997e-10\n",
"352 1.218108947043106e-10\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"353 1.2094443502252972e-10\n",
"354 1.1777642749954964e-10\n",
"355 1.1760770135538223e-10\n",
"356 1.1544241951266798e-10\n",
"357 1.1402352673162142e-10\n",
"358 1.1242246023002167e-10\n",
"359 1.1167516911214648e-10\n",
"360 1.0990484911044263e-10\n",
"361 1.0966394459188678e-10\n",
"362 1.0963921437401325e-10\n",
"363 1.0603967703914918e-10\n",
"364 1.03834024711702e-10\n",
"365 1.039718450224214e-10\n",
"366 1.0166703590108739e-10\n",
"367 1.0114989401621699e-10\n",
"368 9.834176528666916e-11\n",
"369 9.60351104195567e-11\n",
"370 9.521219923591673e-11\n",
"371 9.493408836824813e-11\n",
"372 9.308730175572322e-11\n",
"373 9.198990180703248e-11\n",
"374 9.021075553228286e-11\n",
"375 8.961704989207675e-11\n",
"376 8.807096718577156e-11\n",
"377 8.61079402225684e-11\n",
"378 8.594966405262028e-11\n",
"379 8.452996635988086e-11\n",
"380 8.392590788997012e-11\n",
"381 8.171358034658738e-11\n",
"382 8.029973908030286e-11\n",
"383 7.93237697749305e-11\n",
"384 7.853767636234465e-11\n",
"385 7.892052289459883e-11\n",
"386 7.902740961629462e-11\n",
"387 7.743728575038133e-11\n",
"388 7.645977601056231e-11\n",
"389 7.571010485207808e-11\n",
"390 7.520549460959813e-11\n",
"391 7.54181994633285e-11\n",
"392 7.461733314562125e-11\n",
"393 7.3604816686057e-11\n",
"394 7.3249614707116e-11\n",
"395 7.211352348601707e-11\n",
"396 7.178711791677728e-11\n",
"397 7.167869769952873e-11\n",
"398 7.025070108968023e-11\n",
"399 6.825126575016327e-11\n",
"400 6.793333950927405e-11\n",
"401 6.772558902579107e-11\n",
"402 6.730412061006774e-11\n",
"403 6.625886644906487e-11\n",
"404 6.461721435702117e-11\n",
"405 6.47099734907286e-11\n",
"406 6.391782242376465e-11\n",
"407 6.264266882993752e-11\n",
"408 6.173965505507084e-11\n",
"409 6.116362971653189e-11\n",
"410 6.077430919626536e-11\n",
"411 6.030231869402769e-11\n",
"412 5.947273923334606e-11\n",
"413 5.867167862660949e-11\n",
"414 5.855109452834739e-11\n",
"415 5.7634869099487673e-11\n",
"416 5.798097418852066e-11\n",
"417 5.714314438298729e-11\n",
"418 5.7525491314880384e-11\n",
"419 5.753773499317383e-11\n",
"420 5.72689846933816e-11\n",
"421 5.631775948367057e-11\n",
"422 5.598362051717487e-11\n",
"423 5.5867376697049664e-11\n",
"424 5.629419153052595e-11\n",
"425 5.5520841396594633e-11\n",
"426 5.46322084793438e-11\n",
"427 5.4010414196614676e-11\n",
"428 5.341607364761636e-11\n",
"429 5.3519549902958374e-11\n",
"430 5.247411186126705e-11\n",
"431 5.307565845158457e-11\n",
"432 5.2371048470112314e-11\n",
"433 5.2449510012930745e-11\n",
"434 5.1879195384074706e-11\n",
"435 5.1574945703070085e-11\n",
"436 5.1512395043973314e-11\n",
"437 5.08322897663227e-11\n",
"438 5.057060326052465e-11\n",
"439 4.9787240302689995e-11\n",
"440 4.823476340565236e-11\n",
"441 4.8459992962879284e-11\n",
"442 4.797246627719076e-11\n",
"443 4.824013410953398e-11\n",
"444 4.7559997606860804e-11\n",
"445 4.690832444698145e-11\n",
"446 4.619358368040949e-11\n",
"447 4.583974866356755e-11\n",
"448 4.605983650041168e-11\n",
"449 4.5755635391664384e-11\n",
"450 4.547212259509159e-11\n",
"451 4.5017298916372184e-11\n",
"452 4.5089352390670356e-11\n",
"453 4.437321690642371e-11\n",
"454 4.417576374149412e-11\n",
"455 4.4186581477090314e-11\n",
"456 4.324371069563959e-11\n",
"457 4.303992579002269e-11\n",
"458 4.266882333570088e-11\n",
"459 4.2662408328286716e-11\n",
"460 4.2061472360632735e-11\n",
"461 4.153392560435343e-11\n",
"462 4.1709347781138106e-11\n",
"463 4.15860852698291e-11\n",
"464 4.1531614952683427e-11\n",
"465 4.106182407981329e-11\n",
"466 4.0790346794716825e-11\n",
"467 4.062897934753451e-11\n",
"468 3.990976646384148e-11\n",
"469 4.011613263799063e-11\n",
"470 3.9666533946380866e-11\n",
"471 3.945664281412853e-11\n",
"472 3.835248785222234e-11\n",
"473 3.8230120458226935e-11\n",
"474 3.747497104300557e-11\n",
"475 3.7175280215295814e-11\n",
"476 3.713800100779707e-11\n",
"477 3.680154792018442e-11\n",
"478 3.743382687160235e-11\n",
"479 3.680740781608627e-11\n",
"480 3.6479482629081517e-11\n",
"481 3.598047901287593e-11\n",
"482 3.572271645158054e-11\n",
"483 3.5526505348659754e-11\n",
"484 3.541624979397362e-11\n",
"485 3.4761606787503396e-11\n",
"486 3.4769877949036854e-11\n",
"487 3.460022199308632e-11\n",
"488 3.453215144388899e-11\n",
"489 3.4166957457726355e-11\n",
"490 3.428968220475781e-11\n",
"491 3.439980245101282e-11\n",
"492 3.4364174700263206e-11\n",
"493 3.4181008717881767e-11\n",
"494 3.368537393466653e-11\n",
"495 3.368653272994848e-11\n",
"496 3.291827227469568e-11\n",
"497 3.2805473615393765e-11\n",
"498 3.3228246543171025e-11\n",
"499 3.3127479925898484e-11\n"
]
}
],
"source": [
"for t in range(500):\n",
" # Forward pass: Compute predicted y by passing x to the model\n",
" y_pred = model(x)\n",
"\n",
" # Compute and print loss\n",
" loss = criterion(y_pred, y)\n",
" print(t, loss.item())\n",
"\n",
" # Zero gradients, perform a backward pass, and update the weights.\n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" optimizer.step()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment