Created
January 29, 2020 19:37
-
-
Save kenenbek/9cecdaf66faf3c67f241aeea4992070f to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 115, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import utilities\n", | |
"\n", | |
"from __future__ import unicode_literals, print_function, division\n", | |
"from io import open\n", | |
"import unicodedata\n", | |
"import string\n", | |
"import re\n", | |
"import random\n", | |
"\n", | |
"import torch\n", | |
"import torch.nn as nn\n", | |
"from torch import optim\n", | |
"import torch.nn.functional as F\n", | |
"\n", | |
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 116, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"num_replicates = 10\n", | |
"length = int(3e3)\n", | |
"dataset = utilities.GenomDataset(num_replicates=num_replicates, length=length)\n", | |
"\n", | |
"X = dataset.get_X()\n", | |
"Y = dataset.get_X()\n", | |
"\n", | |
"XX = torch.tensor(X).type(torch.float)\n", | |
"YY = torch.tensor(Y).type(torch.float)\n", | |
"\n", | |
"input_tensor = XX.view(length, num_replicates, -1)\n", | |
"target_tensor = YY.view(length, num_replicates, -1)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 117, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class EncoderRNN(nn.Module):\n", | |
" def __init__(self, input_size, hidden_size, seq_length):\n", | |
" super(EncoderRNN, self).__init__()\n", | |
" self.input_size = input_size\n", | |
" self.hidden_size = hidden_size\n", | |
" self.seq_length = seq_length\n", | |
"\n", | |
" self.gru = nn.GRU(input_size, hidden_size)\n", | |
"\n", | |
" def forward(self, input, hidden):\n", | |
" # input = input.unsqueeze(0).unsqueeze(0)\n", | |
" output, hidden = self.gru(input, hidden)\n", | |
" return output, hidden\n", | |
"\n", | |
" def initHidden(self, batch_size):\n", | |
" return torch.zeros(1, batch_size, self.hidden_size, device=device)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 123, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class AttnDecoderRNN(nn.Module):\n", | |
" def __init__(self, input_size, hidden_size, output_size, length, dropout_p=0.1):\n", | |
" super(AttnDecoderRNN, self).__init__()\n", | |
" self.input_size = input_size\n", | |
" self.hidden_size = hidden_size\n", | |
" self.output_size = output_size\n", | |
" self.length = length\n", | |
" self.dropout_p = dropout_p\n", | |
"\n", | |
" self.attn = nn.Linear(self.input_size + self.hidden_size, self.length)\n", | |
" self.attn_combine = nn.Linear(self.input_size + self.hidden_size, self.hidden_size)\n", | |
" self.dropout = nn.Dropout(self.dropout_p)\n", | |
" self.gru = nn.GRU(self.hidden_size, self.hidden_size)\n", | |
" self.out = nn.Linear(self.hidden_size, self.output_size)\n", | |
"\n", | |
" def forward(self, input, hidden, encoder_outputs):\n", | |
" \n", | |
" repeat_vals = [input.shape[0] // hidden.shape[0]] + [-1] * (len(hidden.shape) - 1)\n", | |
" concatenated = torch.cat((input, hidden.expand(*repeat_vals)), dim=-1)\n", | |
" attn = self.attn(concatenated)\n", | |
" attn_weights = F.softmax(attn, dim=-1)\n", | |
" # print(\"attn_weights\", attn_weights.shape)\n", | |
" # print(\"encoder_outputs\", encoder_outputs.shape)\n", | |
" attn_applied = torch.bmm(attn_weights.permute(1, 0, 2), encoder_outputs.permute(1, 0, 2))\n", | |
" attn_applied = attn_applied.permute(1, 0, 2)\n", | |
" \n", | |
" # print(\"attn_applied\", attn_applied.shape)\n", | |
" # print(\"input\", input.shape)\n", | |
" output = torch.cat((input, attn_applied), dim=-1)\n", | |
" \n", | |
" output = self.attn_combine(output)\n", | |
"\n", | |
" # print(\"output\", output.shape)\n", | |
" output = F.relu(output)\n", | |
" output, hidden = self.gru(output, hidden)\n", | |
" \n", | |
" # print(\"output\", output.shape)\n", | |
" # print(\"hidden\", hidden.shape)\n", | |
" output = self.out(output)\n", | |
" # print(\"output\", output.shape)\n", | |
" return output, hidden, attn_weights\n", | |
"\n", | |
" def initHidden(self):\n", | |
" return torch.zeros(1, 1, self.hidden_size, device=device)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 193, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"input_size = 1\n", | |
"hidden_size = 128\n", | |
"encoder = EncoderRNN(input_size, hidden_size, length).to(device)\n", | |
"\n", | |
"# check\n", | |
"# encoder_output, encoder_hidden = encoder(input_tensor, encoder_hidden)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 194, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# check\n", | |
"\n", | |
"input_size_d = 1\n", | |
"hidden_size_d = 128\n", | |
"output_size_d = 1\n", | |
"decoder = AttnDecoderRNN(input_size_d, hidden_size_d, output_size_d, length, dropout_p=0.1).to(device)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 195, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"teacher_forcing_ratio = 0.5\n", | |
"\n", | |
"\n", | |
"def train(input_tensor, target_tensor, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion, seq_length):\n", | |
" encoder_hidden = encoder.initHidden(input_tensor.size(1))\n", | |
"\n", | |
" encoder_optimizer.zero_grad()\n", | |
" decoder_optimizer.zero_grad()\n", | |
"\n", | |
" input_length = input_tensor.size(0)\n", | |
" target_length = target_tensor.size(0)\n", | |
"\n", | |
" loss = 0\n", | |
"\n", | |
" encoder_outputs, encoder_hidden = encoder(input_tensor, encoder_hidden)\n", | |
"\n", | |
" decoder_input = target_tensor[0].unsqueeze(0)\n", | |
" decoder_hidden = encoder_hidden\n", | |
"\n", | |
" use_teacher_forcing = True # if random.random() < teacher_forcing_ratio else False\n", | |
" \n", | |
" ## TODO Without teacher forcing\n", | |
"# if use_teacher_forcing:\n", | |
"# # Teacher forcing: Feed the target as the next input\n", | |
"# for di in range(1, target_length):\n", | |
"# decoder_output, decoder_hidden, decoder_attention = decoder(\n", | |
"# decoder_input, decoder_hidden, encoder_outputs)\n", | |
"# # print(\"1\", decoder_output.size(), target_tensor[di].size())\n", | |
"# loss += criterion(decoder_output, target_tensor[di].unsqueeze(0))\n", | |
"# decoder_input = target_tensor[di] # Teacher forcing\n", | |
"# # print(\"2\", decoder_output.size(), target_tensor[di].size())\n", | |
"\n", | |
"# else:\n", | |
" # Without teacher forcing: use its own predictions as the next input\n", | |
" decoder_output, decoder_hidden, decoder_attention = decoder(\n", | |
" target_tensor, decoder_hidden, encoder_outputs)\n", | |
" loss += criterion(decoder_output, target_tensor)\n", | |
" decoder_input = decoder_output\n", | |
"\n", | |
" loss.backward()\n", | |
"\n", | |
" encoder_optimizer.step()\n", | |
" decoder_optimizer.step()\n", | |
"\n", | |
" return loss.item()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 196, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import time\n", | |
"\n", | |
"def trainIters(encoder, decoder, input_dataset, target_dataset, n_iters, print_every=1000, plot_every=100, learning_rate=0.01):\n", | |
" start = time.time()\n", | |
" plot_losses = []\n", | |
" print_loss_total = 0 # Reset every print_every\n", | |
" plot_loss_total = 0 # Reset every plot_every\n", | |
"\n", | |
" encoder_optimizer = optim.SGD(encoder.parameters(), lr=learning_rate)\n", | |
" decoder_optimizer = optim.SGD(decoder.parameters(), lr=learning_rate)\n", | |
" \n", | |
" criterion = nn.MSELoss()\n", | |
" \n", | |
" for i in range(0, n_iters):\n", | |
"\n", | |
" loss = train(input_tensor, target_tensor, encoder,\n", | |
" decoder, encoder_optimizer, decoder_optimizer, criterion, length)\n", | |
" print_loss_total += loss\n", | |
" plot_loss_total += loss\n", | |
" # print(loss)\n", | |
" if i % print_every == 0:\n", | |
" print_loss_avg = print_loss_total / print_every\n", | |
" print_loss_total = 0\n", | |
" print('%s (%d %d%%) %.4f' % (timeSince(start, (i+1) / n_iters),\n", | |
" i+1, (i+1) / n_iters * 100, print_loss_avg))\n", | |
"\n", | |
" if i % plot_every == 0:\n", | |
" plot_loss_avg = plot_loss_total / plot_every\n", | |
" plot_losses.append(plot_loss_avg)\n", | |
" plot_loss_total = 0\n", | |
" \n", | |
" showPlot(plot_losses)\n", | |
" return plot_losses " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 197, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from matplotlib import pyplot as plt \n", | |
"plt.switch_backend('agg')\n", | |
"%matplotlib inline \n", | |
"\n", | |
"from helper import asMinutes, timeSince, showPlot" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 198, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import matplotlib.pyplot as plt\n", | |
"plt.switch_backend('agg')\n", | |
"import matplotlib.ticker as ticker\n", | |
"import numpy as np\n", | |
"\n", | |
"\n", | |
"def showPlot(points):\n", | |
" plt.figure()\n", | |
" fig, ax = plt.subplots()\n", | |
" # this locator puts ticks at regular intervals\n", | |
" loc = ticker.MultipleLocator(base=0.2)\n", | |
" ax.yaxis.set_major_locator(loc)\n", | |
" plt.plot(points)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 199, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"0m 18s (- 2m 42s) (1 10%) 0.0165\n", | |
"0m 36s (- 2m 27s) (2 20%) 0.0151\n", | |
"0m 52s (- 2m 2s) (3 30%) 0.0138\n", | |
"1m 8s (- 1m 42s) (4 40%) 0.0126\n", | |
"1m 24s (- 1m 24s) (5 50%) 0.0115\n", | |
"1m 39s (- 1m 6s) (6 60%) 0.0105\n", | |
"1m 55s (- 0m 49s) (7 70%) 0.0096\n", | |
"2m 11s (- 0m 32s) (8 80%) 0.0087\n", | |
"2m 27s (- 0m 16s) (9 90%) 0.0080\n", | |
"2m 43s (- 0m 0s) (10 100%) 0.0073\n" | |
] | |
} | |
], | |
"source": [ | |
"losses = trainIters(encoder, decoder, input_tensor, target_tensor, 10, print_every=1, plot_every=1)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 203, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYAAAAD4CAYAAADlwTGnAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3dd3hVVb7/8fc3DQi9hBogoUqTYqihKiqgl+gFFRRBQJEmlhn9OTN35o5T7hQroAKCiKADIirgoIKNHpBQBaWE0EKR0Htfvz9ynAmZQE5Ckp3kfF7Pk4fsvdfe+Z6jyefsvfZa25xziIhI4AnyugAREfGGAkBEJEApAEREApQCQEQkQCkAREQCVIjXBWRFhQoVXFRUlNdliIgUKKtXrz7knItIv75ABUBUVBQJCQlelyEiUqCY2a6M1usSkIhIgFIAiIgEKAWAiEiAUgCIiAQoBYCISIBSAIiIBCgFgIhIgAqIAFi7+ygTFm1HU1+LiPxbgRoIll0fr9nLtBW7+OnEef7nrgYEBZnXJYmIeC4gAuCFno0ICTYmL9vBkdPn+XvvpoSFBMTJj4jINQVEAAQFGb+7uyEVShThxflbOHLmIuP7tSA8LCBevohIhgLmY7CZMaJLHf76301Yui2FByeu5OjpC16XJSLimYAJgJ/1aVWDcf1u4Yf9J+g9fjl7j531uiQREU8EXAAA3NmoMtMGteLgyfP0HrecbT+d9LokEZE8F5ABANC6Vnk+GNKWS1ccvcfHs3rXUa9LEhHJUwEbAAANq5bio6HtKBseykOTVvDtloNelyQikmf8CgAz62ZmW8ws0cyez2C7mdkY3/YNZtYizbbJZnbQzDZmsN8TvuNuMrO/39hLyZ4a5cP5cGg7akeU4LF3E/hkbbIXZYiI5LlMA8DMgoE3gO5AQ6CvmTVM16w7UNf3NQQYl2bbFKBbBsftAsQBNzvnGgEvZaP+HBFRsggzhrShVXQ5nv5gPZOWJHlViohInvHnDKAVkOicS3LOXQBmkPqHO604YKpLtQIoY2ZVAJxzi4EjGRx3GPBX59x5XztPr7+ULBrKOwNb0qNJZf4070f+8vmPmjpCRAo1fwKgGrAnzXKyb11W26RXD+hgZivNbJGZtcyokZkNMbMEM0tISUnxo9zsKxISzNi+LejXpgYTFiXx3KwNXLp8JVd/poiIV/wZCpvRxDnpPxr70yajn10WaAO0BGaaWS2X7mO3c+4t4C2AmJiYXP9IHhxk/DGuMRVKFOG1r7Zx9MwFxvZtQbGw4Nz+0SIiecqfM4BkoHqa5UhgXzbaZHTcj32Xjb4DrgAV/Kgn15kZT3Wtxx/vaczXmw/y8NsrOX7motdliYjkKH8CYBVQ18yizSwM6APMTddmLtDfdzdQG+C4c25/JsedDdwKYGb1gDDgUJaqz2UPt6nJ631bsCH5OPdPiOfA8XNelyQikmMyDQDn3CVgJDAf+BGY6ZzbZGZDzWyor9lnQBKQCEwEhv+8v5lNB+KB+maWbGaDfZsmA7V8t4fOAAakv/yTH9x1cxXeGdiS5KNn6DVuOdtTTnldkohIjrB8+Df3mmJiYlxCQoInP/v75OM88s53OOCdR1rStHoZT+oQEckqM1vtnItJvz6gRwJnRZPI0swa1o7wsGD6TlzBkm25e0eSiEhuUwBkQXSF4nw0rB01yoUzaMoq5q7PrJ9bRCT/UgBkUaVSRfng8bY0r16WJ2esZcqyHV6XJCKSLQqAbChdLJSpg1vRtUElfv/pD7y8YItGDYtIgaMAyKaiocGMe6gF98dEMvabRH79yUYuX1EIiEjBoYfi3oCQ4CD+1utmKpQowpsLt3Pk9HlG92lO0VCNGhaR/E9nADfIzHiu20387u6GzN/0EwMmf8eJcxo1LCL5nwIghwxqH81rDzRj9a6j9JmwgoMnNWpYRPI3BUAOuqd5NSYNiGHHodP0HhfPrsOnvS5JROSaFAA5rHP9ivzjsdacOHeRXuPi2bj3uNcliYhkSAGQC5rXKMusoW0JCzb6vLWC+O2HvS5JROQ/KABySZ2KJZk1rB2VSxdlwOTv+GJjZpOjiojkLQVALqpaphgfPt6WRtVKMfz9Nfxj5W6vSxIR+RcFQC4rWzyM9x9tTcd6Efz6k+959cutGjUsIvmCAiAPhIeFMLF/DL1aRDL66208MX0t5y5e9rosEQlwGgmcR0KDg3jpvpupXbE4L87fwu4jZ5jYP4ZKpYp6XZqIBCidAeQhM2N45zpM6HcLiQdP0fP1pWxIPuZ1WSISoBQAHrijUWU+GtaOkKAg7p8Qzz836LkCIpL3FAAeaVClFHNGxtK4amlG/mMtr365lSuaTVRE8pACwEMVShTh/cdaX9U5fPaCOodFJG+oE9hjRUKCeem+m6lfuQR/+Xwzu46cZmL/GKqULuZ1aSJSyOkMIB8wM4Z0rM2k/jHsSDlN3OvLWLdHncMikrsUAPnIbQ0q8fHwWMJCgnhgQjxz1u31uiQRKcQUAPlM/colmTMilqaRZXhyxjpeXrBFncMikisUAPlQ+RJFeO/R1v963vDw99dw5sIlr8sSkUJGAZBPhYWkPm/4f+5qwIIfDtB7XDz7jp31uiwRKUQUAPmYmfFoh1q8PaAlu4+coefry1i7+6jXZYlIIaEAKAC63FSRT4a3IzwsmAfeWsHsteocFpEbpwAoIOpWKsnsEbE0r16Gpz5Yx9+/2KzOYRG5IQqAAqRc8TCmDW5N31bVeXPhdoa+t5rT59U5LCLZowAoYMJCgvi/e5vwv//VkK9+/Ile45aTfPSM12WJSAGkACiAzIyBsdG8M7AVe4+d5Z43lrF61xGvyxKRAkYBUIB1qhfBJ8NjKVEkhL5vreSj1clelyQiBYgCoICrU7EEs0fEEhNVll98uJ6/fP4jl9U5LCJ+8CsAzKybmW0xs0Qzez6D7WZmY3zbN5hZizTbJpvZQTPbeI1j/9LMnJlVyP7LCGxlwsN4d1ArHmpdgwmLknh8WgKn1DksIpnINADMLBh4A+gONAT6mlnDdM26A3V9X0OAcWm2TQG6XePY1YHbgd1ZLVyuFhocxJ/vbcIf4hrx7ZYUer25nD1H1DksItfmzxlAKyDROZfknLsAzADi0rWJA6a6VCuAMmZWBcA5txi4Vg/lq8BzgK5Z5JD+baOYMrAl+4+ndg6v2qnOYRHJmD8BUA3Yk2Y52bcuq22uYmY9gb3OufWZtBtiZglmlpCSkuJHudKhbgSzR8RSulgoD05cwYcJezLfSUQCjj8BYBmsS/+J3Z82/25sFg78BvhdZj/cOfeWcy7GORcTERGRWXPxqRVRgk+Gx9I6ujzPztrAn+f9oM5hEbmKPwGQDFRPsxwJ7MtGm7RqA9HAejPb6Wu/xswq+1GP+Kl0eChTBrZkQNuaTFyyg8emJnDy3EWvyxKRfMKfAFgF1DWzaDMLA/oAc9O1mQv0990N1AY47pzbf60DOue+d85VdM5FOeeiSA2QFs65A9l7GXItIcFBvBDXmD/d05hFW1PoNW45uw+rc1hE/AgA59wlYCQwH/gRmOmc22RmQ81sqK/ZZ0ASkAhMBIb/vL+ZTQfigfpmlmxmg3P4NYgf+rWpybRBrfjpxHni3ljK8u2HvC5JRDxmzhWc68IxMTEuISHB6zIKtJ2HTvPo1ASSUk7xizvqM6xTbYKCMurCEZHCwsxWO+di0q/XSOAAE1WhOHNGxHL3zVV5cf4WBk5ZxZHTF7wuS0Q8oAAIQMWLhDC6TzP+fG9j4rcf5q4xSzSZnEgAUgAEKDPjodY1+Xh4O0KDg3hgwgomLUmiIF0SFJEbowAIcI2rlebTJ9pzW4OK/Gnejzw+bTXHz+pWUZFAoAAQShcLZXy/W/jt3Q35ZvNB7h67hO+Tj3tdlojkMgWAAKmXhAa3j2bm0LZcvuzoNW4501bs0iUhkUJMASBXaVGjLPNGdaBdnfL8dvZGRs1Yp6mlRQopBYD8h7LFw5g8oCXP3lmfeRv20XPsUjYfOOF1WSKSwxQAkqGgIGNElzr847E2nDx/iXveWKZZRUUKGQWAXFebWuX5bFQHWtQoy7OzNvDsh+s5e+Gy12WJSA5QAEimIkoWYdrg1oy6rS6z1iRzzxvL2J5yyuuyROQGKQDEL8FBxjO31+Pdga1IOXWenmOXMnf99Wb8FpH8TgEgWdKxXgTzRrWnQZVSjJq+lv+Z/T3nLuqSkEhBpACQLKtSuhjTh7Th8Y61eG/FbnqP1zMGRAoiBYBkS2hwEL/q0YCJ/WPYffgMd41dwvxNep6PSEGiAJAbcnvDSswb1YHoCsV5fNpq/vTPH7h4+YrXZYmIHxQAcsOqlwvnw6FteaRdFJOW7uCBCfHsO3bW67JEJBMKAMkRRUKC+X3PRrz+YHO2/nSKu8Ys4dstB70uS0SuQwEgOerum6syd2QslUoVZeA7q3hp/hYu6ZKQSL6kAJAcVyuiBLNHxNKnZXVe/zaRfm+v5OCJc16XJSLpKAAkVxQNDeavvW7m5fuasn7PcXqMWcry7Ye8LktE0lAASK7qdUskc0bGUiY8lH6TVjL2621cuaJnDIjkBwoAyXX1KpVkzohYejatystfbuWRKas4fOq812WJBDwFgOSJ4kVCePWBZvzfvU1YkXSYu8YsJWHnEa/LEgloCgDJM2bGg61r8PGwdhQJDeKBt1bw+jfbdJeQiEcUAJLnGlcrzadPtKdHkyq8tGAr902IZ+eh016XJRJwFADiiVJFQxnbtzmj+zRj+8FTdB+9hPdX6iH0InlJASCeimtWjflPdyQmqiy/+WQjg6as0pgBkTyiABDPVSldjHcHtuKFno2ITzrMna8t5rPv93tdlkihpwCQfCEoyBjQLop5ozpQo1w4w99fw9MfrOP42YtelyZSaCkAJF+pHVGCWcPa8VTXusxdv4/ury1meaJGEIvkBgWA5DuhwUE81bUeHw9rR9HQYB6ctJI/fPqDHj0pksMUAJJvNa1ehnmjOjCgbU0mL9vB3WOXsnHvca/LEik0FACSrxULC+aFuMZMHdSKk+cucs8byxj7tQaPieQEvwLAzLqZ2RYzSzSz5zPYbmY2xrd9g5m1SLNtspkdNLON6fZ50cw2+9p/YmZlbvzlSGHVsV4E85/qSPcmVXj5y9TBYzs0eEzkhmQaAGYWDLwBdAcaAn3NrGG6Zt2Bur6vIcC4NNumAN0yOPSXQGPn3M3AVuBXWS1eAkuZ8DDG9m3OmL7N2X7wFD1GL2HaCg0eE8kuf84AWgGJzrkk59wFYAYQl65NHDDVpVoBlDGzKgDOucXAf8z65Zxb4Jy75FtcAURm90VIYOnZtCoLnu5ETFRZfjt7I4+8s4qfNHhMJMv8CYBqwJ40y8m+dVltcz2DgM8z2mBmQ8wswcwSUlJSsnBIKcwqly7K1EGt+ENcI1buSB08Nm+DBo+JZIU/AWAZrEt/zu1Pm4wPbvYb4BLwfkbbnXNvOedinHMxERER/hxSAoSZ0b9t6uCxmuXCGfGPNTw1Y60Gj4n4yZ8ASAaqp1mOBPZlo81/MLMBwN3AQ04XciWb0g4e+3TDfrq9tphlGjwmkil/AmAVUNfMos0sDOgDzE3XZi7Q33c3UBvguHPuuufjZtYN+H9AT+fcmWzULvIvaQePFQsL5qFJK/n93E0aPCZyHZkGgK+jdiQwH/gRmOmc22RmQ81sqK/ZZ0ASkAhMBIb/vL+ZTQfigfpmlmxmg32bXgdKAl+a2TozG59TL0oCV9PqZZj3RAceaRfFlOU7uWvMEjYkH/O6LJF8yQrSlZeYmBiXkJDgdRlSQCzZlsKzH27g0KnzjLqtLsM71yYkWGMfJfCY2WrnXEz69fptkEKrQ93UwWM9mlThlS+30nt8PEkpp7wuSyTfUABIoVY6PJQxfZsztm9zdhw6TY8xS5gWv1ODx0RQAEiA+K+mVZn/VEdaRpXjt3M2MUCDx0QUABI4fh489se4Rny34zB3vLqYT9dnereySKGlAJCAYmY87Bs8FlWhOE9MX8uo6Ws5cvqC16WJ5DkFgASk2hEl+GhoW57uWo/Pvt9P11cW8cnaZPUNSEBRAEjACgkO4smudfnnqPbULB/O0x+sp//k79h1WNNMS2BQAEjAu6lyKWYNbccf4xqxdvcx7nh1MeMWbueiHjojhZwCQAQIDkrtG/jqmU50rh/B377YzH+NXcra3Ue9Lk0k1ygARNKoXLooEx6OYcLDt3DszEX+e9xy/nfORk6e0wyjUvgoAEQycGejynz5TEf6t6nJ1BW7uP2VxSzYdMDrskRylAJA5BpKFg3lhbjGfDSsHWXCQxkybTWPT0vgwHENIJPCQQEgkokWNcry6RPtea5bfRZuSaHrK4uYFr+TK1d0y6gUbAoAET+EBgcxvHMdFjzdkWbVy/DbOZvoNX45mw+c8Lo0kWxTAIhkQc3yxZk2uBWv3N+UnYdOc/eYpbw4f7MePCMFkgJAJIvMjP9uEcnXv+hMXLNqvPHtdj2GUgokBYBINpUrHsbL9zfl/UdbA/DQpJX8YuZ6zSskBYYCQOQGxdapwBdPdWREl9rMWbeX215eyMdrNK+Q5H8KAJEcUDQ0mGfvvIl/jmpPVIXiPDNzPQ+/rXmFJH9TAIjkoLTzCq3bkzqv0JsLEzWvkORLCgCRHJZ2XqEu9Svy9y+2aF4hyZcUACK5pHLpoox/+BbNKyT5lgJAJJf9PK/QgLZR/5pXaL7mFZJ8QAEgkgdKFg3l9z0b8bFvXqHHNa+Q5AMKAJE81DyDeYWmxu/ksuYVEg8oAETyWPp5hX43ZxO9Na+QeEABIOKRn+cVevWBpuw6fIa7xizl93M3ceyMRhJL3lAAiHjIzLi3eSRfPdOJPi2rMzV+J51fWsjU+J1c0tgByWUKAJF8oFzxMP58bxPmjepAg8ql+N2cTfQYs4Sl2zTBnOQeBYBIPtKgSin+8VhrxvdrwdmLl+n39koem5rAzkOaUkJyngJAJJ8xM7o1rsKXT3fiuW71WZZ4iDteXcxfPv9Rg8gkRykARPKpoqHBDO9ch4W/7EzPZlWZsCiJLi8tYuaqPXocpeQIBYBIPlexVFFeuq8pc0bEUqNcMZ77aANxbywjYecRr0uTAk4BIFJANK1eho+GtWN0n2aknDxP7/HxPDF9LXuPnfW6NCmg/AoAM+tmZlvMLNHMns9gu5nZGN/2DWbWIs22yWZ20Mw2ptunnJl9aWbbfP+WvfGXI1K4mRlxzarxzS87Meq2uizYdIDbXl7Iq19u5ewFPZdYsibTADCzYOANoDvQEOhrZg3TNesO1PV9DQHGpdk2BeiWwaGfB752ztUFvvYti4gfwsNCeOb2enz9i050bVCJ0V9v49aXFzJn3V49iUz85s8ZQCsg0TmX5Jy7AMwA4tK1iQOmulQrgDJmVgXAObcYyOhiZRzwru/7d4F7svMCRAJZZNlwXn+wBTMfb0u54mE8OWMd942PZ0PyMa9LkwLAnwCoBuxJs5zsW5fVNulVcs7tB/D9WzGjRmY2xMwSzCwhJSXFj3JFAk+r6HLMHdmev/Vqws7Dp4l7YxnPfriegyc126hcmz8BYBmsS3+O6U+bbHHOveWci3HOxUREROTEIUUKpeAg44GWNfj2l50Z0qEWs9ftpcuLCxm3cDvnL6l/QP6TPwGQDFRPsxwJ7MtGm/R++vkyke/fg37UIiKZKFk0lF/1aMCCpzvRtnYF/vbFZu54NfUhNOofkLT8CYBVQF0zizazMKAPMDddm7lAf9/dQG2A4z9f3rmOucAA3/cDgDlZqFtEMhFdoTiTBsQwbXArwoKDeHzaavq9vZItB056XZrkE5kGgHPuEjASmA/8CMx0zm0ys6FmNtTX7DMgCUgEJgLDf97fzKYD8UB9M0s2s8G+TX8FbjezbcDtvmURyWEd6kbw+ZMdeKFnIzbuPUH30Yv57eyNHD2taacDnRWkU8KYmBiXkJDgdRkiBdbR0xd47autvLdyNyWKhPBU17r0a1OT0GCNCS3MzGy1cy4m/Xr9VxcJIGWLh/FCXGM+f7IDN0eW5oVPf6D76CUs2qo77AKRAkAkANWrVJKpg1oxqX8Mly5fYcDk7xg8ZRVJKae8Lk3ykAJAJECZGV0bVmL+0x35dY+bWLnjCHe+tpg/fPoDh0+d97o8yQPqAxARAFJOnuflBVuYmbCHYqHBPNqhFo92iKZk0VCvS5MbdK0+AAWAiFwl8eBJXl6wlc83HqBseCjDO9fh4bY1KRoa7HVpkk0KABHJkg3Jx3hx/haWbDtE5VJFebJrXXrfEqk7hgog3QUkIllyc2QZpg1uzfTH2lC1TFF+9fH33PHqYuau36cnkhUSCgARua62tcvz0bB2TOofQ5GQIEZNX8tdY5fyzeafNLVEAacAEJFM/XzH0GejOjC6TzPOXLjEoCkJ3Dc+npVJh70uT7JJASAifgsKSn0i2VfPdOLP9zZmz9EzPPDWCgZM/o6Ne497XZ5kkTqBRSTbzl28zNT4nby5cDvHzlzkriZVeOaOetSOKOF1aZKG7gISkVxz4txFJi1OYtLSHZy/dIXeLSIZ1bUu1coU87o0QQEgInng0KnzvPntdt5bsQuAfm1qMrxLbSqUKOJxZYFNASAieWbvsbOM+WobH65OHVU8uH00j3asRSmNKvaEAkBE8tz2lFO8smAr877fT5nwUIZ1qs2AdlEaVZzHFAAi4pmNe4/z4vwtLNqaQqVSRRh1W13uj6muUcV5RCOBRcQzjauV5t1BrfhgSBsiy4bzm0820vWVRcxZt1ejij2kABCRPNO6VnlmDW3L5EdiCA8L4ckZ6+gxZglf/aBRxV5QAIhInjIzbr2pEvOeaM+Yvs05d/Eyj05NoNe45cRv16jivKQAEBFPBAUZPZtW5ctnOvF/9zZh37Fz9J24goffXsn3yRpVnBfUCSwi+cK5i5eZFr+LNxcmcvTMRe5oWImRt9bh5sgyXpdW4OkuIBEpEE6eu8jbS3cweekOTpy7RIe6FRjZpQ6ta5X3urQCSwEgIgXKyXMXeW/Fbt5emsShUxdoGVWW4V3q0LleBGbmdXkFigJARAqkcxcvM+O73by1OIl9x8/RuFopRnSuw52NKhMUpCDwhwJARAq0C5euMHvtXt5cmMjOw2eoU7EEwzvXpmfTqoRoQNl1KQBEpFC4fMUx7/v9vPltIpsPnKR6uWIM7VSb3rdEUiREU0xkRAEgIoXKlSuOrzcf5PVvE1m/5xiVShXhsQ61eLB1DcLDQrwuL19RAIhIoeScY1niYV7/dhsrko5QrngYg2KjeLhtFKWLafZRUACISABYvesIr3+TyLdbUihZJISH29ZkcPtoygf48wgUACISMDbuPc64hdv5bON+ioQE0bdVDYZ0rEWV0oH5hDIFgIgEnMSDpxi3cDuz1+0lyKD3LZEM7VSbmuWLe11anlIAiEjA2nPkDBMWb2dmQjKXLl+hZ9OqDO9Sh3qVSnpdWp5QAIhIwDt44hwTlyTx/srdnLlwmTsbVWJkl7o0iSztdWm5SgEgIuJz9PQF3lm2gynLd3Li3CU61otgZJc6tIou53VpueKGnghmZt3MbIuZJZrZ8xlsNzMb49u+wcxaZLavmTUzsxVmts7MEsysVXZfnIhIVpQtHsYzd9Rn2fO38ly3+mzae5z7J8Rz//h4Fm1NCZiH02R6BmBmwcBW4HYgGVgF9HXO/ZCmTQ/gCaAH0BoY7Zxrfb19zWwB8Kpz7nPf/s855zpfrxadAYhIbjh74TIfrNrNhMVJ7D9+jibVSjOiS23uaFg45hu6kTOAVkCicy7JOXcBmAHEpWsTB0x1qVYAZcysSib7OqCU7/vSwL4svyoRkRxQLCyYR2KjWfRsF/7Wqwknz11k6HtruPO1xcxM2MO5i5e9LjFX+BMA1YA9aZaTfev8aXO9fZ8CXjSzPcBLwK8y+uFmNsR3iSghJSXFj3JFRLInLCSIB1rW4KtnOjG6TzOCg4znZm2g/d++4dUvt5Jy8rzXJeYofwIgo/Of9NeNrtXmevsOA552zlUHngbezuiHO+fecs7FOOdiIiIi/ChXROTGhAQHEdesGp8/2YH3H21N08gyjP56G7F//YZfzFzPpn2F45GV/syYlAxUT7McyX9errlWm7Dr7DsAeNL3/YfAJP9KFhHJG2ZGbJ0KxNapQFLKKd5dvpMPVyfz0Zpk2tQqx6DYaG5rUIngAtpP4M8ZwCqgrplFm1kY0AeYm67NXKC/726gNsBx59z+TPbdB3TyfX8rsO0GX4uISK6pFVGCF+IaE//8bfy6x03sOXKWIdNW0+WlhUxeuoOT5y56XWKW+TUOwHeXzmtAMDDZOfdnMxsK4Jwbb6nPZ3sd6AacAQY65xKuta9vfXtgNKlnIeeA4c651derQ3cBiUh+cenyFRb88BOTl+4gYddRShQJ4f6Y6gyMjaJ6uXCvy7uKBoKJiOSSdXuO8c6yHczbsJ8rznF7w0oMio2mVXS5fPH8YgWAiEguO3D8HNNW7OT9lbs5duYijaqWYlBsNHc3reLp08oUACIieeTshcvMXreXyUt3sO3gKSJKFuHhNjV5sHUNKnjwbAIFgIhIHnPOsWTbISYv28HCLSmEhQRxT7OqDIyNpkGVUpkfIIdcKwD04EwRkVxiZnSsF0HHehEkHjzFlOU7mLU6mZkJybSrXZ7B7aPpUr+iZ9NN6AxARCQPHTtzgenf7WFq/E72Hz9HVPlwBsZG0/uWSIoXyZ3P5LoEJCKSj1y8fIUvNh7g7aU7WLfnGCWLhtCnZXX6t83520gVACIi+dSa3UeZvHQHn288gHOOOxtVZnD7aG6pWTZHbiNVH4CISD7VokZZWjxYln3HzjI1fhfTv9vN5xsPcHNkaQbFRtOjSRXCQvx6fEuW6AxARCSfOXPhEh+v2cvkZTtISjlNxZJFeO2BZrSrUyFbx9MZgIhIAREeFkK/NjV5sFUNFm1LYcqynURVKJ7jP0cBICKSTwUFGV3qV6RL/Yq5c/xcOaqIiOR7CgARkQClABARCVAKABGRAKUAEBEJUAoAEZEApQAQEQlQCgARkQBVoKaCMLMUYPDV4BYAAAL1SURBVFc2d68AHMrBcgo6vR//pvfiano/rlYY3o+azrmI9CsLVADcCDNLyGgujECl9+Pf9F5cTe/H1Qrz+6FLQCIiAUoBICISoAIpAN7yuoB8Ru/Hv+m9uJrej6sV2vcjYPoARETkaoF0BiAiImkoAEREAlRABICZdTOzLWaWaGbPe12PV8ysupl9a2Y/mtkmM3vS65ryAzMLNrO1ZvZPr2vxmpmVMbNZZrbZ9/9JW69r8oqZPe37PdloZtPNrKjXNeW0Qh8AZhYMvAF0BxoCfc2sobdVeeYS8AvnXAOgDTAigN+LtJ4EfvS6iHxiNPCFc+4moCkB+r6YWTVgFBDjnGsMBAN9vK0q5xX6AABaAYnOuSTn3AVgBhDncU2ecM7td86t8X1/ktRf7mreVuUtM4sE7gImeV2L18ysFNAReBvAOXfBOXfM26o8FQIUM7MQIBzY53E9OS4QAqAasCfNcjIB/kcPwMyigObASm8r8dxrwHPAFa8LyQdqASnAO75LYpPMLOefRF4AOOf2Ai8Bu4H9wHHn3AJvq8p5gRAAlsG6gL731cxKAB8BTznnTnhdj1fM7G7goHNutde15BMhQAtgnHOuOXAaCMg+MzMrS+qVgmigKlDczPp5W1XOC4QASAaqp1mOpBCeyvnLzEJJ/eP/vnPuY6/r8Vgs0NPMdpJ6afBWM3vP25I8lQwkO+d+PiucRWogBKKuwA7nXIpz7iLwMdDO45pyXCAEwCqgrplFm1kYqR05cz2uyRNmZqRe3/3ROfeK1/V4zTn3K+dcpHMuitT/L75xzhW6T3n+cs4dAPaYWX3fqtuAHzwsyUu7gTZmFu77vbmNQtghHuJ1AbnNOXfJzEYC80ntyZ/snNvkcVleiQUeBr43s3W+db92zn3mYU2SvzwBvO/7sJQEDPS4Hk8451aa2SxgDal3z62lEE4JoakgREQCVCBcAhIRkQwoAEREApQCQEQkQCkAREQClAJARCRAKQBERAKUAkBEJED9fzTylcI7GUklAAAAAElFTkSuQmCC\n", | |
"text/plain": [ | |
"<Figure size 432x288 with 1 Axes>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"%matplotlib inline \n", | |
"plt.plot(losses)\n", | |
"plt.show()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.9" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment