Skip to content

Instantly share code, notes, and snippets.

@kenenbek
Created January 29, 2020 19:37
Show Gist options
  • Save kenenbek/9cecdaf66faf3c67f241aeea4992070f to your computer and use it in GitHub Desktop.
Save kenenbek/9cecdaf66faf3c67f241aeea4992070f to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 115,
"metadata": {},
"outputs": [],
"source": [
"import utilities\n",
"\n",
"from __future__ import unicode_literals, print_function, division\n",
"from io import open\n",
"import unicodedata\n",
"import string\n",
"import re\n",
"import random\n",
"\n",
"import torch\n",
"import torch.nn as nn\n",
"from torch import optim\n",
"import torch.nn.functional as F\n",
"\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
]
},
{
"cell_type": "code",
"execution_count": 116,
"metadata": {},
"outputs": [],
"source": [
"num_replicates = 10\n",
"length = int(3e3)\n",
"dataset = utilities.GenomDataset(num_replicates=num_replicates, length=length)\n",
"\n",
"X = dataset.get_X()\n",
"Y = dataset.get_X()\n",
"\n",
"XX = torch.tensor(X).type(torch.float)\n",
"YY = torch.tensor(Y).type(torch.float)\n",
"\n",
"input_tensor = XX.view(length, num_replicates, -1)\n",
"target_tensor = YY.view(length, num_replicates, -1)"
]
},
{
"cell_type": "code",
"execution_count": 117,
"metadata": {},
"outputs": [],
"source": [
"class EncoderRNN(nn.Module):\n",
" def __init__(self, input_size, hidden_size, seq_length):\n",
" super(EncoderRNN, self).__init__()\n",
" self.input_size = input_size\n",
" self.hidden_size = hidden_size\n",
" self.seq_length = seq_length\n",
"\n",
" self.gru = nn.GRU(input_size, hidden_size)\n",
"\n",
" def forward(self, input, hidden):\n",
" # input = input.unsqueeze(0).unsqueeze(0)\n",
" output, hidden = self.gru(input, hidden)\n",
" return output, hidden\n",
"\n",
" def initHidden(self, batch_size):\n",
" return torch.zeros(1, batch_size, self.hidden_size, device=device)"
]
},
{
"cell_type": "code",
"execution_count": 123,
"metadata": {},
"outputs": [],
"source": [
"class AttnDecoderRNN(nn.Module):\n",
" def __init__(self, input_size, hidden_size, output_size, length, dropout_p=0.1):\n",
" super(AttnDecoderRNN, self).__init__()\n",
" self.input_size = input_size\n",
" self.hidden_size = hidden_size\n",
" self.output_size = output_size\n",
" self.length = length\n",
" self.dropout_p = dropout_p\n",
"\n",
" self.attn = nn.Linear(self.input_size + self.hidden_size, self.length)\n",
" self.attn_combine = nn.Linear(self.input_size + self.hidden_size, self.hidden_size)\n",
" self.dropout = nn.Dropout(self.dropout_p)\n",
" self.gru = nn.GRU(self.hidden_size, self.hidden_size)\n",
" self.out = nn.Linear(self.hidden_size, self.output_size)\n",
"\n",
" def forward(self, input, hidden, encoder_outputs):\n",
" \n",
" repeat_vals = [input.shape[0] // hidden.shape[0]] + [-1] * (len(hidden.shape) - 1)\n",
" concatenated = torch.cat((input, hidden.expand(*repeat_vals)), dim=-1)\n",
" attn = self.attn(concatenated)\n",
" attn_weights = F.softmax(attn, dim=-1)\n",
" # print(\"attn_weights\", attn_weights.shape)\n",
" # print(\"encoder_outputs\", encoder_outputs.shape)\n",
" attn_applied = torch.bmm(attn_weights.permute(1, 0, 2), encoder_outputs.permute(1, 0, 2))\n",
" attn_applied = attn_applied.permute(1, 0, 2)\n",
" \n",
" # print(\"attn_applied\", attn_applied.shape)\n",
" # print(\"input\", input.shape)\n",
" output = torch.cat((input, attn_applied), dim=-1)\n",
" \n",
" output = self.attn_combine(output)\n",
"\n",
" # print(\"output\", output.shape)\n",
" output = F.relu(output)\n",
" output, hidden = self.gru(output, hidden)\n",
" \n",
" # print(\"output\", output.shape)\n",
" # print(\"hidden\", hidden.shape)\n",
" output = self.out(output)\n",
" # print(\"output\", output.shape)\n",
" return output, hidden, attn_weights\n",
"\n",
" def initHidden(self):\n",
" return torch.zeros(1, 1, self.hidden_size, device=device)"
]
},
{
"cell_type": "code",
"execution_count": 193,
"metadata": {},
"outputs": [],
"source": [
"input_size = 1\n",
"hidden_size = 128\n",
"encoder = EncoderRNN(input_size, hidden_size, length).to(device)\n",
"\n",
"# check\n",
"# encoder_output, encoder_hidden = encoder(input_tensor, encoder_hidden)"
]
},
{
"cell_type": "code",
"execution_count": 194,
"metadata": {},
"outputs": [],
"source": [
"# check\n",
"\n",
"input_size_d = 1\n",
"hidden_size_d = 128\n",
"output_size_d = 1\n",
"decoder = AttnDecoderRNN(input_size_d, hidden_size_d, output_size_d, length, dropout_p=0.1).to(device)"
]
},
{
"cell_type": "code",
"execution_count": 195,
"metadata": {},
"outputs": [],
"source": [
"teacher_forcing_ratio = 0.5\n",
"\n",
"\n",
"def train(input_tensor, target_tensor, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion, seq_length):\n",
" encoder_hidden = encoder.initHidden(input_tensor.size(1))\n",
"\n",
" encoder_optimizer.zero_grad()\n",
" decoder_optimizer.zero_grad()\n",
"\n",
" input_length = input_tensor.size(0)\n",
" target_length = target_tensor.size(0)\n",
"\n",
" loss = 0\n",
"\n",
" encoder_outputs, encoder_hidden = encoder(input_tensor, encoder_hidden)\n",
"\n",
" decoder_input = target_tensor[0].unsqueeze(0)\n",
" decoder_hidden = encoder_hidden\n",
"\n",
" use_teacher_forcing = True # if random.random() < teacher_forcing_ratio else False\n",
" \n",
" ## TODO Without teacher forcing\n",
"# if use_teacher_forcing:\n",
"# # Teacher forcing: Feed the target as the next input\n",
"# for di in range(1, target_length):\n",
"# decoder_output, decoder_hidden, decoder_attention = decoder(\n",
"# decoder_input, decoder_hidden, encoder_outputs)\n",
"# # print(\"1\", decoder_output.size(), target_tensor[di].size())\n",
"# loss += criterion(decoder_output, target_tensor[di].unsqueeze(0))\n",
"# decoder_input = target_tensor[di] # Teacher forcing\n",
"# # print(\"2\", decoder_output.size(), target_tensor[di].size())\n",
"\n",
"# else:\n",
" # Without teacher forcing: use its own predictions as the next input\n",
" decoder_output, decoder_hidden, decoder_attention = decoder(\n",
" target_tensor, decoder_hidden, encoder_outputs)\n",
" loss += criterion(decoder_output, target_tensor)\n",
" decoder_input = decoder_output\n",
"\n",
" loss.backward()\n",
"\n",
" encoder_optimizer.step()\n",
" decoder_optimizer.step()\n",
"\n",
" return loss.item()"
]
},
{
"cell_type": "code",
"execution_count": 196,
"metadata": {},
"outputs": [],
"source": [
"import time\n",
"\n",
"def trainIters(encoder, decoder, input_dataset, target_dataset, n_iters, print_every=1000, plot_every=100, learning_rate=0.01):\n",
" start = time.time()\n",
" plot_losses = []\n",
" print_loss_total = 0 # Reset every print_every\n",
" plot_loss_total = 0 # Reset every plot_every\n",
"\n",
" encoder_optimizer = optim.SGD(encoder.parameters(), lr=learning_rate)\n",
" decoder_optimizer = optim.SGD(decoder.parameters(), lr=learning_rate)\n",
" \n",
" criterion = nn.MSELoss()\n",
" \n",
" for i in range(0, n_iters):\n",
"\n",
" loss = train(input_tensor, target_tensor, encoder,\n",
" decoder, encoder_optimizer, decoder_optimizer, criterion, length)\n",
" print_loss_total += loss\n",
" plot_loss_total += loss\n",
" # print(loss)\n",
" if i % print_every == 0:\n",
" print_loss_avg = print_loss_total / print_every\n",
" print_loss_total = 0\n",
" print('%s (%d %d%%) %.4f' % (timeSince(start, (i+1) / n_iters),\n",
" i+1, (i+1) / n_iters * 100, print_loss_avg))\n",
"\n",
" if i % plot_every == 0:\n",
" plot_loss_avg = plot_loss_total / plot_every\n",
" plot_losses.append(plot_loss_avg)\n",
" plot_loss_total = 0\n",
" \n",
" showPlot(plot_losses)\n",
" return plot_losses "
]
},
{
"cell_type": "code",
"execution_count": 197,
"metadata": {},
"outputs": [],
"source": [
"from matplotlib import pyplot as plt \n",
"plt.switch_backend('agg')\n",
"%matplotlib inline \n",
"\n",
"from helper import asMinutes, timeSince, showPlot"
]
},
{
"cell_type": "code",
"execution_count": 198,
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"plt.switch_backend('agg')\n",
"import matplotlib.ticker as ticker\n",
"import numpy as np\n",
"\n",
"\n",
"def showPlot(points):\n",
" plt.figure()\n",
" fig, ax = plt.subplots()\n",
" # this locator puts ticks at regular intervals\n",
" loc = ticker.MultipleLocator(base=0.2)\n",
" ax.yaxis.set_major_locator(loc)\n",
" plt.plot(points)"
]
},
{
"cell_type": "code",
"execution_count": 199,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0m 18s (- 2m 42s) (1 10%) 0.0165\n",
"0m 36s (- 2m 27s) (2 20%) 0.0151\n",
"0m 52s (- 2m 2s) (3 30%) 0.0138\n",
"1m 8s (- 1m 42s) (4 40%) 0.0126\n",
"1m 24s (- 1m 24s) (5 50%) 0.0115\n",
"1m 39s (- 1m 6s) (6 60%) 0.0105\n",
"1m 55s (- 0m 49s) (7 70%) 0.0096\n",
"2m 11s (- 0m 32s) (8 80%) 0.0087\n",
"2m 27s (- 0m 16s) (9 90%) 0.0080\n",
"2m 43s (- 0m 0s) (10 100%) 0.0073\n"
]
}
],
"source": [
"losses = trainIters(encoder, decoder, input_tensor, target_tensor, 10, print_every=1, plot_every=1)"
]
},
{
"cell_type": "code",
"execution_count": 203,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYAAAAD4CAYAAADlwTGnAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3dd3hVVb7/8fc3DQi9hBogoUqTYqihKiqgl+gFFRRBQJEmlhn9OTN35o5T7hQroAKCiKADIirgoIKNHpBQBaWE0EKR0Htfvz9ynAmZQE5Ckp3kfF7Pk4fsvdfe+Z6jyefsvfZa25xziIhI4AnyugAREfGGAkBEJEApAEREApQCQEQkQCkAREQCVIjXBWRFhQoVXFRUlNdliIgUKKtXrz7knItIv75ABUBUVBQJCQlelyEiUqCY2a6M1usSkIhIgFIAiIgEKAWAiEiAUgCIiAQoBYCISIBSAIiIBCgFgIhIgAqIAFi7+ygTFm1HU1+LiPxbgRoIll0fr9nLtBW7+OnEef7nrgYEBZnXJYmIeC4gAuCFno0ICTYmL9vBkdPn+XvvpoSFBMTJj4jINQVEAAQFGb+7uyEVShThxflbOHLmIuP7tSA8LCBevohIhgLmY7CZMaJLHf76301Yui2FByeu5OjpC16XJSLimYAJgJ/1aVWDcf1u4Yf9J+g9fjl7j531uiQREU8EXAAA3NmoMtMGteLgyfP0HrecbT+d9LokEZE8F5ABANC6Vnk+GNKWS1ccvcfHs3rXUa9LEhHJUwEbAAANq5bio6HtKBseykOTVvDtloNelyQikmf8CgAz62ZmW8ws0cyez2C7mdkY3/YNZtYizbbJZnbQzDZmsN8TvuNuMrO/39hLyZ4a5cP5cGg7akeU4LF3E/hkbbIXZYiI5LlMA8DMgoE3gO5AQ6CvmTVM16w7UNf3NQQYl2bbFKBbBsftAsQBNzvnGgEvZaP+HBFRsggzhrShVXQ5nv5gPZOWJHlViohInvHnDKAVkOicS3LOXQBmkPqHO604YKpLtQIoY2ZVAJxzi4EjGRx3GPBX59x5XztPr7+ULBrKOwNb0qNJZf4070f+8vmPmjpCRAo1fwKgGrAnzXKyb11W26RXD+hgZivNbJGZtcyokZkNMbMEM0tISUnxo9zsKxISzNi+LejXpgYTFiXx3KwNXLp8JVd/poiIV/wZCpvRxDnpPxr70yajn10WaAO0BGaaWS2X7mO3c+4t4C2AmJiYXP9IHhxk/DGuMRVKFOG1r7Zx9MwFxvZtQbGw4Nz+0SIiecqfM4BkoHqa5UhgXzbaZHTcj32Xjb4DrgAV/Kgn15kZT3Wtxx/vaczXmw/y8NsrOX7motdliYjkKH8CYBVQ18yizSwM6APMTddmLtDfdzdQG+C4c25/JsedDdwKYGb1gDDgUJaqz2UPt6nJ631bsCH5OPdPiOfA8XNelyQikmMyDQDn3CVgJDAf+BGY6ZzbZGZDzWyor9lnQBKQCEwEhv+8v5lNB+KB+maWbGaDfZsmA7V8t4fOAAakv/yTH9x1cxXeGdiS5KNn6DVuOdtTTnldkohIjrB8+Df3mmJiYlxCQoInP/v75OM88s53OOCdR1rStHoZT+oQEckqM1vtnItJvz6gRwJnRZPI0swa1o7wsGD6TlzBkm25e0eSiEhuUwBkQXSF4nw0rB01yoUzaMoq5q7PrJ9bRCT/UgBkUaVSRfng8bY0r16WJ2esZcqyHV6XJCKSLQqAbChdLJSpg1vRtUElfv/pD7y8YItGDYtIgaMAyKaiocGMe6gF98dEMvabRH79yUYuX1EIiEjBoYfi3oCQ4CD+1utmKpQowpsLt3Pk9HlG92lO0VCNGhaR/E9nADfIzHiu20387u6GzN/0EwMmf8eJcxo1LCL5nwIghwxqH81rDzRj9a6j9JmwgoMnNWpYRPI3BUAOuqd5NSYNiGHHodP0HhfPrsOnvS5JROSaFAA5rHP9ivzjsdacOHeRXuPi2bj3uNcliYhkSAGQC5rXKMusoW0JCzb6vLWC+O2HvS5JROQ/KABySZ2KJZk1rB2VSxdlwOTv+GJjZpOjiojkLQVALqpaphgfPt6WRtVKMfz9Nfxj5W6vSxIR+RcFQC4rWzyM9x9tTcd6Efz6k+959cutGjUsIvmCAiAPhIeFMLF/DL1aRDL66208MX0t5y5e9rosEQlwGgmcR0KDg3jpvpupXbE4L87fwu4jZ5jYP4ZKpYp6XZqIBCidAeQhM2N45zpM6HcLiQdP0fP1pWxIPuZ1WSISoBQAHrijUWU+GtaOkKAg7p8Qzz836LkCIpL3FAAeaVClFHNGxtK4amlG/mMtr365lSuaTVRE8pACwEMVShTh/cdaX9U5fPaCOodFJG+oE9hjRUKCeem+m6lfuQR/+Xwzu46cZmL/GKqULuZ1aSJSyOkMIB8wM4Z0rM2k/jHsSDlN3OvLWLdHncMikrsUAPnIbQ0q8fHwWMJCgnhgQjxz1u31uiQRKcQUAPlM/colmTMilqaRZXhyxjpeXrBFncMikisUAPlQ+RJFeO/R1v963vDw99dw5sIlr8sSkUJGAZBPhYWkPm/4f+5qwIIfDtB7XDz7jp31uiwRKUQUAPmYmfFoh1q8PaAlu4+coefry1i7+6jXZYlIIaEAKAC63FSRT4a3IzwsmAfeWsHsteocFpEbpwAoIOpWKsnsEbE0r16Gpz5Yx9+/2KzOYRG5IQqAAqRc8TCmDW5N31bVeXPhdoa+t5rT59U5LCLZowAoYMJCgvi/e5vwv//VkK9+/Ile45aTfPSM12WJSAGkACiAzIyBsdG8M7AVe4+d5Z43lrF61xGvyxKRAkYBUIB1qhfBJ8NjKVEkhL5vreSj1clelyQiBYgCoICrU7EEs0fEEhNVll98uJ6/fP4jl9U5LCJ+8CsAzKybmW0xs0Qzez6D7WZmY3zbN5hZizTbJpvZQTPbeI1j/9LMnJlVyP7LCGxlwsN4d1ArHmpdgwmLknh8WgKn1DksIpnINADMLBh4A+gONAT6mlnDdM26A3V9X0OAcWm2TQG6XePY1YHbgd1ZLVyuFhocxJ/vbcIf4hrx7ZYUer25nD1H1DksItfmzxlAKyDROZfknLsAzADi0rWJA6a6VCuAMmZWBcA5txi4Vg/lq8BzgK5Z5JD+baOYMrAl+4+ndg6v2qnOYRHJmD8BUA3Yk2Y52bcuq22uYmY9gb3OufWZtBtiZglmlpCSkuJHudKhbgSzR8RSulgoD05cwYcJezLfSUQCjj8BYBmsS/+J3Z82/25sFg78BvhdZj/cOfeWcy7GORcTERGRWXPxqRVRgk+Gx9I6ujzPztrAn+f9oM5hEbmKPwGQDFRPsxwJ7MtGm7RqA9HAejPb6Wu/xswq+1GP+Kl0eChTBrZkQNuaTFyyg8emJnDy3EWvyxKRfMKfAFgF1DWzaDMLA/oAc9O1mQv0990N1AY47pzbf60DOue+d85VdM5FOeeiSA2QFs65A9l7GXItIcFBvBDXmD/d05hFW1PoNW45uw+rc1hE/AgA59wlYCQwH/gRmOmc22RmQ81sqK/ZZ0ASkAhMBIb/vL+ZTQfigfpmlmxmg3P4NYgf+rWpybRBrfjpxHni3ljK8u2HvC5JRDxmzhWc68IxMTEuISHB6zIKtJ2HTvPo1ASSUk7xizvqM6xTbYKCMurCEZHCwsxWO+di0q/XSOAAE1WhOHNGxHL3zVV5cf4WBk5ZxZHTF7wuS0Q8oAAIQMWLhDC6TzP+fG9j4rcf5q4xSzSZnEgAUgAEKDPjodY1+Xh4O0KDg3hgwgomLUmiIF0SFJEbowAIcI2rlebTJ9pzW4OK/Gnejzw+bTXHz+pWUZFAoAAQShcLZXy/W/jt3Q35ZvNB7h67hO+Tj3tdlojkMgWAAKmXhAa3j2bm0LZcvuzoNW4501bs0iUhkUJMASBXaVGjLPNGdaBdnfL8dvZGRs1Yp6mlRQopBYD8h7LFw5g8oCXP3lmfeRv20XPsUjYfOOF1WSKSwxQAkqGgIGNElzr847E2nDx/iXveWKZZRUUKGQWAXFebWuX5bFQHWtQoy7OzNvDsh+s5e+Gy12WJSA5QAEimIkoWYdrg1oy6rS6z1iRzzxvL2J5yyuuyROQGKQDEL8FBxjO31+Pdga1IOXWenmOXMnf99Wb8FpH8TgEgWdKxXgTzRrWnQZVSjJq+lv+Z/T3nLuqSkEhBpACQLKtSuhjTh7Th8Y61eG/FbnqP1zMGRAoiBYBkS2hwEL/q0YCJ/WPYffgMd41dwvxNep6PSEGiAJAbcnvDSswb1YHoCsV5fNpq/vTPH7h4+YrXZYmIHxQAcsOqlwvnw6FteaRdFJOW7uCBCfHsO3bW67JEJBMKAMkRRUKC+X3PRrz+YHO2/nSKu8Ys4dstB70uS0SuQwEgOerum6syd2QslUoVZeA7q3hp/hYu6ZKQSL6kAJAcVyuiBLNHxNKnZXVe/zaRfm+v5OCJc16XJSLpKAAkVxQNDeavvW7m5fuasn7PcXqMWcry7Ye8LktE0lAASK7qdUskc0bGUiY8lH6TVjL2621cuaJnDIjkBwoAyXX1KpVkzohYejatystfbuWRKas4fOq812WJBDwFgOSJ4kVCePWBZvzfvU1YkXSYu8YsJWHnEa/LEgloCgDJM2bGg61r8PGwdhQJDeKBt1bw+jfbdJeQiEcUAJLnGlcrzadPtKdHkyq8tGAr902IZ+eh016XJRJwFADiiVJFQxnbtzmj+zRj+8FTdB+9hPdX6iH0InlJASCeimtWjflPdyQmqiy/+WQjg6as0pgBkTyiABDPVSldjHcHtuKFno2ITzrMna8t5rPv93tdlkihpwCQfCEoyBjQLop5ozpQo1w4w99fw9MfrOP42YtelyZSaCkAJF+pHVGCWcPa8VTXusxdv4/ury1meaJGEIvkBgWA5DuhwUE81bUeHw9rR9HQYB6ctJI/fPqDHj0pksMUAJJvNa1ehnmjOjCgbU0mL9vB3WOXsnHvca/LEik0FACSrxULC+aFuMZMHdSKk+cucs8byxj7tQaPieQEvwLAzLqZ2RYzSzSz5zPYbmY2xrd9g5m1SLNtspkdNLON6fZ50cw2+9p/YmZlbvzlSGHVsV4E85/qSPcmVXj5y9TBYzs0eEzkhmQaAGYWDLwBdAcaAn3NrGG6Zt2Bur6vIcC4NNumAN0yOPSXQGPn3M3AVuBXWS1eAkuZ8DDG9m3OmL7N2X7wFD1GL2HaCg0eE8kuf84AWgGJzrkk59wFYAYQl65NHDDVpVoBlDGzKgDOucXAf8z65Zxb4Jy75FtcAURm90VIYOnZtCoLnu5ETFRZfjt7I4+8s4qfNHhMJMv8CYBqwJ40y8m+dVltcz2DgM8z2mBmQ8wswcwSUlJSsnBIKcwqly7K1EGt+ENcI1buSB08Nm+DBo+JZIU/AWAZrEt/zu1Pm4wPbvYb4BLwfkbbnXNvOedinHMxERER/hxSAoSZ0b9t6uCxmuXCGfGPNTw1Y60Gj4n4yZ8ASAaqp1mOBPZlo81/MLMBwN3AQ04XciWb0g4e+3TDfrq9tphlGjwmkil/AmAVUNfMos0sDOgDzE3XZi7Q33c3UBvguHPuuufjZtYN+H9AT+fcmWzULvIvaQePFQsL5qFJK/n93E0aPCZyHZkGgK+jdiQwH/gRmOmc22RmQ81sqK/ZZ0ASkAhMBIb/vL+ZTQfigfpmlmxmg32bXgdKAl+a2TozG59TL0oCV9PqZZj3RAceaRfFlOU7uWvMEjYkH/O6LJF8yQrSlZeYmBiXkJDgdRlSQCzZlsKzH27g0KnzjLqtLsM71yYkWGMfJfCY2WrnXEz69fptkEKrQ93UwWM9mlThlS+30nt8PEkpp7wuSyTfUABIoVY6PJQxfZsztm9zdhw6TY8xS5gWv1ODx0RQAEiA+K+mVZn/VEdaRpXjt3M2MUCDx0QUABI4fh489se4Rny34zB3vLqYT9dnereySKGlAJCAYmY87Bs8FlWhOE9MX8uo6Ws5cvqC16WJ5DkFgASk2hEl+GhoW57uWo/Pvt9P11cW8cnaZPUNSEBRAEjACgkO4smudfnnqPbULB/O0x+sp//k79h1WNNMS2BQAEjAu6lyKWYNbccf4xqxdvcx7nh1MeMWbueiHjojhZwCQAQIDkrtG/jqmU50rh/B377YzH+NXcra3Ue9Lk0k1ygARNKoXLooEx6OYcLDt3DszEX+e9xy/nfORk6e0wyjUvgoAEQycGejynz5TEf6t6nJ1BW7uP2VxSzYdMDrskRylAJA5BpKFg3lhbjGfDSsHWXCQxkybTWPT0vgwHENIJPCQQEgkokWNcry6RPtea5bfRZuSaHrK4uYFr+TK1d0y6gUbAoAET+EBgcxvHMdFjzdkWbVy/DbOZvoNX45mw+c8Lo0kWxTAIhkQc3yxZk2uBWv3N+UnYdOc/eYpbw4f7MePCMFkgJAJIvMjP9uEcnXv+hMXLNqvPHtdj2GUgokBYBINpUrHsbL9zfl/UdbA/DQpJX8YuZ6zSskBYYCQOQGxdapwBdPdWREl9rMWbeX215eyMdrNK+Q5H8KAJEcUDQ0mGfvvIl/jmpPVIXiPDNzPQ+/rXmFJH9TAIjkoLTzCq3bkzqv0JsLEzWvkORLCgCRHJZ2XqEu9Svy9y+2aF4hyZcUACK5pHLpoox/+BbNKyT5lgJAJJf9PK/QgLZR/5pXaL7mFZJ8QAEgkgdKFg3l9z0b8bFvXqHHNa+Q5AMKAJE81DyDeYWmxu/ksuYVEg8oAETyWPp5hX43ZxO9Na+QeEABIOKRn+cVevWBpuw6fIa7xizl93M3ceyMRhJL3lAAiHjIzLi3eSRfPdOJPi2rMzV+J51fWsjU+J1c0tgByWUKAJF8oFzxMP58bxPmjepAg8ql+N2cTfQYs4Sl2zTBnOQeBYBIPtKgSin+8VhrxvdrwdmLl+n39koem5rAzkOaUkJyngJAJJ8xM7o1rsKXT3fiuW71WZZ4iDteXcxfPv9Rg8gkRykARPKpoqHBDO9ch4W/7EzPZlWZsCiJLi8tYuaqPXocpeQIBYBIPlexVFFeuq8pc0bEUqNcMZ77aANxbywjYecRr0uTAk4BIFJANK1eho+GtWN0n2aknDxP7/HxPDF9LXuPnfW6NCmg/AoAM+tmZlvMLNHMns9gu5nZGN/2DWbWIs22yWZ20Mw2ptunnJl9aWbbfP+WvfGXI1K4mRlxzarxzS87Meq2uizYdIDbXl7Iq19u5ewFPZdYsibTADCzYOANoDvQEOhrZg3TNesO1PV9DQHGpdk2BeiWwaGfB752ztUFvvYti4gfwsNCeOb2enz9i050bVCJ0V9v49aXFzJn3V49iUz85s8ZQCsg0TmX5Jy7AMwA4tK1iQOmulQrgDJmVgXAObcYyOhiZRzwru/7d4F7svMCRAJZZNlwXn+wBTMfb0u54mE8OWMd942PZ0PyMa9LkwLAnwCoBuxJs5zsW5fVNulVcs7tB/D9WzGjRmY2xMwSzCwhJSXFj3JFAk+r6HLMHdmev/Vqws7Dp4l7YxnPfriegyc126hcmz8BYBmsS3+O6U+bbHHOveWci3HOxUREROTEIUUKpeAg44GWNfj2l50Z0qEWs9ftpcuLCxm3cDvnL6l/QP6TPwGQDFRPsxwJ7MtGm/R++vkyke/fg37UIiKZKFk0lF/1aMCCpzvRtnYF/vbFZu54NfUhNOofkLT8CYBVQF0zizazMKAPMDddm7lAf9/dQG2A4z9f3rmOucAA3/cDgDlZqFtEMhFdoTiTBsQwbXArwoKDeHzaavq9vZItB056XZrkE5kGgHPuEjASmA/8CMx0zm0ys6FmNtTX7DMgCUgEJgLDf97fzKYD8UB9M0s2s8G+TX8FbjezbcDtvmURyWEd6kbw+ZMdeKFnIzbuPUH30Yv57eyNHD2taacDnRWkU8KYmBiXkJDgdRkiBdbR0xd47autvLdyNyWKhPBU17r0a1OT0GCNCS3MzGy1cy4m/Xr9VxcJIGWLh/FCXGM+f7IDN0eW5oVPf6D76CUs2qo77AKRAkAkANWrVJKpg1oxqX8Mly5fYcDk7xg8ZRVJKae8Lk3ykAJAJECZGV0bVmL+0x35dY+bWLnjCHe+tpg/fPoDh0+d97o8yQPqAxARAFJOnuflBVuYmbCHYqHBPNqhFo92iKZk0VCvS5MbdK0+AAWAiFwl8eBJXl6wlc83HqBseCjDO9fh4bY1KRoa7HVpkk0KABHJkg3Jx3hx/haWbDtE5VJFebJrXXrfEqk7hgog3QUkIllyc2QZpg1uzfTH2lC1TFF+9fH33PHqYuau36cnkhUSCgARua62tcvz0bB2TOofQ5GQIEZNX8tdY5fyzeafNLVEAacAEJFM/XzH0GejOjC6TzPOXLjEoCkJ3Dc+npVJh70uT7JJASAifgsKSn0i2VfPdOLP9zZmz9EzPPDWCgZM/o6Ne497XZ5kkTqBRSTbzl28zNT4nby5cDvHzlzkriZVeOaOetSOKOF1aZKG7gISkVxz4txFJi1OYtLSHZy/dIXeLSIZ1bUu1coU87o0QQEgInng0KnzvPntdt5bsQuAfm1qMrxLbSqUKOJxZYFNASAieWbvsbOM+WobH65OHVU8uH00j3asRSmNKvaEAkBE8tz2lFO8smAr877fT5nwUIZ1qs2AdlEaVZzHFAAi4pmNe4/z4vwtLNqaQqVSRRh1W13uj6muUcV5RCOBRcQzjauV5t1BrfhgSBsiy4bzm0820vWVRcxZt1ejij2kABCRPNO6VnlmDW3L5EdiCA8L4ckZ6+gxZglf/aBRxV5QAIhInjIzbr2pEvOeaM+Yvs05d/Eyj05NoNe45cRv16jivKQAEBFPBAUZPZtW5ctnOvF/9zZh37Fz9J24goffXsn3yRpVnBfUCSwi+cK5i5eZFr+LNxcmcvTMRe5oWImRt9bh5sgyXpdW4OkuIBEpEE6eu8jbS3cweekOTpy7RIe6FRjZpQ6ta5X3urQCSwEgIgXKyXMXeW/Fbt5emsShUxdoGVWW4V3q0LleBGbmdXkFigJARAqkcxcvM+O73by1OIl9x8/RuFopRnSuw52NKhMUpCDwhwJARAq0C5euMHvtXt5cmMjOw2eoU7EEwzvXpmfTqoRoQNl1KQBEpFC4fMUx7/v9vPltIpsPnKR6uWIM7VSb3rdEUiREU0xkRAEgIoXKlSuOrzcf5PVvE1m/5xiVShXhsQ61eLB1DcLDQrwuL19RAIhIoeScY1niYV7/dhsrko5QrngYg2KjeLhtFKWLafZRUACISABYvesIr3+TyLdbUihZJISH29ZkcPtoygf48wgUACISMDbuPc64hdv5bON+ioQE0bdVDYZ0rEWV0oH5hDIFgIgEnMSDpxi3cDuz1+0lyKD3LZEM7VSbmuWLe11anlIAiEjA2nPkDBMWb2dmQjKXLl+hZ9OqDO9Sh3qVSnpdWp5QAIhIwDt44hwTlyTx/srdnLlwmTsbVWJkl7o0iSztdWm5SgEgIuJz9PQF3lm2gynLd3Li3CU61otgZJc6tIou53VpueKGnghmZt3MbIuZJZrZ8xlsNzMb49u+wcxaZLavmTUzsxVmts7MEsysVXZfnIhIVpQtHsYzd9Rn2fO38ly3+mzae5z7J8Rz//h4Fm1NCZiH02R6BmBmwcBW4HYgGVgF9HXO/ZCmTQ/gCaAH0BoY7Zxrfb19zWwB8Kpz7nPf/s855zpfrxadAYhIbjh74TIfrNrNhMVJ7D9+jibVSjOiS23uaFg45hu6kTOAVkCicy7JOXcBmAHEpWsTB0x1qVYAZcysSib7OqCU7/vSwL4svyoRkRxQLCyYR2KjWfRsF/7Wqwknz11k6HtruPO1xcxM2MO5i5e9LjFX+BMA1YA9aZaTfev8aXO9fZ8CXjSzPcBLwK8y+uFmNsR3iSghJSXFj3JFRLInLCSIB1rW4KtnOjG6TzOCg4znZm2g/d++4dUvt5Jy8rzXJeYofwIgo/Of9NeNrtXmevsOA552zlUHngbezuiHO+fecs7FOOdiIiIi/ChXROTGhAQHEdesGp8/2YH3H21N08gyjP56G7F//YZfzFzPpn2F45GV/syYlAxUT7McyX9errlWm7Dr7DsAeNL3/YfAJP9KFhHJG2ZGbJ0KxNapQFLKKd5dvpMPVyfz0Zpk2tQqx6DYaG5rUIngAtpP4M8ZwCqgrplFm1kY0AeYm67NXKC/726gNsBx59z+TPbdB3TyfX8rsO0GX4uISK6pFVGCF+IaE//8bfy6x03sOXKWIdNW0+WlhUxeuoOT5y56XWKW+TUOwHeXzmtAMDDZOfdnMxsK4Jwbb6nPZ3sd6AacAQY65xKuta9vfXtgNKlnIeeA4c651derQ3cBiUh+cenyFRb88BOTl+4gYddRShQJ4f6Y6gyMjaJ6uXCvy7uKBoKJiOSSdXuO8c6yHczbsJ8rznF7w0oMio2mVXS5fPH8YgWAiEguO3D8HNNW7OT9lbs5duYijaqWYlBsNHc3reLp08oUACIieeTshcvMXreXyUt3sO3gKSJKFuHhNjV5sHUNKnjwbAIFgIhIHnPOsWTbISYv28HCLSmEhQRxT7OqDIyNpkGVUpkfIIdcKwD04EwRkVxiZnSsF0HHehEkHjzFlOU7mLU6mZkJybSrXZ7B7aPpUr+iZ9NN6AxARCQPHTtzgenf7WFq/E72Hz9HVPlwBsZG0/uWSIoXyZ3P5LoEJCKSj1y8fIUvNh7g7aU7WLfnGCWLhtCnZXX6t83520gVACIi+dSa3UeZvHQHn288gHOOOxtVZnD7aG6pWTZHbiNVH4CISD7VokZZWjxYln3HzjI1fhfTv9vN5xsPcHNkaQbFRtOjSRXCQvx6fEuW6AxARCSfOXPhEh+v2cvkZTtISjlNxZJFeO2BZrSrUyFbx9MZgIhIAREeFkK/NjV5sFUNFm1LYcqynURVKJ7jP0cBICKSTwUFGV3qV6RL/Yq5c/xcOaqIiOR7CgARkQClABARCVAKABGRAKUAEBEJUAoAEZEApQAQEQlQCgARkQBVoKaCMLMUYPDV4BYAAAL1SURBVFc2d68AHMrBcgo6vR//pvfiano/rlYY3o+azrmI9CsLVADcCDNLyGgujECl9+Pf9F5cTe/H1Qrz+6FLQCIiAUoBICISoAIpAN7yuoB8Ru/Hv+m9uJrej6sV2vcjYPoARETkaoF0BiAiImkoAEREAlRABICZdTOzLWaWaGbPe12PV8ysupl9a2Y/mtkmM3vS65ryAzMLNrO1ZvZPr2vxmpmVMbNZZrbZ9/9JW69r8oqZPe37PdloZtPNrKjXNeW0Qh8AZhYMvAF0BxoCfc2sobdVeeYS8AvnXAOgDTAigN+LtJ4EfvS6iHxiNPCFc+4moCkB+r6YWTVgFBDjnGsMBAN9vK0q5xX6AABaAYnOuSTn3AVgBhDncU2ecM7td86t8X1/ktRf7mreVuUtM4sE7gImeV2L18ysFNAReBvAOXfBOXfM26o8FQIUM7MQIBzY53E9OS4QAqAasCfNcjIB/kcPwMyigObASm8r8dxrwHPAFa8LyQdqASnAO75LYpPMLOefRF4AOOf2Ai8Bu4H9wHHn3AJvq8p5gRAAlsG6gL731cxKAB8BTznnTnhdj1fM7G7goHNutde15BMhQAtgnHOuOXAaCMg+MzMrS+qVgmigKlDczPp5W1XOC4QASAaqp1mOpBCeyvnLzEJJ/eP/vnPuY6/r8Vgs0NPMdpJ6afBWM3vP25I8lQwkO+d+PiucRWogBKKuwA7nXIpz7iLwMdDO45pyXCAEwCqgrplFm1kYqR05cz2uyRNmZqRe3/3ROfeK1/V4zTn3K+dcpHMuitT/L75xzhW6T3n+cs4dAPaYWX3fqtuAHzwsyUu7gTZmFu77vbmNQtghHuJ1AbnNOXfJzEYC80ntyZ/snNvkcVleiQUeBr43s3W+db92zn3mYU2SvzwBvO/7sJQEDPS4Hk8451aa2SxgDal3z62lEE4JoakgREQCVCBcAhIRkQwoAEREApQCQEQkQCkAREQClAJARCRAKQBERAKUAkBEJED9fzTylcI7GUklAAAAAElFTkSuQmCC\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"%matplotlib inline \n",
"plt.plot(losses)\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.9"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment