Skip to content

Instantly share code, notes, and snippets.

@yumatsuoka
Created August 13, 2016 14:55
Show Gist options
  • Save yumatsuoka/b339a5816cb1e5c9642a5ff3f92ebbac to your computer and use it in GitHub Desktop.
Save yumatsuoka/b339a5816cb1e5c9642a5ff3f92ebbac to your computer and use it in GitHub Desktop.
training 2 layer neural network. Used only numpy library on randomize list and indexing
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 2 layer neural network and its training"
]
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# -*- coding: utf-8 -*-\n",
"\n",
"from __future__ import print_function\n",
"\n",
"%matplotlib inline\n",
"import six\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## neural network model"
]
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def forward(x, t, w, b):\n",
" \"\"\"\n",
" neural netの順伝播を行う関数\n",
" 入力:neural netの入力(リスト)、教師ラベル(リスト)\n",
" 出力:順伝播の誤差(double型のスカラー)\n",
" \"\"\"\n",
" h = [0. for i in six.moves.range(len(t))]\n",
" y = [0. for i in six.moves.range(len(t))]\n",
" loss = 0.\n",
" \n",
" # 行列積\n",
" for i in six.moves.range(len(t)):\n",
" for j in six.moves.range(len(x)):\n",
" h[i] += w[i][j] * x[j]\n",
" \n",
" # バイアスの加算\n",
" for i in six.moves.range(len(t)):\n",
" y[i] = h[i] + b[i]\n",
" \n",
" # 活性化関数をyに適用\n",
" for i in six.moves.range(len(t)):\n",
" y[i] = activation(y[i])\n",
" \n",
" # lossの計算\n",
" for i in six.moves.range(len(t)):\n",
" loss += (y[i] - t[i]) ** 2\n",
" \n",
" return loss, y\n",
"\n",
"\n",
"def activation(x):\n",
" \"\"\"\n",
" 活性化関数:今回は y=x\n",
" 入力:double型のスカラー\n",
" 出力:活性化関数を通したdouble型のスカラー\n",
" \"\"\"\n",
" return x * 1.\n",
"\n",
"\n",
"def backward(x, loss, w, b, y, t):\n",
" \"\"\"\n",
" パラメータの更新を行うための微分値の計算\n",
" 入力:誤差loss(double型のスカラー)、重みw(リスト)、\n",
" バイアスb(リスト)\n",
" 出力:wの微分値(リスト)、bの微分値(リスト)\n",
" \"\"\"\n",
" diff_w = [[0. for i in six.moves.range(len(w[0]))]\\\n",
" for j in six.moves.range(len(w))]\n",
" diff_b = [0. for i in six.moves.range(len(b))]\n",
" for i in six.moves.range(len(w)):\n",
" for j in six.moves.range(len(w[0])):\n",
" diff_w[i][j] += (y[i] - t[i]) * 1. * x[j]\n",
" diff_b[i] += (y[i] - t[i]) * 1. * 1.\n",
" \n",
" return diff_w, diff_b"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## training neural net (trainer)"
]
},
{
"cell_type": "code",
"execution_count": 53,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"def update(alpha, w, b, dw, db):\n",
" \"\"\"\n",
" パラメータ(重みw, バイアスb)を誤差逆伝播法に従って更新する\n",
" wとbに対してそれぞれの微分値に学習率をかけて減算\n",
" 入力:学習率aplha(double), 重みw(リスト), バイアスb(リスト),\n",
" wの微分値dw(リスト), bの微分値db(リスト)\n",
" 出力:更新後のw(リスト), b(リスト)\n",
" \"\"\"\n",
" for i in six.moves.range(len(w)):\n",
" for j in six.moves.range(len(w[0])):\n",
" w[i][j] -= alpha * dw[i][j]\n",
" b[i] -= alpha * db[i]\n",
" \n",
" return w, b\n",
"\n",
"\n",
"def train_nn(epoch, dataset, target, batchsize, alpha, w, b):\n",
" loss_list = []\n",
" for itr in six.moves.range(epoch):\n",
" perm = np.random.permutation(len(dataset))\n",
" ave_loss = 0\n",
" \n",
" for i in six.moves.range(0, len(dataset), batchsize):\n",
" diff_w = [[0. for i in six.moves.range(len(dataset[0]))]\\\n",
" for j in six.moves.range(len(target[0]))]\n",
" diff_b = [0. for i in six.moves.range(len(target[0]))]\n",
" \n",
" x = dataset[perm[i : i + batchsize]]\n",
" t = target[perm[i : i + batchsize]]\n",
" \n",
" #forward, backward on neural net\n",
" for bidx in six.moves.range(batchsize):\n",
" loss, y = forward(x[bidx], t[bidx], w, b)\n",
" dfw, dfb = backward(x[bidx], loss, w, b, y, t[bidx])\n",
" diff_w = [[diff_w[i][j] + dfw[i][j]\\\n",
" for j in six.moves.range(len(dataset[0]))]\\\n",
" for i in six.moves.range(len(target[0]))]\n",
" diff_b += dfw\n",
" ave_loss += loss\n",
" \n",
" #update params\n",
" w, b = update(alpha, w, b, diff_w, diff_b)\n",
" \n",
" loss_list.append(ave_loss / len(dataset))\n",
" print(\"epoch={}, loss={}\".format(itr, ave_loss/len(dataset)))\n",
" return loss_list"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## parameters on neural net"
]
},
{
"cell_type": "code",
"execution_count": 54,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# hyper parameters\n",
"alpha = 0.0001\n",
"epoch = 500\n",
"batchsize = 6\n",
"\n",
"# set dataset and target\n",
"dataset = np.asarray([[1, 1], [2, 1], [1, 3], [2, 4], [4, 3], [4, 2]])\n",
"#target = np.asarray([[1], [1], [2], [2], [3], [3]])\n",
"target = np.asarray([[1, 0, 0], [1, 0, 0], [0, 1, 0], [0, 1, 0], [0, 0, 1], [0, 0, 1]])\n",
"\n",
"# neural parameters\n",
"w = [[np.random.rand() * 0.01 for j in six.moves.range(len(dataset[0]))]\\\n",
" for i in six.moves.range(len(target[0]))]\n",
"b = [1.0 for i in six.moves.range(len(target[0]))]"
]
},
{
"cell_type": "code",
"execution_count": 55,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch=0, loss=2.0622183934591263\n",
"epoch=1, loss=2.0429150213025147\n",
"epoch=2, loss=2.0238943329986694\n",
"epoch=3, loss=2.005152111517894\n",
"epoch=4, loss=1.9866842028496556\n",
"epoch=5, loss=1.9684865150606694\n",
"epoch=6, loss=1.9505550173670667\n",
"epoch=7, loss=1.932885739220433\n",
"epoch=8, loss=1.9154747694075012\n",
"epoch=9, loss=1.8983182551633087\n",
"epoch=10, loss=1.8814124012976017\n",
"epoch=11, loss=1.8647534693343042\n",
"epoch=12, loss=1.8483377766638431\n",
"epoch=13, loss=1.832161695708148\n",
"epoch=14, loss=1.8162216530981272\n",
"epoch=15, loss=1.8005141288634423\n",
"epoch=16, loss=1.7850356556343872\n",
"epoch=17, loss=1.7697828178557025\n",
"epoch=18, loss=1.754752251012136\n",
"epoch=19, loss=1.739940640865581\n",
"epoch=20, loss=1.7253447227036152\n",
"epoch=21, loss=1.7109612805992738\n",
"epoch=22, loss=1.6967871466818831\n",
"epoch=23, loss=1.6828192004187958\n",
"epoch=24, loss=1.6690543679078598\n",
"epoch=25, loss=1.6554896211804653\n",
"epoch=26, loss=1.6421219775150027\n",
"epoch=27, loss=1.6289484987605907\n",
"epoch=28, loss=1.6159662906709047\n",
"epoch=29, loss=1.6031725022479668\n",
"epoch=30, loss=1.5905643250957409\n",
"epoch=31, loss=1.5781389927833933\n",
"epoch=32, loss=1.565893780218066\n",
"epoch=33, loss=1.5538260030270272\n",
"epoch=34, loss=1.5419330169490582\n",
"epoch=35, loss=1.5302122172349322\n",
"epoch=36, loss=1.5186610380568597\n",
"epoch=37, loss=1.5072769519267568\n",
"epoch=38, loss=1.496057469123208\n",
"epoch=39, loss=1.4850001371269979\n",
"epoch=40, loss=1.4741025400650765\n",
"epoch=41, loss=1.463362298162835\n",
"epoch=42, loss=1.452777067204572\n",
"epoch=43, loss=1.4423445380020183\n",
"epoch=44, loss=1.4320624358708116\n",
"epoch=45, loss=1.4219285201147915\n",
"epoch=46, loss=1.4119405835180032\n",
"epoch=47, loss=1.4020964518442964\n",
"epoch=48, loss=1.3923939833444032\n",
"epoch=49, loss=1.3828310682703784\n",
"epoch=50, loss=1.3734056283973064\n",
"epoch=51, loss=1.3641156165521462\n",
"epoch=52, loss=1.3549590161496299\n",
"epoch=53, loss=1.3459338407350883\n",
"epoch=54, loss=1.3370381335341153\n",
"epoch=55, loss=1.3282699670089608\n",
"epoch=56, loss=1.3196274424215564\n",
"epoch=57, loss=1.311108689403069\n",
"epoch=58, loss=1.3027118655298915\n",
"epoch=59, loss=1.294435155905968\n",
"epoch=60, loss=1.2862767727513662\n",
"epoch=61, loss=1.2782349549969922\n",
"epoch=62, loss=1.2703079678853715\n",
"epoch=63, loss=1.2624941025773881\n",
"epoch=64, loss=1.2547916757649111\n",
"epoch=65, loss=1.2471990292892021\n",
"epoch=66, loss=1.2397145297650314\n",
"epoch=67, loss=1.2323365682104117\n",
"epoch=68, loss=1.2250635596818653\n",
"epoch=69, loss=1.2178939429151425\n",
"epoch=70, loss=1.2108261799713114\n",
"epoch=71, loss=1.203858755888135\n",
"epoch=72, loss=1.1969901783366625\n",
"epoch=73, loss=1.1902189772829512\n",
"epoch=74, loss=1.1835437046548434\n",
"epoch=75, loss=1.1769629340137266\n",
"epoch=76, loss=1.1704752602311987\n",
"epoch=77, loss=1.1640792991705637\n",
"epoch=78, loss=1.1577736873730942\n",
"epoch=79, loss=1.1515570817489733\n",
"epoch=80, loss=1.1454281592728666\n",
"epoch=81, loss=1.1393856166840355\n",
"epoch=82, loss=1.1334281701909406\n",
"epoch=83, loss=1.1275545551802524\n",
"epoch=84, loss=1.1217635259302166\n",
"epoch=85, loss=1.1160538553283048\n",
"epoch=86, loss=1.110424334593077\n",
"epoch=87, loss=1.1048737730002134\n",
"epoch=88, loss=1.0994009976126307\n",
"epoch=89, loss=1.0940048530146378\n",
"epoch=90, loss=1.0886842010500626\n",
"epoch=91, loss=1.0834379205642948\n",
"epoch=92, loss=1.078264907150183\n",
"epoch=93, loss=1.07316407289773\n",
"epoch=94, loss=1.0681343461475286\n",
"epoch=95, loss=1.0631746712478842\n",
"epoch=96, loss=1.0582840083155671\n",
"epoch=97, loss=1.0534613330001399\n",
"epoch=98, loss=1.0487056362518083\n",
"epoch=99, loss=1.0440159240927405\n",
"epoch=100, loss=1.0393912173918072\n",
"epoch=101, loss=1.0348305516426848\n",
"epoch=102, loss=1.0303329767452791\n",
"epoch=103, loss=1.0258975567904123\n",
"epoch=104, loss=1.0215233698477315\n",
"epoch=105, loss=1.0172095077567842\n",
"epoch=106, loss=1.0129550759212187\n",
"epoch=107, loss=1.0087591931060598\n",
"epoch=108, loss=1.004620991238015\n",
"epoch=109, loss=1.0005396152087631\n",
"epoch=110, loss=0.9965142226811872\n",
"epoch=111, loss=0.9925439838985\n",
"epoch=112, loss=0.9886280814962233\n",
"epoch=113, loss=0.9847657103169789\n",
"epoch=114, loss=0.9809560772280461\n",
"epoch=115, loss=0.9771984009416473\n",
"epoch=116, loss=0.9734919118379185\n",
"epoch=117, loss=0.9698358517905264\n",
"epoch=118, loss=0.9662294739948937\n",
"epoch=119, loss=0.9626720427989871\n",
"epoch=120, loss=0.9591628335366408\n",
"epoch=121, loss=0.955701132363365\n",
"epoch=122, loss=0.9522862360946132\n",
"epoch=123, loss=0.9489174520464646\n",
"epoch=124, loss=0.9455940978786878\n",
"epoch=125, loss=0.9423155014401514\n",
"epoch=126, loss=0.939081000616543\n",
"epoch=127, loss=0.9358899431803671\n",
"epoch=128, loss=0.9327416866431825\n",
"epoch=129, loss=0.9296355981100491\n",
"epoch=130, loss=0.9265710541361519\n",
"epoch=131, loss=0.9235474405855643\n",
"epoch=132, loss=0.9205641524921262\n",
"epoch=133, loss=0.9176205939223987\n",
"epoch=134, loss=0.9147161778406697\n",
"epoch=135, loss=0.9118503259759745\n",
"epoch=136, loss=0.9090224686911051\n",
"epoch=137, loss=0.9062320448535792\n",
"epoch=138, loss=0.9034785017085339\n",
"epoch=139, loss=0.9007612947535234\n",
"epoch=140, loss=0.8980798876151846\n",
"epoch=141, loss=0.8954337519277482\n",
"epoch=142, loss=0.8928223672133648\n",
"epoch=143, loss=0.8902452207642199\n",
"epoch=144, loss=0.8877018075264108\n",
"epoch=145, loss=0.8851916299855592\n",
"epoch=146, loss=0.8827141980541343\n",
"epoch=147, loss=0.8802690289604572\n",
"epoch=148, loss=0.8778556471393685\n",
"epoch=149, loss=0.8754735841245266\n",
"epoch=150, loss=0.8731223784423173\n",
"epoch=151, loss=0.870801575507349\n",
"epoch=152, loss=0.8685107275195092\n",
"epoch=153, loss=0.8662493933625606\n",
"epoch=154, loss=0.8640171385042521\n",
"epoch=155, loss=0.8618135348979227\n",
"epoch=156, loss=0.8596381608855772\n",
"epoch=157, loss=0.8574906011024096\n",
"epoch=158, loss=0.8553704463827527\n",
"epoch=159, loss=0.8532772936674355\n",
"epoch=160, loss=0.8512107459125248\n",
"epoch=161, loss=0.8491704119994282\n",
"epoch=162, loss=0.8471559066463444\n",
"epoch=163, loss=0.8451668503210351\n",
"epoch=164, loss=0.8432028691549007\n",
"epoch=165, loss=0.8412635948583403\n",
"epoch=166, loss=0.8393486646373772\n",
"epoch=167, loss=0.8374577211115306\n",
"epoch=168, loss=0.8355904122329139\n",
"epoch=169, loss=0.8337463912065428\n",
"epoch=170, loss=0.831925316411836\n",
"epoch=171, loss=0.8301268513252854\n",
"epoch=172, loss=0.8283506644442861\n",
"epoch=173, loss=0.8265964292121027\n",
"epoch=174, loss=0.8248638239439577\n",
"epoch=175, loss=0.8231525317542244\n",
"epoch=176, loss=0.821462240484708\n",
"epoch=177, loss=0.8197926426339986\n",
"epoch=178, loss=0.818143435287881\n",
"epoch=179, loss=0.8165143200507828\n",
"epoch=180, loss=0.8149050029782497\n",
"epoch=181, loss=0.813315194510427\n",
"epoch=182, loss=0.8117446094065373\n",
"epoch=183, loss=0.8101929666803377\n",
"epoch=184, loss=0.8086599895365407\n",
"epoch=185, loss=0.8071454053081853\n",
"epoch=186, loss=0.8056489453949456\n",
"epoch=187, loss=0.8041703452023592\n",
"epoch=188, loss=0.8027093440819663\n",
"epoch=189, loss=0.8012656852723419\n",
"epoch=190, loss=0.79983911584101\n",
"epoch=191, loss=0.7984293866272258\n",
"epoch=192, loss=0.797036252185613\n",
"epoch=193, loss=0.7956594707306438\n",
"epoch=194, loss=0.7942988040819484\n",
"epoch=195, loss=0.7929540176104424\n",
"epoch=196, loss=0.791624880185259\n",
"epoch=197, loss=0.7903111641214738\n",
"epoch=198, loss=0.7890126451286129\n",
"epoch=199, loss=0.7877291022599281\n",
"epoch=200, loss=0.7864603178624318\n",
"epoch=201, loss=0.7852060775276782\n",
"epoch=202, loss=0.7839661700432816\n",
"epoch=203, loss=0.7827403873451556\n",
"epoch=204, loss=0.7815285244704709\n",
"epoch=205, loss=0.7803303795113122\n",
"epoch=206, loss=0.7791457535690304\n",
"epoch=207, loss=0.7779744507092752\n",
"epoch=208, loss=0.7768162779177018\n",
"epoch=209, loss=0.7756710450563374\n",
"epoch=210, loss=0.7745385648206026\n",
"epoch=211, loss=0.773418652696971\n",
"epoch=212, loss=0.7723111269212661\n",
"epoch=213, loss=0.7712158084375792\n",
"epoch=214, loss=0.7701325208578016\n",
"epoch=215, loss=0.7690610904217614\n",
"epoch=216, loss=0.7680013459579568\n",
"epoch=217, loss=0.7669531188448757\n",
"epoch=218, loss=0.7659162429728941\n",
"epoch=219, loss=0.7648905547067429\n",
"epoch=220, loss=0.763875892848537\n",
"epoch=221, loss=0.7628720986013554\n",
"epoch=222, loss=0.7618790155333675\n",
"epoch=223, loss=0.760896489542494\n",
"epoch=224, loss=0.7599243688215968\n",
"epoch=225, loss=0.7589625038241884\n",
"epoch=226, loss=0.7580107472306555\n",
"epoch=227, loss=0.7570689539149856\n",
"epoch=228, loss=0.7561369809119917\n",
"epoch=229, loss=0.7552146873850262\n",
"epoch=230, loss=0.7543019345941792\n",
"epoch=231, loss=0.7533985858649501\n",
"epoch=232, loss=0.7525045065573881\n",
"epoch=233, loss=0.7516195640356941\n",
"epoch=234, loss=0.7507436276382773\n",
"epoch=235, loss=0.7498765686482579\n",
"epoch=236, loss=0.7490182602644131\n",
"epoch=237, loss=0.7481685775725547\n",
"epoch=238, loss=0.7473273975173358\n",
"epoch=239, loss=0.7464945988744796\n",
"epoch=240, loss=0.7456700622234215\n",
"epoch=241, loss=0.744853669920362\n",
"epoch=242, loss=0.7440453060717204\n",
"epoch=243, loss=0.7432448565079869\n",
"epoch=244, loss=0.7424522087579648\n",
"epoch=245, loss=0.7416672520233977\n",
"epoch=246, loss=0.7408898771539754\n",
"epoch=247, loss=0.7401199766227154\n",
"epoch=248, loss=0.7393574445017091\n",
"epoch=249, loss=0.7386021764382332\n",
"epoch=250, loss=0.7378540696312178\n",
"epoch=251, loss=0.7371130228080642\n",
"epoch=252, loss=0.7363789362018117\n",
"epoch=253, loss=0.7356517115286443\n",
"epoch=254, loss=0.7349312519657336\n",
"epoch=255, loss=0.7342174621294134\n",
"epoch=256, loss=0.7335102480536798\n",
"epoch=257, loss=0.7328095171690142\n",
"epoch=258, loss=0.7321151782815202\n",
"epoch=259, loss=0.7314271415523743\n",
"epoch=260, loss=0.7307453184775836\n",
"epoch=261, loss=0.7300696218680454\n",
"epoch=262, loss=0.7293999658299047\n",
"epoch=263, loss=0.7287362657452064\n",
"epoch=264, loss=0.7280784382528355\n",
"epoch=265, loss=0.7274264012297426\n",
"epoch=266, loss=0.7267800737724492\n",
"epoch=267, loss=0.726139376178832\n",
"epoch=268, loss=0.7255042299301762\n",
"epoch=269, loss=0.7248745576734987\n",
"epoch=270, loss=0.7242502832041361\n",
"epoch=271, loss=0.7236313314485909\n",
"epoch=272, loss=0.7230176284476362\n",
"epoch=273, loss=0.7224091013396715\n",
"epoch=274, loss=0.7218056783443277\n",
"epoch=275, loss=0.7212072887463185\n",
"epoch=276, loss=0.7206138628795294\n",
"epoch=277, loss=0.7200253321113496\n",
"epoch=278, loss=0.7194416288272326\n",
"epoch=279, loss=0.7188626864154936\n",
"epoch=280, loss=0.7182884392523281\n",
"epoch=281, loss=0.7177188226870589\n",
"epoch=282, loss=0.7171537730276009\n",
"epoch=283, loss=0.7165932275261447\n",
"epoch=284, loss=0.7160371243650531\n",
"epoch=285, loss=0.7154854026429699\n",
"epoch=286, loss=0.7149380023611339\n",
"epoch=287, loss=0.7143948644098993\n",
"epoch=288, loss=0.713855930555458\n",
"epoch=289, loss=0.7133211434267585\n",
"epoch=290, loss=0.7127904465026211\n",
"epoch=291, loss=0.7122637840990466\n",
"epoch=292, loss=0.7117411013567131\n",
"epoch=293, loss=0.7112223442286614\n",
"epoch=294, loss=0.7107074594681619\n",
"epoch=295, loss=0.7101963946167666\n",
"epoch=296, loss=0.7096890979925358\n",
"epoch=297, loss=0.7091855186784426\n",
"epoch=298, loss=0.7086856065109517\n",
"epoch=299, loss=0.7081893120687667\n",
"epoch=300, loss=0.7076965866617467\n",
"epoch=301, loss=0.7072073823199894\n",
"epoch=302, loss=0.7067216517830753\n",
"epoch=303, loss=0.7062393484894741\n",
"epoch=304, loss=0.7057604265661096\n",
"epoch=305, loss=0.7052848408180795\n",
"epoch=306, loss=0.7048125467185297\n",
"epoch=307, loss=0.7043435003986795\n",
"epoch=308, loss=0.7038776586379957\n",
"epoch=309, loss=0.7034149788545152\n",
"epoch=310, loss=0.7029554190953088\n",
"epoch=311, loss=0.702498938027091\n",
"epoch=312, loss=0.7020454949269687\n",
"epoch=313, loss=0.7015950496733278\n",
"epoch=314, loss=0.7011475627368563\n",
"epoch=315, loss=0.7007029951717024\n",
"epoch=316, loss=0.700261308606763\n",
"epoch=317, loss=0.6998224652371042\n",
"epoch=318, loss=0.6993864278155093\n",
"epoch=319, loss=0.6989531596441517\n",
"epoch=320, loss=0.6985226245663957\n",
"epoch=321, loss=0.6980947869587157\n",
"epoch=322, loss=0.6976696117227398\n",
"epoch=323, loss=0.697247064277409\n",
"epoch=324, loss=0.6968271105512575\n",
"epoch=325, loss=0.696409716974804\n",
"epoch=326, loss=0.6959948504730602\n",
"epoch=327, loss=0.6955824784581491\n",
"epoch=328, loss=0.695172568822035\n",
"epoch=329, loss=0.6947650899293617\n",
"epoch=330, loss=0.6943600106103968\n",
"epoch=331, loss=0.6939573001540827\n",
"epoch=332, loss=0.6935569283011921\n",
"epoch=333, loss=0.6931588652375821\n",
"epoch=334, loss=0.6927630815875544\n",
"epoch=335, loss=0.6923695484073105\n",
"epoch=336, loss=0.6919782371785065\n",
"epoch=337, loss=0.6915891198019049\n",
"epoch=338, loss=0.6912021685911185\n",
"epoch=339, loss=0.6908173562664514\n",
"epoch=340, loss=0.6904346559488294\n",
"epoch=341, loss=0.6900540411538221\n",
"epoch=342, loss=0.6896754857857541\n",
"epoch=343, loss=0.6892989641319053\n",
"epoch=344, loss=0.6889244508567947\n",
"epoch=345, loss=0.6885519209965548\n",
"epoch=346, loss=0.6881813499533841\n",
"epoch=347, loss=0.6878127134900879\n",
"epoch=348, loss=0.6874459877246966\n",
"epoch=349, loss=0.6870811491251673\n",
"epoch=350, loss=0.6867181745041621\n",
"epoch=351, loss=0.6863570410139074\n",
"epoch=352, loss=0.6859977261411264\n",
"epoch=353, loss=0.6856402077020514\n",
"epoch=354, loss=0.6852844638375063\n",
"epoch=355, loss=0.6849304730080662\n",
"epoch=356, loss=0.6845782139892879\n",
"epoch=357, loss=0.6842276658670109\n",
"epoch=358, loss=0.68387880803273\n",
"epoch=359, loss=0.6835316201790365\n",
"epoch=360, loss=0.6831860822951271\n",
"epoch=361, loss=0.6828421746623808\n",
"epoch=362, loss=0.6824998778500008\n",
"epoch=363, loss=0.6821591727107226\n",
"epoch=364, loss=0.6818200403765857\n",
"epoch=365, loss=0.6814824622547687\n",
"epoch=366, loss=0.6811464200234857\n",
"epoch=367, loss=0.6808118956279459\n",
"epoch=368, loss=0.6804788712763714\n",
"epoch=369, loss=0.6801473294360766\n",
"epoch=370, loss=0.679817252829605\n",
"epoch=371, loss=0.679488624430923\n",
"epoch=372, loss=0.6791614274616721\n",
"epoch=373, loss=0.6788356453874774\n",
"epoch=374, loss=0.6785112619143084\n",
"epoch=375, loss=0.6781882609848978\n",
"epoch=376, loss=0.6778666267752117\n",
"epoch=377, loss=0.6775463436909724\n",
"epoch=378, loss=0.6772273963642341\n",
"epoch=379, loss=0.6769097696500102\n",
"epoch=380, loss=0.6765934486229489\n",
"epoch=381, loss=0.6762784185740607\n",
"epoch=382, loss=0.6759646650074932\n",
"epoch=383, loss=0.675652173637356\n",
"epoch=384, loss=0.6753409303845902\n",
"epoch=385, loss=0.6750309213738878\n",
"epoch=386, loss=0.6747221329306544\n",
"epoch=387, loss=0.6744145515780194\n",
"epoch=388, loss=0.6741081640338896\n",
"epoch=389, loss=0.6738029572080478\n",
"epoch=390, loss=0.6734989181992929\n",
"epoch=391, loss=0.673196034292625\n",
"epoch=392, loss=0.6728942929564719\n",
"epoch=393, loss=0.6725936818399556\n",
"epoch=394, loss=0.6722941887702016\n",
"epoch=395, loss=0.6719958017496875\n",
"epoch=396, loss=0.6716985089536305\n",
"epoch=397, loss=0.6714022987274162\n",
"epoch=398, loss=0.6711071595840621\n",
"epoch=399, loss=0.6708130802017239\n",
"epoch=400, loss=0.6705200494212341\n",
"epoch=401, loss=0.6702280562436811\n",
"epoch=402, loss=0.6699370898280224\n",
"epoch=403, loss=0.6696471394887341\n",
"epoch=404, loss=0.669358194693496\n",
"epoch=405, loss=0.6690702450609102\n",
"epoch=406, loss=0.6687832803582546\n",
"epoch=407, loss=0.6684972904992699\n",
"epoch=408, loss=0.6682122655419797\n",
"epoch=409, loss=0.6679281956865424\n",
"epoch=410, loss=0.6676450712731361\n",
"epoch=411, loss=0.6673628827798751\n",
"epoch=412, loss=0.6670816208207566\n",
"epoch=413, loss=0.6668012761436392\n",
"epoch=414, loss=0.6665218396282511\n",
"epoch=415, loss=0.6662433022842272\n",
"epoch=416, loss=0.6659656552491778\n",
"epoch=417, loss=0.6656888897867834\n",
"epoch=418, loss=0.6654129972849204\n",
"epoch=419, loss=0.6651379692538127\n",
"epoch=420, loss=0.6648637973242125\n",
"epoch=421, loss=0.6645904732456082\n",
"epoch=422, loss=0.6643179888844568\n",
"epoch=423, loss=0.6640463362224466\n",
"epoch=424, loss=0.6637755073547822\n",
"epoch=425, loss=0.6635054944884975\n",
"epoch=426, loss=0.6632362899407928\n",
"epoch=427, loss=0.6629678861373965\n",
"epoch=428, loss=0.6627002756109529\n",
"epoch=429, loss=0.6624334509994317\n",
"epoch=430, loss=0.6621674050445634\n",
"epoch=431, loss=0.6619021305902967\n",
"epoch=432, loss=0.6616376205812796\n",
"epoch=433, loss=0.6613738680613634\n",
"epoch=434, loss=0.661110866172128\n",
"epoch=435, loss=0.6608486081514301\n",
"epoch=436, loss=0.6605870873319734\n",
"epoch=437, loss=0.6603262971398985\n",
"epoch=438, loss=0.6600662310933969\n",
"epoch=439, loss=0.6598068828013415\n",
"epoch=440, loss=0.659548245961942\n",
"epoch=441, loss=0.6592903143614164\n",
"epoch=442, loss=0.6590330818726857\n",
"epoch=443, loss=0.6587765424540862\n",
"epoch=444, loss=0.6585206901481006\n",
"epoch=445, loss=0.6582655190801099\n",
"epoch=446, loss=0.6580110234571624\n",
"epoch=447, loss=0.6577571975667621\n",
"epoch=448, loss=0.657504035775674\n",
"epoch=449, loss=0.6572515325287484\n",
"epoch=450, loss=0.6569996823477625\n",
"epoch=451, loss=0.656748479830279\n",
"epoch=452, loss=0.6564979196485211\n",
"epoch=453, loss=0.6562479965482664\n",
"epoch=454, loss=0.6559987053477553\n",
"epoch=455, loss=0.6557500409366162\n",
"epoch=456, loss=0.655501998274807\n",
"epoch=457, loss=0.6552545723915734\n",
"epoch=458, loss=0.6550077583844197\n",
"epoch=459, loss=0.6547615514180993\n",
"epoch=460, loss=0.6545159467236161\n",
"epoch=461, loss=0.6542709395972436\n",
"epoch=462, loss=0.6540265253995577\n",
"epoch=463, loss=0.6537826995544839\n",
"epoch=464, loss=0.6535394575483581\n",
"epoch=465, loss=0.6532967949290037\n",
"epoch=466, loss=0.6530547073048193\n",
"epoch=467, loss=0.6528131903438826\n",
"epoch=468, loss=0.6525722397730664\n",
"epoch=469, loss=0.6523318513771682\n",
"epoch=470, loss=0.6520920209980526\n",
"epoch=471, loss=0.6518527445338068\n",
"epoch=472, loss=0.6516140179379083\n",
"epoch=473, loss=0.6513758372184056\n",
"epoch=474, loss=0.6511381984371106\n",
"epoch=475, loss=0.6509010977088031\n",
"epoch=476, loss=0.6506645312004478\n",
"epoch=477, loss=0.6504284951304219\n",
"epoch=478, loss=0.6501929857677556\n",
"epoch=479, loss=0.6499579994313825\n",
"epoch=480, loss=0.6497235324894025\n",
"epoch=481, loss=0.6494895813583546\n",
"epoch=482, loss=0.6492561425025012\n",
"epoch=483, loss=0.649023212433123\n",
"epoch=484, loss=0.6487907877078252\n",
"epoch=485, loss=0.6485588649298512\n",
"epoch=486, loss=0.6483274407474106\n",
"epoch=487, loss=0.648096511853014\n",
"epoch=488, loss=0.6478660749828199\n",
"epoch=489, loss=0.6476361269159895\n",
"epoch=490, loss=0.647406664474052\n",
"epoch=491, loss=0.6471776845202808\n",
"epoch=492, loss=0.6469491839590757\n",
"epoch=493, loss=0.6467211597353574\n",
"epoch=494, loss=0.6464936088339696\n",
"epoch=495, loss=0.6462665282790906\n",
"epoch=496, loss=0.6460399151336527\n",
"epoch=497, loss=0.6458137664987712\n",
"epoch=498, loss=0.6455880795131828\n",
"epoch=499, loss=0.6453628513526889\n"
]
}
],
"source": [
"loss_list = train_nn(epoch=epoch, dataset=dataset, target=target,\\\n",
" batchsize=batchsize,alpha=alpha, w=w, b=b)\n"
]
},
{
"cell_type": "code",
"execution_count": 56,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYkAAAEZCAYAAABiu9n+AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzt3Xm8XfO5x/HPN4MpEwmCTGJIlYpQlCKOmVI6GhsubU1t\nub29Faq90tu6bXWgqoOgKhRVauYS5Nw2xhQZkCASkUQGIkREBXnuH791ZDvOPjkn2fusPXzfr9d6\nnbXX+u21nr1ysp/zm9ZSRGBmZtaSTnkHYGZmlctJwszMinKSMDOzopwkzMysKCcJMzMryknCzMyK\ncpKwuifpPElX5x1HtZJ0paT/zjsOKw8nCbPEE4bMWuAkYRVHUue8Y6gGvk7WEZwkrF0kjZQ0XdIS\nSU9J+ly2fS1JiyVtW1B2Q0nLJG2YvT5M0pNZufGSti8oO1PSWZImAUsldSp2rqx8J0m/lPSKpBck\nfUPSCkmdsv09JV0u6WVJsyX9SJLa+BkPz873mqQHJG3T7PPPyWKaKmmfbPsukiZIekPSPEm/KHLs\nvbN4zslinyHp2IL9a0n6haRZ2XF+J2ntZu89S9I84I9FznGSpGckLZJ0t6SBBftWSPpWds0WSrqg\nYJ8kfV/Si5LmS/qTpJ4F+/eU9GD27zdL0vEFp+0t6Y7sujwsaXBbrrVVgYjw4qXNC/BFoG+2/mVg\nacHry4EfFZQ9HbgrW98RWADsDAgYAcwEumb7ZwJPAJsBa7fhXKcCTwGbAr2AscD7QKds/83A74B1\ngA2BR4CvF/lM5wFjsvUh2Xn2BToD3wWeB7pk+14qiGEgMDhbfwg4LltfD9i1yLn2Bt4Ffg50BYZn\n59s6238hcEv2mboBtwLnN3vv/2TvXbuF4x8BPJfF2gn4HvBgwf4VwP3Z8fsDzwInZftOyt47KPsM\nNxVcl0HAEuDI7LpsAAzN9l0JvAJ8MjvnNcC1ef+ueinR//m8A/BS3QvwJPDZbH0/YHrBvvEFX5y/\nA37Y7L3TgL2y9ZnACe041/2FX/rZud/PvqT6Av8q/BIFjgYeKHLcwiTxfeD6gn0C5mRf5lsC87Nz\ndWl2jMbsOH1W8Rn2BpYD6xRs+wtwbra+tCnxZK93B2YUvPdfZIm1yPHvAk4seN0JeAsYkL1eARxQ\nsP80YGy2fh9wasG+IcA72THOBm4qcs4rgdEFrw8Bnsn7d9NLaRY3N1m7SDq+oMloMbAd6S91gHHA\nulnTyyBgB9JfxZD+Ev1O1oTzWvbe/qSaQ5M57TjXZsDsguKF6wNJf2nPKzjXHwre25rNgFlNLyJ9\n680G+kXEC8C/A6OABZKulbRpVvSrwMeAaZIelXRoK+dYHBH/Kng9C9hM0kakv+Afb7pOwN1An4Ky\nr0TEu60cexDw64L3LyJ1yvcrKFN4nWex8t/gQ589W+9CSroDgBdaOe/8gvVlQPdWyloV6ZJ3AFY9\nsrbt0cA+EfFwtu1J0l/bRMQKSTcAx5Kalu6IiLeyt88mNZv8pJVTfDDCaFXnAuaRkkyTgQXrs0l/\ncffJvuTb42XgE822DQDmAkTE9cD1krpn8f2UVAN6gfS5kfRF4EZJvSPi7RbOsYGkdQv2DQSmAK+S\nvmC3i4h5ReJb1ed5CfhxRFzXSpkBwNRsfRDpM5P9HFRQbhDwHunfcjaw6yrObTXINQlrj26k5opX\ns47jE/noF+p1wFGkL8xrC7ZfBpwqaVcASd0kfUZSt9U81w3AmZI2k7Q+cFbTjoiYD9wLXCipR9Yh\nu4Wk4W34jDcAh0raR1IXSf9JSjgPSRqSbV+L1GT0dhYjko5T1kEPvEH6Ml9R5BwCfiipq6S9gEOB\nG7KEdhlwUVarQFI/SQe2Ie4mlwLfUzaAQFIvSV9qVua7ktaXNAA4A7g+234d8G1Jm2dJ8HxS09sK\n4M/AfpK+JKmzpN6SdmhHXFalnCSszSJiKvBLUifwfFLzz/hmZR4jtYFvSmoqadr+OPB14JKsGeQ5\n4ITCt7bzXJeREsFk4HHgTuC97AsN4HhgLeAZ4DXgr8AmbfiMzwFfAS4hdcYeSuoHeQ9Ym1RzeIX0\nV/dGwDnZWw8Gnpa0hNT5fFREvFPkNPOAxdkxrgZOiYjns30jgenAI5Jezz7jkFXFXRD/LVmM12fv\nn5zFVuhW0jV7AridlaOk/pjF83dS09IyUhIhImYDnwH+k3Q9nwSGtjUuq15qf228HQeX+gNjSG2a\nK4DLIuLiZmWOJf3HAHgTOC0ippQtKKtJkg4Gfh8RFT30UtLewNURMXCVhctz/hXAVhExI4/zW/Up\nd03iPeA/ImI70iiNb6hgzHlmBjA8InYAfkz6C9GsVZLWkXRI1vTRjzSy6G95x2VWa8qaJCJifkRM\nzNaXkjrL+jUr80hEvJG9fKT5frMiBPyQ1PTxOPA0KVFY63z7EWuXsjY3fehE0uakseSfyBJGS2X+\nExgSESd3SFBmZtaqDhkCm42UuBE4s5UEsQ9wIrBnR8RkZmarVvYkIakLKUFcHRG3FikzlDTm/OCI\nWFykjKvJZmarISLadN+ylnTEENg/kqbo/7qlndmkqZuAEdmEpKLynp5eKct5552XewyVsvha+Fr4\nWrS+rKmy1iQk7QEcB0zJZssG6YZjg0h3PBgN/ADoDfxOkoB3I8IzO83MKkBZk0REPEi6Y2RrZb5O\nmmRlZmYVxjOuq1BDQ0PeIVQMX4uVfC1W8rUonQ4bArumJEW1xGpmVikkERXecW1mZlXKScLMzIpy\nkjAzs6KcJMzMrCgnCTMzK8pJwszMinKSMDOzopwkzMysKCcJMzMryknCzMyKcpIwM7OinCTMzKwo\nJwkzMyuqqpKEbwJrZtaxqipJTJ+edwRmZvWlqpLEuHF5R2BmVl+qKkk88EDeEZiZ1ZeyJglJ/SU9\nIOlpSVMknVGk3MWSnpc0UdKwYscbN879EmZmHancNYn3gP+IiO2A3YFvSNqmsICkQ4AtI2Jr4BTg\nD8UO1r07PPNMOcM1M7NCZU0SETE/IiZm60uBqUC/ZsWOAMZkZR4Feknq29Lx9tnHTU5mZh2pw/ok\nJG0ODAMebbarHzC74PVcPppIANh3X3dem5l1pA5JEpK6AzcCZ2Y1itWyzz7Q2AgrVpQsNDMza0WX\ncp9AUhdSgrg6Im5tochcYEDB6/7Zto+49NJRdOoEp50GxxzTQENDQ8njNTOrZo2NjTQ2NpbseIoy\nDxeSNAZ4NSL+o8j+zwDfiIhDJe0GXBQRu7VQLiKCb3wDttgCvvOdsoZtZlYTJBERWt33l3sI7B7A\nccC+kp6U9ISkgyWdIulkgIi4C5gpaTpwKXB6a8fcZx/3S5iZdZSy1yRKpakm8eqrsOWW8Oqr0LVr\n3lGZmVW2iq5JlMOGG6Yk8dhjeUdiZlb7qi5JABx4INxzT95RmJnVvqpMEgcdBPfem3cUZma1r+r6\nJADeeQc22ghefBF69843LjOzSlZ3fRIAa68Nw4fD/ffnHYmZWW2ryiQB7pcwM+sIVZskmvolqqS1\nzMysKlVtkhgyBCSYNi3vSMzMalfVJgkp1Sbc5GRmVj5VmyQg9Ut4KKyZWflU5RDYJosXw6BBsHAh\nrLNOToGZmVWwuhwC22SDDWC77WD8+LwjMTOrTVWdJAAOOQTuvjvvKMzMalPVJ4nDDoM77sg7CjOz\n2lT1SWLHHWHpUnjuubwjMTOrPVWfJCQ49FC48868IzEzqz1VnyTATU5mZuVS1UNgm7z1Fmy6Kcye\nDb16dXBgZmYVrK6HwDbp1g323NMT68zMSq2sSULSFZIWSJpcZH9PSbdJmihpiqR/W91zHXqom5zM\nzEqtrM1NkvYElgJjImJoC/vPAXpGxDmSNgSeBfpGxHstlC3a3ATpAUS77ALz50PnziX7CGZmVa2i\nm5siYjywuLUiQI9svQewqKUE0Rabbw6bbAKPPbY67zYzs5bk3SdxCbCtpJeBScCZa3Iwj3IyMyut\nLjmf/yDgyYjYV9KWwFhJQyNiaUuFR40a9cF6Q0MDDQ0NH9r/2c/CySfD+eeXL2Azs0rW2NhIY2Nj\nyY5X9iGwkgYBtxfpk7gD+ElEPJi9vh8YGRH/bKFsq30SACtWQL9+8Pe/w9ZblyZ+M7NqVtF9Ehll\nS0tmAfsDSOoLDAFmrO6JOnWCz30Obr55dY9gZmaFyj0E9lrgIWCIpJcknSjpFEknZ0V+DHw6GyI7\nFjgrIl5bk3N+4Qvwt7+tWdxmZpbUxIzrQu++C337wpQpqenJzKyeVUNzU4fq2jWNcrrllrwjMTOr\nfjWXJMBNTmZmpVJzzU0Ay5alG/7NmAF9+pQ5MDOzCubmphastx7svz/cfnvekZiZVbeaTBLgJicz\ns1KoyeYmgNdfh4EDYe5c6NFj1eXNzGqRm5uKWH99GD4cbrst70jMzKpXzSYJgKOPhuuvzzsKM7Pq\nVbPNTQBLlsCAATBzJvTuXabAzMwqmJubWtGzJxxwgO/lZGa2umo6SUBqcvrLX/KOwsysOtV0cxOk\niXWbbQbPPQcbb1yGwMzMKpibm1ZhvfXSvZxuvDHvSMzMqk/NJwnwKCczs9VV881NAMuXp3s5TZoE\n/fuXODAzswrm5qY2WGst+PznXZswM2uvukgSAF/5CowZA1VScTIzqwh1kySGD0+T6yZNyjsSM7Pq\nUTdJolMnOP54uOqqvCMxM6seZU0Skq6QtEDS5FbKNEh6UtJTksaVM54RI+Daa9NzsM3MbNXKXZO4\nEjio2E5JvYDfAodFxCeAL5czmK23hq22gnvuKedZzMxqR1mTRESMBxa3UuRY4KaImJuVf7Wc8YCb\nnMzM2iPvPokhQG9J4yRNkDSi3Cc86igYOxYWt5a6zMwMgC4VcP6dgH2BbsDDkh6OiOktFR41atQH\n6w0NDTQ0NLT7hOuvDwcdlG76d+qpqxOymVnlamxspLGxsWTHK/uMa0mDgNsjYmgL+0YC60TED7PX\nlwN3R8RNLZRd7RnXzd15J/zoR/DIIyU5nJlZxaqGGdfKlpbcCuwpqbOk9YBPAVPLHdBBB8GcOTC5\n6JgrMzOD8g+BvRZ4CBgi6SVJJ0o6RdLJABExDbgHmAw8AoyOiGfKGRNAly7w1a/CZZeV+0xmZtWt\nLm7w15KXXoIdd4TZs9PtxM3MalE1NDdVpIED4VOf8nMmzMxaU7dJAuDkk2H06LyjMDOrXHWdJA49\nFGbMgKefzjsSM7PKVNdJomtXOOkkd2CbmRVTtx3XTWbOhF12SR3Y665b8sObmeXKHddraPDglCT8\n1Dozs4+q+yQBcMYZ8Jvf+Kl1ZmbNOUmQZmAvXQoPPph3JGZmlcVJgvTUum9+M9UmzMxspbrvuG6y\nZAlsvnm6n1P//mU7jZlZh3LHdYn07AnHHQd/+EPekZiZVQ7XJAo8+ywMHw6zZsE665T1VGZmHcI1\niRL62MfSTf88HNbMLHGSaObb34Zf/tLDYc3MwEniIw48MI12+t//zTsSM7P8OUk0I8FZZ8EFF+Qd\niZlZ/tqUJCSdKamnkiskPSHpwHIHl5cjj0x3h33ssbwjMTPLV1trEidFxBLgQGADYATw07JFlbOu\nXVPfxM9/nnckZmb5amuSaBo+9Rng6oh4umBbTfra12DcOJg+Pe9IzMzy09Yk8bike0lJ4h5JPYAV\nq3pT1jS1QNLkVZTbRdK7kr7QxnjKrnt3OPXUNNLJzKxetWkynaROwDBgRkS8Lqk30D8iVvXlvyew\nFBgTEUNbOfZY4G3gjxHxtyLlyj6ZrrkFC+DjH4ennoLNNuvQU5uZlURHTabbHXg2SxBfAb4PvLGq\nN0XEeGDxKop9C7gRWNjGWDpM375wwgke6WRm9autSeL3wDJJOwDfAV4AxqzpySVtBnwuIn5PhfZx\nnHUWjBkD8+fnHYmZWcfr0sZy70VESDoCuCQirpD01RKc/yJgZMHrVhPFqFGjPlhvaGigoaGhBCG0\nbtNN4StfSSOd3D9hZpWusbGRxsbGkh2vrX0S/wf8L3ASsBepaWhSRGzfhvcOAm5vqU9C0oymVWBD\n4C3g5Ii4rYWyHd4n0WTuXNh+e5g2DTbeOJcQzMxWS0f1SRwFvEOaLzEf6A+0dRaBKFJDiIgtsmUw\nqV/i9JYSRN769YNjj4Vf/CLvSMzMOlabbxUuqS+wS/bysYhYZUezpGuBBqAPsAA4D1gLiIgY3azs\nH4E7Kml0U6E5c2Do0HQ78Y02yi0MM7N2WdOaRFubm44k1RwaSbWCvYDvRsSNq3vi9so7SQB861vQ\npQtceGGuYZiZtVlHJYlJwAFNtQdJGwH3RcQOq3vi9qqEJLFgAWy7LTzxBAwalGsoZmZt0lF9Ep2a\nNS8tasd7a0bfvnD66VAwyMrMrKa1tSbxc2AocF226ShgckSMLP6u0qqEmgTAG2/A1lun+zptt13e\n0ZiZta5DmpuyE30R2CN7+Y+IuHl1T7o6KiVJQJovMX483NyhV8DMrP06LEnkrZKSxNtvw5AhcMMN\nsPvueUdjZlZcWZOEpDeBlgqINIy15+qeuL0qKUkAXHklXH55qlGoIm8oYmZW5o7riOgRET1bWHp0\nZIKoRMcfn2oUN9yQdyRmZuXj5qY18Pe/w4gR6XYd666bdzRmZh/VUUNgrQXDh8Ouu/rGf2ZWu1yT\nWEMzZ8LOO8OUKX4wkZlVHo9uqgDnnAMvvwxXXZV3JGZmH+YkUQHefDM95vT662HPPfOOxsxsJfdJ\nVIAePdJN/047Dd59N+9ozMxKx0miRL70pfTciYsuyjsSM7PScXNTCU2fDrvtlu4SO3Bg3tGYmbm5\nqaJstRWccUZazMxqgZNEiY0cCVOn+uZ/ZlYb3NxUBv/4Bxx9dJo70bt33tGYWT3zENgKdcYZ8Prr\nMGZM3pGYWT2r6D4JSVdIWiBpcpH9x0qalC3jJW1fzng60k9+ku4Qe+edeUdiZrb6yt0ncSVwUCv7\nZwDDs2dl/xi4rMzxdJhu3eCKK+DUU9PT7MzMqlHZm5skDQJuj4ihqyi3PjAlIgYU2V9VzU1NTj8d\nli2DP/0p70jMrB5VdHNTO30NuDvvIErtggvgoYfgxhvzjsTMrP265B0AgKR9gBOBVu98NGrUqA/W\nGxoaaGhoKGtcpdC9O1xzDXz2s2miXf/+eUdkZrWssbGRxsbGkh0v9+YmSUOBm4CDI+KFVo5Tlc1N\nTc4/Hx54AMaOhU6VVH8zs5pWDc1NypaP7pAGkhLEiNYSRC04+2xYvhx+9au8IzEza7uy1iQkXQs0\nAH2ABcB5wFpARMRoSZcBXwBmkRLJuxGxa5FjVXVNAuDFF9OT7G6/HT71qbyjMbN64Ml0VeaWW+DM\nM9NNAPv0yTsaM6t1ThJV6DvfgWnTUo3C/RNmVk7V0Cdhzfz0p+mWHT/7Wd6RmJm1zjWJnMyZAzvv\nDH/+M+y3X97RmFmtck2iSvXvD9ddB8ceCy/U9LguM6tmThI52mcfOO+8NNFuyZK8ozEz+yg3N1WA\n006D2bPh1luhc+e8ozGzWuLmphpw8cXw1ltw7rl5R2Jm9mFOEhWga1f461/hhht8t1gzqywVcYM/\ngw03hLvugoYG6NsXDjkk74jMzFyTqCjbbAM33wwnnAATJuQdjZmZk0TF2X13uPxyOPxwmD4972jM\nrN65uakCHX44LFgABx0EDz4Im2ySd0RmVq+cJCrU178OCxfC/vvDuHGw0UZ5R2Rm9chJooKdey68\n/TYccEB6YFHv3nlHZGb1xpPpKlwEnHVWqk3cdx+sv37eEZlZNfGtwutABPz7v8Ojj8K990LPnnlH\nZGbVwjOu64AEF10En/xkumPsokV5R2Rm9cJJokpIcMklsO++sPfeMG9e3hGZWT1wx3UVkdIDi3r1\ngr32Sn0Um2+ed1RmVsvKWpOQdIWkBZImt1LmYknPS5ooaVg546kFEnzve+k52cOHwzPP5B2RmdWy\ncjc3XQkcVGynpEOALSNia+AU4A9ljqdmfOtbcP756ZkUjY15R2NmtaqsSSIixgOLWylyBDAmK/so\n0EtS33LGVEtGjEhPtzvySLjmmryjMbNalHefRD9gdsHrudm2BfmEU3323TfNoTj0UJg5E77//dQk\nZWZWCnkniXYZNWrUB+sNDQ00NDTkFksl2W47ePjh9BjUZ5+F0aNhvfXyjsrM8tDY2EhjCdugyz6Z\nTtIg4PaIGNrCvj8A4yLiL9nracDeEfGRmkQ9T6Zrq2XL0j2fpk6Fv/3NI5/MrDom0ylbWnIbcDyA\npN2A11tKENY2662X+iZGjIDddkv3ezIzWxNlrUlIuhZoAPqQ+hnOA9YCIiJGZ2UuAQ4G3gJOjIgn\nihzLNYl2eOABOPbYdDuPs86CTp42aVaXfO8mK+qll1Ki6NYNxoxJj0U1s/pSDc1NlpOBA9Mcip13\nhp12SjO0zczawzWJOnHffXD88XDccfCjH8E66+QdkZl1BNckrE323x8mTkxzKXbaCSZMyDsiM6sG\nThJ1ZOON4a9/hR/8AA47LE28e+edvKMys0rmJFFnJDjmmFSrmDwZdtklPczIzKwlThJ1atNN4dZb\nYeRI+Pzn4dRTYXFrd9kys7rkJFHHpNSR/cwz0LkzbLttGirr8QFm1sSjm+wDEybAaaelmdsXXpge\nl2pm1c2jm6xkmvonRoxINws8/niYPXvV7zOz2uUkYR/SuXO6SeCzz8KgQTBsWBoF9eabeUdmZnlw\nkrAW9eiRJt1NmpRqE1ttBRdcAG+9lXdkZtaRnCSsVf37w1VXpdt7PP44bLkl/OpX8PbbeUdmZh3B\nScLa5OMfh7/8BcaOhfHjU7L4+c9hyZK8IzOzcnKSsHbZfvv0QKO77koT8rbYAs4+G+bNyzsyMysH\nJwlbLcOGwZ//nIbNLluWHqHa9FQ8M6sdThK2RgYPhosvhueeS/0X++6blptugvfeyzs6M1tTnkxn\nJbV8eWqO+u1v0x1nTz451TA23TTvyMzqkyfTWUVZay04+mj4xz/gzjth7tzU6X344Sl5LF+ed4Rm\n1h5lTxKSDpY0TdJzkka2sL+npNskTZQ0RdK/lTsm6xg77ACXXgpz5sAXvwi/+Q306wdnnAFPPOF7\nRJlVg7I2N0nqBDwH7Ae8DEwAjo6IaQVlzgF6RsQ5kjYEngX6RsR7zY7l5qYaMHNmmndx1VXQvTsc\ndRQceSQMGZJ3ZGa1qdKbm3YFno+IWRHxLnA9cESzMgH0yNZ7AIuaJwirHYMHw6hR8MIL8LvfwcKF\nsPfeabTU//wPPP983hGaWaFyJ4l+QOEt4uZk2wpdAmwr6WVgEnBmmWOyCtCpE+y1VxoZNWdO+vny\ny2nbsGHwX/+VhteuWJF3pGb1rRI6rg8CnoyIzYAdgd9K6p5zTNaBOneG4cPhkktSR/fFF8O//pXu\nQtuvH3zta+kBSb5vlFnH61Lm488FBha87p9tK3Qi8BOAiHhB0kxgG+CfzQ82atSoD9YbGhpoaGgo\nbbSWu6aEMXx4uqHg9Olwxx0pcYwYAZ/+NOy/f1qGDk01EjNbqbGxkcbGxpIdr9wd151JHdH7AfOA\nx4BjImJqQZnfAgsj4oeS+pKSww4R8VqzY7njus698QY88ADcdx/cfz8sWpQm7jUljcGD847QrPKs\nacd12SfTSToY+DWpaeuKiPippFOAiIjRkjYF/gQ0Tbf6SURc18JxnCTsQ2bPTsmiKWmsu27q09hz\nT9hjD9hmG9c0zCo+SZSKk4S1JiLdN+rBB9NdasePh9dfT8lijz1S4vjkJ2GddfKO1KxjOUmYFfHy\nyylpNCWOZ56Bj30sPaa1adluO+jaNe9IzcrHScKsjd5+Oz1p75//TMNrJ0yAWbNSB/guu8COO6ZZ\n4ttu6xqH1Q4nCbM1sGRJukXIhAkpgUyalEZUDR6cEsbQoWnZYYc0HFer/V/NLB9OEmYl9s47MG0a\nTJ6cksbkyWlZvjzdrPDjH0+d4k3L4MFp6K5ZJXKSMOsgCxemzvFp09LStL5gAWy11YcTx1ZbpUe8\n9unj2ofly0nCLGdvvZUeulSYOF54IS0RKVk0JY3CpV8/10Cs/JwkzCpUBLz22sqE0XxZtAgGDUrL\nwIErfzYtAwak53OYrQknCbMqtWwZvPgivPTSymXWrJXrL7+cmqsKE8iAAakGstlmadl0UycSa52T\nhFmNev99mDfvowlk3ryUQObOTf0h66+/Mmm0tPTrBxtv7KateuUkYVbHVqyAV15JSaNwmTv3w68X\nLYLevVOy6Ns3LU3rLf30PJHa4SRhZqv03nspmSxcmGofrf1cuDAliebJo08f2HDDD/9sWu/Z06O4\nKpWThJmVVES6427z5PHqq6lGsmjRyvWmn2+/nWoqxZJI8/XevVMzmW+JUn5OEmaWu+XLiyeQlra9\n9lpKROuuCxtskBLGBht8eGlpW+F2N4m1jZOEmVWlFSvgzTfT3XoXL165NH9dbFvnzi0nj169Vi49\nexZf79GjPjrznSTMrO5EpCaulpLIG2+kZcmSltebXi9dCt26tZ5IVpVoevWCtdeu7P4YJwkzs9XQ\nVJMplkzasv7GGylh9eixZkv37ivXS91P4yRhZpajd95JyaYty9Klqy7TteuaJZvmy9prO0mYmdWE\npma01Uk0S5euXJpeb745TJzoJGFmZkWsaXNT2R8TL+lgSdMkPSdpZJEyDZKelPSUpHHljsnMzNqm\nrElCUifgEuAgYDvgGEnbNCvTC/gtcFhEfAL4cjljqgWNjY15h1AxfC1W8rVYydeidMpdk9gVeD4i\nZkXEu8D1wBHNyhwL3BQRcwEi4tUyx1T1/B9gJV+LlXwtVvK1KJ1yJ4l+wOyC13OybYWGAL0ljZM0\nQdKIMsdkZmZt1CXvAEgx7ATsC3QDHpb0cERMzzcsMzMr6+gmSbsBoyLi4Oz12UBExM8KyowE1omI\nH2avLwfujoibmh3LQ5vMzFbDmoxuKndNYgKwlaRBwDzgaOCYZmVuBX4jqTOwNvAp4FfND7QmH9LM\nzFZPWZNERLwv6ZvAvaT+jysiYqqkU9LuGB0R0yTdA0wG3gdGR8Qz5YzLzMzapmom05mZWccr+2S6\nUmjLhLwu0NgJAAAEvUlEQVRaIukKSQskTS7YtoGkeyU9K+mebH5J075zJD0vaaqkA/OJuvQk9Zf0\ngKSnJU2RdEa2vR6vxdqSHs0mnU6RdF62ve6uRRNJnSQ9Iem27HVdXgtJL0qalP1uPJZtK921iIiK\nXkiJbDowCOgKTAS2yTuuMn/mPYFhwOSCbT8DzsrWRwI/zda3BZ4kNR1unl0r5f0ZSnQdNgGGZevd\ngWeBberxWmSfb73sZ2fgEdI8pLq8Ftln/DZwDXBb9rourwUwA9ig2baSXYtqqEm0ZUJeTYmI8cDi\nZpuPAK7K1q8CPpetHw5cHxHvRcSLwPOka1b1ImJ+REzM1pcCU4H+1OG1AIiIZdnq2qT/5EGdXgtJ\n/YHPAJcXbK7LawGIj7YKlexaVEOSaMuEvHqwcUQsgPTlCWycbW9+feZSg9dH0uak2tUjQN96vBZZ\n88qTwHxgbERMoE6vBXAh8F1SomxSr9cigLHZZOSvZdtKdi0qYTKdrZ66GXEgqTtwI3BmRCxtYc5M\nXVyLiFgB7CipJ3CzpO346Gev+Wsh6VBgQURMlNTQStGavxaZPSJinqSNgHslPUsJfy+qoSYxFxhY\n8Lp/tq3eLJDUF0DSJsDCbPtcYEBBuZq6PpK6kBLE1RFxa7a5Lq9Fk4hYAjQCB1Of12IP4HBJM4Dr\ngH0lXQ3Mr8NrQUTMy36+AtxCaj4q2e9FNSSJDybkSVqLNCHvtpxj6gjKlia3Af+WrZ9AmoTYtP1o\nSWtJGgxsBTzWUUF2gD8Cz0TErwu21d21kLRh0wgVSesCB5D6aOruWkTE9yJiYERsQfo+eCAiRgC3\nU2fXQtJ6WU0bSd2AA4EplPL3Iu+e+Tb23h9MGtnyPHB23vF0wOe9FngZeAd4CTgR2AC4L7sO9wLr\nF5Q/hzRKYSpwYN7xl/A67EGaYDmRNCLjiex3oXcdXovts88/kTTx9Nxse91di2bXZW9Wjm6qu2sB\nDC74/zGl6fuxlNfCk+nMzKyoamhuMjOznDhJmJlZUU4SZmZWlJOEmZkV5SRhZmZFOUmYmVlRThJm\nHUDS3pJuzzsOs/ZykjDrOJ6UZFXHScKsgKTjsof7PCHp99mdV9+U9CtJT0kaK6lPVnaYpIclTZR0\nU8FtM7bMyk2U9M/s9gcAPST9NXvYy9W5fUizdnCSMMtI2gY4Cvh0ROwErACOA9YDHouITwB/B87L\n3nIV8N2IGAY8VbD9z8Bvsu2fBuZl24cBZ5Ae/LKlpE+X/1OZrRnfKtxspf2AnYAJkgSsAywgJYsb\nsjLXADdlt+vuFekBUZASxg3Zzdb6RcRtABGxHCAdjsciu2OnpImkJ4M91AGfy2y1OUmYrSTgqog4\n90MbpR80KxcF5dvjnYL19/H/P6sCbm4yW+l+4EvZw1uaHiY/kPRM6S9lZY4Dxkd6psNrkvbIto8A\n/i/SY1ZnSzoiO8Za2a29zaqS/5Ixy0TEVEnfJz3dqxOwHPgm8Bawa1ajWEDqt4B0n/5LsyQwg3RL\nd0gJY7Sk/86O8eWWTle+T2JWOr5VuNkqSHozInrkHYdZHtzcZLZq/kvK6pZrEmZmVpRrEmZmVpST\nhJmZFeUkYWZmRTlJmJlZUU4SZmZWlJOEmZkV9f+MDb/uII+vBwAAAABJRU5ErkJggg==\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x278c1a16a90>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.plot(loss_list)\n",
"plt.title(\"average loss per epoch\")\n",
"plt.xlabel(\"epoch\")\n",
"plt.ylabel(\"loss\")\n",
"plt.xlim([0,epoch])\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"anaconda-cloud": {},
"kernelspec": {
"display_name": "Python [Root]",
"language": "python",
"name": "Python [Root]"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.2"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment