Skip to content

Instantly share code, notes, and snippets.

@canard0328
Created September 29, 2015 08:03
Show Gist options
  • Save canard0328/be986d937e850d02e91c to your computer and use it in GitHub Desktop.
Save canard0328/be986d937e850d02e91c to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import numpy as np\n",
"import chainer.functions as F\n",
"from chainer import Variable, FunctionSet, optimizers, function"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"X = np.array([\n",
" [1, 3, 0.001],\n",
" [2, 0.001, 5],\n",
" [0.3, 1, 0.2],\n",
" [4, 10, 1],\n",
" [5, 9, 5],\n",
" [0.1, 0.3, 0.9],\n",
" [10, 1, 0.1],\n",
" [2, 10, 10],\n",
" [0.2, 0.5, 11],\n",
" [9, 0.1, 9],\n",
" [0.9, 11, 0.1]\n",
" ]).astype(np.float32)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"y = 2*X[:,0]**2*X[:,1] - X[:,1]**0.5*X[:,2]**0.5 + np.random.normal(scale=0.1,size=X.shape[0]).astype(np.float32)\n",
"y = y.reshape(len(y), 1)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"X = np.log(X).astype(np.float32)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"num_units = 2\n",
"model = FunctionSet(\n",
" l1 = F.Linear(3, num_units),\n",
" l2 = F.Linear(num_units, 1)\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"class Expo(function.Function):\n",
" def forward_cpu(self, x):\n",
" self.y = np.exp(x[0])\n",
" return self.y,\n",
"\n",
" def backward_cpu(self, x, gy):\n",
" return gy[0] * self.y,\n",
"\n",
"def expo(x):\n",
" return Expo()(x)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def forward(x_data, y_data):\n",
" x, t = Variable(x_data), Variable(y_data)\n",
" h1 = expo(model.l1(x))\n",
" y = model.l2(h1)\n",
" return F.mean_squared_error(y, t)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"optimizer = optimizers.Adam()\n",
"optimizer.setup(model)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch:100 train mean loss=31372.7792969\n",
"epoch:200 train mean loss=31234.3144531\n",
"epoch:300 train mean loss=31152.5253906\n",
"epoch:400 train mean loss=31094.3867188\n",
"epoch:500 train mean loss=31048.0175781\n",
"epoch:600 train mean loss=31007.4199219\n",
"epoch:700 train mean loss=30967.8574219\n",
"epoch:800 train mean loss=30920.8417969\n",
"epoch:900 train mean loss=30803.2382812\n",
"epoch:1000 train mean loss=28303.3183594\n",
"epoch:1100 train mean loss=14691.9912109\n",
"epoch:1200 train mean loss=9743.24511719\n",
"epoch:1300 train mean loss=5964.33398438\n",
"epoch:1400 train mean loss=2970.22753906\n",
"epoch:1500 train mean loss=974.683654785\n",
"epoch:1600 train mean loss=357.173095703\n",
"epoch:1700 train mean loss=278.009765625\n",
"epoch:1800 train mean loss=253.25994873\n",
"epoch:1900 train mean loss=235.991958618\n",
"epoch:2000 train mean loss=221.130355835\n",
"epoch:2100 train mean loss=207.482162476\n",
"epoch:2200 train mean loss=194.632583618\n",
"epoch:2300 train mean loss=182.404312134\n",
"epoch:2400 train mean loss=170.71421814\n",
"epoch:2500 train mean loss=159.516830444\n",
"epoch:2600 train mean loss=148.791152954\n",
"epoch:2700 train mean loss=138.524398804\n",
"epoch:2800 train mean loss=128.708847046\n",
"epoch:2900 train mean loss=119.339553833\n",
"epoch:3000 train mean loss=110.41217804\n",
"epoch:3100 train mean loss=101.923614502\n",
"epoch:3200 train mean loss=93.8698654175\n",
"epoch:3300 train mean loss=86.245262146\n",
"epoch:3400 train mean loss=79.0436172485\n",
"epoch:3500 train mean loss=72.2586212158\n",
"epoch:3600 train mean loss=65.8811569214\n",
"epoch:3700 train mean loss=59.9036750793\n",
"epoch:3800 train mean loss=54.3155403137\n",
"epoch:3900 train mean loss=49.1061859131\n",
"epoch:4000 train mean loss=44.2637748718\n",
"epoch:4100 train mean loss=39.7768783569\n",
"epoch:4200 train mean loss=35.6319160461\n",
"epoch:4300 train mean loss=31.815990448\n",
"epoch:4400 train mean loss=28.3146286011\n",
"epoch:4500 train mean loss=25.1136341095\n",
"epoch:4600 train mean loss=22.1983680725\n",
"epoch:4700 train mean loss=19.5539932251\n",
"epoch:4800 train mean loss=17.1654510498\n",
"epoch:4900 train mean loss=15.0172872543\n",
"epoch:5000 train mean loss=13.0945644379\n",
"epoch:5100 train mean loss=11.3816661835\n",
"epoch:5200 train mean loss=9.86406230927\n",
"epoch:5300 train mean loss=8.52664470673\n",
"epoch:5400 train mean loss=7.3549823761\n",
"epoch:5500 train mean loss=6.3347454071\n",
"epoch:5600 train mean loss=5.45214128494\n",
"epoch:5700 train mean loss=4.69383335114\n",
"epoch:5800 train mean loss=4.04701328278\n",
"epoch:5900 train mean loss=3.4996817112\n",
"epoch:6000 train mean loss=3.0401597023\n",
"epoch:6100 train mean loss=2.6576256752\n",
"epoch:6200 train mean loss=2.34209609032\n",
"epoch:6300 train mean loss=2.08418512344\n",
"epoch:6400 train mean loss=1.87544584274\n",
"epoch:6500 train mean loss=1.70815765858\n",
"epoch:6600 train mean loss=1.57541358471\n",
"epoch:6700 train mean loss=1.47113156319\n",
"epoch:6800 train mean loss=1.389950037\n",
"epoch:6900 train mean loss=1.32728886604\n",
"epoch:7000 train mean loss=1.27922046185\n",
"epoch:7100 train mean loss=1.24244046211\n",
"epoch:7200 train mean loss=1.21423864365\n",
"epoch:7300 train mean loss=1.19241094589\n",
"epoch:7400 train mean loss=1.17518079281\n",
"epoch:7500 train mean loss=1.16118133068\n",
"epoch:7600 train mean loss=1.1493421793\n",
"epoch:7700 train mean loss=1.138854146\n",
"epoch:7800 train mean loss=1.12911236286\n",
"epoch:7900 train mean loss=1.11967742443\n",
"epoch:8000 train mean loss=1.11018323898\n",
"epoch:8100 train mean loss=1.10039067268\n",
"epoch:8200 train mean loss=1.09007024765\n",
"epoch:8300 train mean loss=1.07904744148\n",
"epoch:8400 train mean loss=1.06715524197\n",
"epoch:8500 train mean loss=1.05417370796\n",
"epoch:8600 train mean loss=1.03992474079\n",
"epoch:8700 train mean loss=1.02412843704\n",
"epoch:8800 train mean loss=1.006529212\n",
"epoch:8900 train mean loss=0.986808001995\n",
"epoch:9000 train mean loss=0.964643120766\n",
"epoch:9100 train mean loss=0.939798891544\n",
"epoch:9200 train mean loss=0.912258565426\n",
"epoch:9300 train mean loss=0.882494330406\n",
"epoch:9400 train mean loss=0.851729571819\n",
"epoch:9500 train mean loss=0.822065412998\n",
"epoch:9600 train mean loss=0.795871198177\n",
"epoch:9700 train mean loss=0.77457010746\n",
"epoch:9800 train mean loss=0.757784485817\n",
"epoch:9900 train mean loss=0.744022369385\n",
"epoch:10000 train mean loss=0.731878876686\n",
"epoch:10100 train mean loss=0.720541238785\n",
"epoch:10200 train mean loss=0.709630548954\n",
"epoch:10300 train mean loss=0.698968350887\n",
"epoch:10400 train mean loss=0.688422620296\n",
"epoch:10500 train mean loss=0.677916944027\n",
"epoch:10600 train mean loss=0.667325556278\n",
"epoch:10700 train mean loss=0.656499803066\n",
"epoch:10800 train mean loss=0.645299911499\n",
"epoch:10900 train mean loss=0.633540391922\n",
"epoch:11000 train mean loss=0.621025323868\n",
"epoch:11100 train mean loss=0.607669651508\n",
"epoch:11200 train mean loss=0.593504369259\n",
"epoch:11300 train mean loss=0.578767538071\n",
"epoch:11400 train mean loss=0.564007997513\n",
"epoch:11500 train mean loss=0.549864828587\n",
"epoch:11600 train mean loss=0.540891170502\n",
"epoch:11700 train mean loss=0.525401651859\n",
"epoch:11800 train mean loss=0.514214456081\n",
"epoch:11900 train mean loss=0.505761563778\n",
"epoch:12000 train mean loss=0.495530039072\n",
"epoch:12100 train mean loss=0.487383395433\n",
"epoch:12200 train mean loss=0.47992375493\n",
"epoch:12300 train mean loss=0.473052024841\n",
"epoch:12400 train mean loss=0.466601073742\n",
"epoch:12500 train mean loss=0.46064427495\n",
"epoch:12600 train mean loss=0.457746893167\n",
"epoch:12700 train mean loss=0.449678599834\n",
"epoch:12800 train mean loss=0.444524407387\n",
"epoch:12900 train mean loss=0.441940665245\n",
"epoch:13000 train mean loss=0.438120096922\n",
"epoch:13100 train mean loss=0.429325580597\n",
"epoch:13200 train mean loss=0.424200803041\n",
"epoch:13300 train mean loss=0.421023637056\n",
"epoch:13400 train mean loss=0.413643360138\n",
"epoch:13500 train mean loss=0.408192455769\n",
"epoch:13600 train mean loss=0.417691737413\n",
"epoch:13700 train mean loss=0.397109866142\n",
"epoch:13800 train mean loss=0.391566067934\n",
"epoch:13900 train mean loss=0.386039912701\n",
"epoch:14000 train mean loss=0.380708485842\n",
"epoch:14100 train mean loss=0.379563778639\n",
"epoch:14200 train mean loss=0.370126068592\n",
"epoch:14300 train mean loss=0.365221321583\n",
"epoch:14400 train mean loss=0.360549241304\n",
"epoch:14500 train mean loss=0.356159001589\n",
"epoch:14600 train mean loss=0.352070212364\n",
"epoch:14700 train mean loss=0.348999798298\n",
"epoch:14800 train mean loss=0.356187701225\n",
"epoch:14900 train mean loss=0.341304928064\n",
"epoch:15000 train mean loss=0.338270157576\n",
"epoch:15100 train mean loss=0.335548371077\n",
"epoch:15200 train mean loss=0.337527573109\n",
"epoch:15300 train mean loss=0.330541372299\n",
"epoch:15400 train mean loss=0.328361898661\n",
"epoch:15500 train mean loss=0.329148501158\n",
"epoch:15600 train mean loss=0.324667870998\n",
"epoch:15700 train mean loss=0.32468059659\n",
"epoch:15800 train mean loss=0.321730703115\n",
"epoch:15900 train mean loss=0.320273518562\n",
"epoch:16000 train mean loss=0.319922745228\n",
"epoch:16100 train mean loss=0.318046182394\n",
"epoch:16200 train mean loss=0.317025691271\n",
"epoch:16300 train mean loss=0.316511541605\n",
"epoch:16400 train mean loss=0.31578925252\n",
"epoch:16500 train mean loss=0.322080284357\n",
"epoch:16600 train mean loss=0.313970118761\n",
"epoch:16700 train mean loss=0.313417315483\n",
"epoch:16800 train mean loss=0.313607424498\n",
"epoch:16900 train mean loss=0.312390357256\n",
"epoch:17000 train mean loss=0.311956524849\n",
"epoch:17100 train mean loss=0.311582326889\n",
"epoch:17200 train mean loss=0.3112231493\n",
"epoch:17300 train mean loss=0.310920864344\n",
"epoch:17400 train mean loss=0.310798853636\n",
"epoch:17500 train mean loss=0.310363650322\n",
"epoch:17600 train mean loss=0.310130536556\n",
"epoch:17700 train mean loss=0.30990344286\n",
"epoch:17800 train mean loss=0.30972135067\n",
"epoch:17900 train mean loss=0.309542268515\n",
"epoch:18000 train mean loss=0.30946713686\n",
"epoch:18100 train mean loss=0.309371173382\n",
"epoch:18200 train mean loss=0.30907946825\n",
"epoch:18300 train mean loss=0.309720695019\n",
"epoch:18400 train mean loss=0.308821648359\n",
"epoch:18500 train mean loss=0.308723211288\n",
"epoch:18600 train mean loss=0.310669630766\n",
"epoch:18700 train mean loss=0.3085103333\n",
"epoch:18800 train mean loss=0.309707671404\n",
"epoch:18900 train mean loss=0.3083370924\n",
"epoch:19000 train mean loss=0.308256536722\n",
"epoch:19100 train mean loss=0.308253377676\n",
"epoch:19200 train mean loss=0.308115869761\n",
"epoch:19300 train mean loss=0.308184027672\n",
"epoch:19400 train mean loss=0.308013409376\n",
"epoch:19500 train mean loss=0.30789411068\n",
"epoch:19600 train mean loss=0.307871729136\n",
"epoch:19700 train mean loss=0.308270901442\n",
"epoch:19800 train mean loss=0.30770689249\n",
"epoch:19900 train mean loss=0.307707577944\n",
"epoch:20000 train mean loss=0.327270388603\n",
"epoch:20100 train mean loss=0.307528465986\n",
"epoch:20200 train mean loss=0.307558685541\n",
"epoch:20300 train mean loss=0.307413429022\n",
"epoch:20400 train mean loss=0.30802705884\n",
"epoch:20500 train mean loss=0.3073117733\n",
"epoch:20600 train mean loss=0.307250380516\n",
"epoch:20700 train mean loss=0.307330220938\n",
"epoch:20800 train mean loss=0.307140320539\n",
"epoch:20900 train mean loss=0.313841462135\n",
"epoch:21000 train mean loss=0.307033151388\n",
"epoch:21100 train mean loss=0.30697080493\n",
"epoch:21200 train mean loss=0.30835917592\n",
"epoch:21300 train mean loss=0.311429440975\n",
"epoch:21400 train mean loss=0.306810289621\n",
"epoch:21500 train mean loss=0.306930065155\n",
"epoch:21600 train mean loss=0.306723326445\n",
"epoch:21700 train mean loss=0.306637555361\n",
"epoch:21800 train mean loss=0.307242542505\n",
"epoch:21900 train mean loss=0.308094322681\n",
"epoch:22000 train mean loss=0.306472569704\n",
"epoch:22100 train mean loss=0.306425184011\n",
"epoch:22200 train mean loss=0.306489676237\n",
"epoch:22300 train mean loss=0.306289583445\n",
"epoch:22400 train mean loss=0.306257933378\n",
"epoch:22500 train mean loss=0.306520462036\n",
"epoch:22600 train mean loss=0.306117296219\n",
"epoch:22700 train mean loss=0.307666122913\n",
"epoch:22800 train mean loss=0.305995106697\n",
"epoch:22900 train mean loss=0.306275159121\n",
"epoch:23000 train mean loss=0.305876225233\n",
"epoch:23100 train mean loss=0.305820375681\n",
"epoch:23200 train mean loss=0.306178838015\n",
"epoch:23300 train mean loss=0.30633610487\n",
"epoch:23400 train mean loss=0.305618166924\n",
"epoch:23500 train mean loss=0.327315598726\n",
"epoch:23600 train mean loss=0.305489331484\n",
"epoch:23700 train mean loss=0.305429697037\n",
"epoch:23800 train mean loss=0.305358588696\n",
"epoch:23900 train mean loss=0.305315464735\n",
"epoch:24000 train mean loss=0.306744635105\n",
"epoch:24100 train mean loss=0.305161654949\n",
"epoch:24200 train mean loss=0.306147366762\n",
"epoch:24300 train mean loss=0.305026143789\n",
"epoch:24400 train mean loss=0.305058479309\n",
"epoch:24500 train mean loss=0.315707623959\n",
"epoch:24600 train mean loss=0.304824322462\n",
"epoch:24700 train mean loss=0.304865241051\n",
"epoch:24800 train mean loss=0.304689764977\n",
"epoch:24900 train mean loss=0.306201130152\n",
"epoch:25000 train mean loss=0.304550260305\n",
"epoch:25100 train mean loss=0.304488837719\n",
"epoch:25200 train mean loss=0.311301916838\n",
"epoch:25300 train mean loss=0.304344534874\n",
"epoch:25400 train mean loss=0.304272085428\n",
"epoch:25500 train mean loss=0.304375559092\n",
"epoch:25600 train mean loss=0.304685622454\n",
"epoch:25700 train mean loss=0.304057300091\n",
"epoch:25800 train mean loss=0.30401173234\n",
"epoch:25900 train mean loss=0.303913414478\n",
"epoch:26000 train mean loss=0.30386030674\n",
"epoch:26100 train mean loss=0.314157217741\n",
"epoch:26200 train mean loss=0.303698331118\n",
"epoch:26300 train mean loss=0.303641915321\n",
"epoch:26400 train mean loss=0.303722947836\n",
"epoch:26500 train mean loss=0.303478807211\n",
"epoch:26600 train mean loss=0.303917139769\n",
"epoch:26700 train mean loss=0.303332239389\n",
"epoch:26800 train mean loss=0.303517878056\n",
"epoch:26900 train mean loss=0.303733468056\n",
"epoch:27000 train mean loss=0.30311319232\n",
"epoch:27100 train mean loss=0.303036987782\n",
"epoch:27200 train mean loss=0.303193837404\n",
"epoch:27300 train mean loss=0.30289003253\n",
"epoch:27400 train mean loss=0.302912324667\n",
"epoch:27500 train mean loss=0.302889496088\n",
"epoch:27600 train mean loss=0.302685052156\n",
"epoch:27700 train mean loss=0.302594870329\n",
"epoch:27800 train mean loss=0.302679657936\n",
"epoch:27900 train mean loss=0.302458643913\n",
"epoch:28000 train mean loss=0.302377492189\n",
"epoch:28100 train mean loss=0.302309781313\n",
"epoch:28200 train mean loss=0.305227398872\n",
"epoch:28300 train mean loss=0.302150338888\n",
"epoch:28400 train mean loss=0.302149742842\n",
"epoch:28500 train mean loss=0.302710354328\n",
"epoch:28600 train mean loss=0.302416056395\n",
"epoch:28700 train mean loss=0.305107146502\n",
"epoch:28800 train mean loss=0.313704073429\n",
"epoch:28900 train mean loss=0.301713943481\n",
"epoch:29000 train mean loss=0.30163320899\n",
"epoch:29100 train mean loss=0.30195248127\n",
"epoch:29200 train mean loss=0.301480799913\n",
"epoch:29300 train mean loss=0.301409363747\n",
"epoch:29400 train mean loss=0.311064451933\n",
"epoch:29500 train mean loss=0.301260054111\n",
"epoch:29600 train mean loss=0.301843434572\n",
"epoch:29700 train mean loss=0.309216678143\n",
"epoch:29800 train mean loss=0.301059812307\n",
"epoch:29900 train mean loss=0.300995200872\n",
"epoch:30000 train mean loss=0.302341520786\n",
"epoch:30100 train mean loss=0.300906777382\n",
"epoch:30200 train mean loss=0.300751328468\n",
"epoch:30300 train mean loss=0.300980657339\n",
"epoch:30400 train mean loss=0.300626903772\n",
"epoch:30500 train mean loss=0.300533741713\n",
"epoch:30600 train mean loss=0.301971584558\n",
"epoch:30700 train mean loss=0.30038946867\n",
"epoch:30800 train mean loss=0.300339192152\n",
"epoch:30900 train mean loss=0.302649110556\n",
"epoch:31000 train mean loss=0.300175487995\n",
"epoch:31100 train mean loss=0.300122857094\n",
"epoch:31200 train mean loss=0.310733556747\n",
"epoch:31300 train mean loss=0.299963474274\n",
"epoch:31400 train mean loss=0.299942463636\n",
"epoch:31500 train mean loss=0.303368359804\n",
"epoch:31600 train mean loss=0.299751162529\n",
"epoch:31700 train mean loss=0.300648093224\n",
"epoch:31800 train mean loss=0.299620181322\n",
"epoch:31900 train mean loss=0.299550026655\n",
"epoch:32000 train mean loss=0.299581199884\n",
"epoch:32100 train mean loss=0.299404680729\n",
"epoch:32200 train mean loss=0.300023168325\n",
"epoch:32300 train mean loss=0.299285322428\n",
"epoch:32400 train mean loss=0.299195408821\n",
"epoch:32500 train mean loss=0.300106972456\n",
"epoch:32600 train mean loss=0.299063831568\n",
"epoch:32700 train mean loss=0.298996120691\n",
"epoch:32800 train mean loss=0.301807135344\n",
"epoch:32900 train mean loss=0.29896324873\n",
"epoch:33000 train mean loss=0.298790425062\n",
"epoch:33100 train mean loss=0.298866122961\n",
"epoch:33200 train mean loss=0.298656284809\n",
"epoch:33300 train mean loss=0.298620909452\n",
"epoch:33400 train mean loss=0.301293730736\n",
"epoch:33500 train mean loss=0.298484444618\n",
"epoch:33600 train mean loss=0.298389077187\n",
"epoch:33700 train mean loss=0.298373103142\n",
"epoch:33800 train mean loss=0.298268288374\n",
"epoch:33900 train mean loss=0.298193216324\n",
"epoch:34000 train mean loss=0.298143565655\n",
"epoch:34100 train mean loss=0.307447820902\n",
"epoch:34200 train mean loss=0.297997444868\n",
"epoch:34300 train mean loss=0.297934323549\n",
"epoch:34400 train mean loss=0.298451930285\n",
"epoch:34500 train mean loss=0.297804027796\n",
"epoch:34600 train mean loss=0.297748446465\n",
"epoch:34700 train mean loss=0.297688066959\n",
"epoch:34800 train mean loss=0.308348745108\n",
"epoch:34900 train mean loss=0.297550559044\n",
"epoch:35000 train mean loss=0.297623038292\n",
"epoch:35100 train mean loss=0.299623310566\n",
"epoch:35200 train mean loss=0.297395825386\n",
"epoch:35300 train mean loss=0.297301203012\n",
"epoch:35400 train mean loss=0.297259509563\n",
"epoch:35500 train mean loss=0.298073500395\n",
"epoch:35600 train mean loss=0.297115474939\n",
"epoch:35700 train mean loss=0.297309368849\n",
"epoch:35800 train mean loss=0.297997355461\n",
"epoch:35900 train mean loss=0.30123513937\n",
"epoch:36000 train mean loss=0.302880972624\n",
"epoch:36100 train mean loss=0.296812802553\n",
"epoch:36200 train mean loss=0.29674872756\n",
"epoch:36300 train mean loss=0.307402461767\n",
"epoch:36400 train mean loss=0.296707242727\n",
"epoch:36500 train mean loss=0.296569257975\n",
"epoch:36600 train mean loss=0.296524018049\n",
"epoch:36700 train mean loss=0.297457069159\n",
"epoch:36800 train mean loss=0.296395897865\n",
"epoch:36900 train mean loss=0.296792954206\n",
"epoch:37000 train mean loss=0.298014909029\n",
"epoch:37100 train mean loss=0.296293199062\n",
"epoch:37200 train mean loss=0.296202689409\n",
"epoch:37300 train mean loss=0.296123474836\n",
"epoch:37400 train mean loss=0.296047776937\n",
"epoch:37500 train mean loss=0.29599031806\n",
"epoch:37600 train mean loss=0.295957833529\n",
"epoch:37700 train mean loss=0.307167738676\n",
"epoch:37800 train mean loss=0.295820564032\n",
"epoch:37900 train mean loss=0.295789271593\n",
"epoch:38000 train mean loss=0.295714735985\n",
"epoch:38100 train mean loss=0.295713573694\n",
"epoch:38200 train mean loss=0.296147704124\n",
"epoch:38300 train mean loss=0.295808196068\n",
"epoch:38400 train mean loss=0.295494288206\n",
"epoch:38500 train mean loss=0.295455008745\n",
"epoch:38600 train mean loss=0.295552790165\n",
"epoch:38700 train mean loss=0.295327395201\n",
"epoch:38800 train mean loss=0.317661732435\n",
"epoch:38900 train mean loss=0.295218706131\n",
"epoch:39000 train mean loss=0.299768358469\n",
"epoch:39100 train mean loss=0.295127242804\n",
"epoch:39200 train mean loss=0.295069098473\n",
"epoch:39300 train mean loss=0.295096039772\n",
"epoch:39400 train mean loss=0.29495254159\n",
"epoch:39500 train mean loss=0.294953644276\n",
"epoch:39600 train mean loss=0.294852524996\n",
"epoch:39700 train mean loss=0.299119263887\n",
"epoch:39800 train mean loss=0.294748276472\n",
"epoch:39900 train mean loss=0.294712215662\n",
"epoch:40000 train mean loss=0.294731557369\n",
"epoch:40100 train mean loss=0.294594615698\n",
"epoch:40200 train mean loss=0.294544219971\n",
"epoch:40300 train mean loss=0.314175754786\n",
"epoch:40400 train mean loss=0.294441640377\n",
"epoch:40500 train mean loss=0.294942080975\n",
"epoch:40600 train mean loss=0.295029520988\n",
"epoch:40700 train mean loss=0.294298529625\n",
"epoch:40800 train mean loss=0.294265091419\n",
"epoch:40900 train mean loss=0.294349759817\n",
"epoch:41000 train mean loss=0.306304693222\n",
"epoch:41100 train mean loss=0.294096827507\n",
"epoch:41200 train mean loss=0.294087916613\n",
"epoch:41300 train mean loss=0.300193369389\n",
"epoch:41400 train mean loss=0.293954223394\n",
"epoch:41500 train mean loss=0.293906062841\n",
"epoch:41600 train mean loss=0.29413741827\n",
"epoch:41700 train mean loss=0.295879274607\n",
"epoch:41800 train mean loss=0.293755233288\n",
"epoch:41900 train mean loss=0.293728470802\n",
"epoch:42000 train mean loss=0.298869729042\n",
"epoch:42100 train mean loss=0.293612837791\n",
"epoch:42200 train mean loss=0.293574541807\n",
"epoch:42300 train mean loss=0.294453859329\n",
"epoch:42400 train mean loss=0.293473601341\n",
"epoch:42500 train mean loss=0.297279566526\n",
"epoch:42600 train mean loss=0.293381243944\n",
"epoch:42700 train mean loss=0.294555008411\n",
"epoch:42800 train mean loss=0.293386310339\n",
"epoch:42900 train mean loss=0.293245047331\n",
"epoch:43000 train mean loss=0.29320409894\n",
"epoch:43100 train mean loss=0.293887257576\n",
"epoch:43200 train mean loss=0.296142578125\n",
"epoch:43300 train mean loss=0.293074965477\n",
"epoch:43400 train mean loss=0.293064892292\n",
"epoch:43500 train mean loss=0.293043881655\n",
"epoch:43600 train mean loss=0.292933791876\n",
"epoch:43700 train mean loss=0.293460249901\n",
"epoch:43800 train mean loss=0.309539079666\n",
"epoch:43900 train mean loss=0.292802125216\n",
"epoch:44000 train mean loss=0.293288260698\n",
"epoch:44100 train mean loss=0.292713463306\n",
"epoch:44200 train mean loss=0.293365508318\n",
"epoch:44300 train mean loss=0.292625814676\n",
"epoch:44400 train mean loss=0.29267424345\n",
"epoch:44500 train mean loss=0.293920695782\n",
"epoch:44600 train mean loss=0.292497456074\n",
"epoch:44700 train mean loss=0.292532473803\n",
"epoch:44800 train mean loss=0.292414575815\n",
"epoch:44900 train mean loss=0.292801588774\n",
"epoch:45000 train mean loss=0.292331457138\n",
"epoch:45100 train mean loss=0.292316317558\n",
"epoch:45200 train mean loss=0.292837500572\n",
"epoch:45300 train mean loss=0.292220532894\n",
"epoch:45400 train mean loss=0.292164713144\n",
"epoch:45500 train mean loss=0.29532968998\n",
"epoch:45600 train mean loss=0.292092591524\n",
"epoch:45700 train mean loss=0.292056530714\n",
"epoch:45800 train mean loss=0.292046755552\n",
"epoch:45900 train mean loss=0.291977256536\n",
"epoch:46000 train mean loss=0.291925936937\n",
"epoch:46100 train mean loss=0.293803840876\n",
"epoch:46200 train mean loss=0.291845291853\n",
"epoch:46300 train mean loss=0.300246447325\n",
"epoch:46400 train mean loss=0.291764855385\n",
"epoch:46500 train mean loss=0.291726797819\n",
"epoch:46600 train mean loss=0.291685223579\n",
"epoch:46700 train mean loss=0.291653990746\n",
"epoch:46800 train mean loss=0.291615366936\n",
"epoch:46900 train mean loss=0.291570425034\n",
"epoch:47000 train mean loss=0.292142659426\n",
"epoch:47100 train mean loss=0.303349345922\n",
"epoch:47200 train mean loss=0.291460245848\n",
"epoch:47300 train mean loss=0.29142421484\n",
"epoch:47400 train mean loss=0.313145369291\n",
"epoch:47500 train mean loss=0.291339725256\n",
"epoch:47600 train mean loss=0.291306138039\n",
"epoch:47700 train mean loss=0.291265368462\n",
"epoch:47800 train mean loss=0.292409271002\n",
"epoch:47900 train mean loss=0.291192978621\n",
"epoch:48000 train mean loss=0.291167706251\n",
"epoch:48100 train mean loss=0.293959289789\n",
"epoch:48200 train mean loss=0.291084080935\n",
"epoch:48300 train mean loss=0.291054308414\n",
"epoch:48400 train mean loss=0.302320361137\n",
"epoch:48500 train mean loss=0.290974408388\n",
"epoch:48600 train mean loss=0.291046977043\n",
"epoch:48700 train mean loss=0.290899753571\n",
"epoch:48800 train mean loss=0.29087126255\n",
"epoch:48900 train mean loss=0.301152408123\n",
"epoch:49000 train mean loss=0.290869742632\n",
"epoch:49100 train mean loss=0.290761172771\n",
"epoch:49200 train mean loss=0.291226267815\n",
"epoch:49300 train mean loss=0.290697187185\n",
"epoch:49400 train mean loss=0.290658682585\n",
"epoch:49500 train mean loss=0.290621548891\n",
"epoch:49600 train mean loss=0.290820926428\n",
"epoch:49700 train mean loss=0.293107151985\n",
"epoch:49800 train mean loss=0.290510058403\n",
"epoch:49900 train mean loss=0.29054531455\n",
"epoch:50000 train mean loss=0.292022317648\n"
]
}
],
"source": [
"# 確率的勾配降下法で学習させる際の1回分のバッチサイズ\n",
"batchsize = 11\n",
"# 学習の繰り返し回数\n",
"n_epoch = 50000\n",
"N = X.shape[0]\n",
"# Learning loop\n",
"for epoch in xrange(1, n_epoch+1):\n",
" #print 'epoch', epoch\n",
"\n",
" # N個の順番をランダムに並び替える\n",
" perm = np.random.permutation(N)\n",
" sum_loss = 0\n",
" # 0〜Nまでのデータをバッチサイズごとに使って学習\n",
" for i in xrange(0, N, batchsize):\n",
" x_batch = X[perm[i:i+batchsize]]\n",
" y_batch = y[perm[i:i+batchsize]]\n",
"\n",
" # 勾配を初期化\n",
" optimizer.zero_grads()\n",
" # 順伝播させて誤差を算出\n",
" loss = forward(x_batch, y_batch)\n",
" # 誤差逆伝播で勾配を計算\n",
" loss.backward()\n",
" optimizer.weight_decay(0.01)\n",
" optimizer.update()\n",
" sum_loss += loss.data * batchsize\n",
"\n",
" # 訓練データの誤差と、正解精度を表示\n",
" if epoch % 100 == 0:\n",
" print 'epoch:{} train mean loss={}'.format(epoch, sum_loss / N)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[ 6.07095146]\n",
" [ -0.90871316]\n",
" [ -0.47295213]\n",
" [ 316.75616455]\n",
" [ 443.15646362]\n",
" [ -0.95780635]\n",
" [ 199.78878784]\n",
" [ 70.07402039]\n",
" [ -0.80237651]\n",
" [ 15.18830585]\n",
" [ 16.46421432]]\n",
"[[ 5.80544519e+00]\n",
" [ -2.21579187e-02]\n",
" [ -3.42641264e-01]\n",
" [ 3.16766022e+02]\n",
" [ 4.43226501e+02]\n",
" [ -6.06474221e-01]\n",
" [ 1.99794052e+02]\n",
" [ 6.99967651e+01]\n",
" [ -2.24294662e+00]\n",
" [ 1.52209854e+01]\n",
" [ 1.68022785e+01]]\n"
]
},
{
"data": {
"text/plain": [
"array(0.29077136516571045, dtype=float32)"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x_test = Variable(X[range(11)])\n",
"t = Variable(y[range(11)])\n",
"h1 = expo(model.l1(x_test))\n",
"y_test = model.l2(h1)\n",
"print y_test.data\n",
"print t.data\n",
"F.mean_squared_error(y_test, t).data"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"array([[ 2.24651504e+00, 1.13109624e+00, -2.16289815e-02],\n",
" [ 1.10743952e+00, 4.99774486e-01, 8.01726012e-04]], dtype=float32)"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.l1.W"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"array([[ 0.59255338, 0.89250612]], dtype=float32)"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.l2.W"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.10"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment