-
-
Save jvns/b8804fb9d0672ce147a28d22648b4bd7 to your computer and use it in GitHub Desktop.
rnn 123.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"name": "rnn 123.ipynb", | |
"provenance": [], | |
"collapsed_sections": [], | |
"authorship_tag": "ABX9TyMgy3H46gOlgYJzKwafyS7N", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"accelerator": "GPU" | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/jvns/b8804fb9d0672ce147a28d22648b4bd7/rnn-123.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "0GTbDNteFg2m" | |
}, | |
"source": [ | |
"import itertools\n", | |
"import fastprogress\n", | |
"from fastai.text import *\n" | |
], | |
"execution_count": 1, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "UPHnBLaTFl8O" | |
}, | |
"source": [ | |
"# Define the models\n", | |
"\n", | |
"I'm defining 2 different versions of the RNN model, just because I wanted to test 2 different ways to see if they both work" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "ryL6xSivFrXM" | |
}, | |
"source": [ | |
"input_size = 3\n", | |
"hidden_size = 4\n", | |
"class RNN(nn.Module):\n", | |
" def __init__(self, hidden_size=4):\n", | |
" super().__init__()\n", | |
" self.i2h = nn.Linear(input_size + hidden_size, hidden_size)\n", | |
" self.i2o = nn.Linear(input_size + hidden_size, input_size)\n", | |
" self.hidden = torch.zeros(hidden_size).cuda()\n", | |
"\n", | |
" def forward(self, input):\n", | |
" input = torch.nn.functional.one_hot(input, num_classes=input_size).type(torch.FloatTensor).cuda()\n", | |
" combined = torch.cat((input, self.hidden))\n", | |
" hidden = self.i2h(combined)\n", | |
" output = self.i2o(combined)\n", | |
" self.hidden = hidden.detach()\n", | |
" return output\n", | |
"\n" | |
], | |
"execution_count": 37, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "tOhSjTcC1oO4" | |
}, | |
"source": [ | |
"class RNN2(nn.Module):\n", | |
" def __init__(self):\n", | |
" super().__init__()\n", | |
" self.i2h = nn.Linear(input_size, hidden_size) # Wxh\n", | |
" self.h2h = nn.Linear(hidden_size, hidden_size) # Whh\n", | |
" self.h2o = nn.Linear(hidden_size, input_size) # Why\n", | |
" self.hidden = torch.zeros(hidden_size).cuda()\n", | |
"\n", | |
" def forward(self, input):\n", | |
" x = self.i2h(torch.nn.functional.one_hot(input, num_classes=input_size).type(torch.FloatTensor).cuda())\n", | |
" y = self.h2h(self.hidden)\n", | |
" hidden = torch.tanh(y + x)\n", | |
" self.hidden = hidden.detach()\n", | |
" return self.h2o(hidden)" | |
], | |
"execution_count": 16, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "-yCxsmBo4UmB" | |
}, | |
"source": [ | |
"# Define our training function\n", | |
"\n", | |
"This trains an RNN on a sequence of input / output pairs" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "pdTq4A_n564r" | |
}, | |
"source": [ | |
"\"\"\"\n", | |
"Example: repeat([0,1,2]) = [(0,1), (1,2), (2,0), (0,1), (1,2), (2, 0), ...]\n", | |
"We're going to use this to generate training data for our RNN\n", | |
"\"\"\"\n", | |
"def repeat(numbers):\n", | |
" n = len(numbers)\n", | |
" for i in range(10000000):\n", | |
" yield (numbers[i % n], numbers[(i+1) % n])\n" | |
], | |
"execution_count": 35, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "o3JolnM6GXwP" | |
}, | |
"source": [ | |
"# `vectorify` just turns a number into a 1-dimensional vector PyTorch can use\n", | |
"def vectorify(i):\n", | |
" return torch.Tensor([i]).type(torch.LongTensor).squeeze().cuda()\n", | |
"# trains an RNN model on a sequence of training data\n", | |
"def train(rnn, sequence, n_items=10000):\n", | |
" optimizer = torch.optim.Adam(rnn.parameters(), lr=0.05)\n", | |
" mb = master_bar(range(10))\n", | |
" prev = 0\n", | |
" for _ in mb:\n", | |
" running_loss = 0\n", | |
" for i in progress_bar(range(int(n_items/10)), parent=mb):\n", | |
" input, target = next(sequence)\n", | |
" input = vectorify(input)\n", | |
" target = vectorify(target)\n", | |
" # forward pass\n", | |
" output = rnn(input)\n", | |
" # compute loss\n", | |
" loss = F.cross_entropy(output.unsqueeze(0), target.unsqueeze(0))\n", | |
" running_loss += loss.item()\n", | |
" # compute gradients and take optimizer step\n", | |
" optimizer.zero_grad()\n", | |
" loss.backward()\n", | |
" optimizer.step()\n", | |
" print(running_loss / (n_items/10))" | |
], | |
"execution_count": 86, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "W4mYhhVJNjxE" | |
}, | |
"source": [ | |
"def predict(rnn, i):\n", | |
" return int(torch.multinomial(F.softmax(rnn(vectorify(i)), dim=-1), 1)[0])" | |
], | |
"execution_count": 8, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "M5Q9IC_g4dkY" | |
}, | |
"source": [ | |
"# Train different models on a bunch of sequences!" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "Hup3sbyM4jKG" | |
}, | |
"source": [ | |
"### Model 1: `repeat([0, 1, 2])`\n", | |
"\n", | |
"Let's see if the RNN can learn the sequence 0, 1, 2, 0, 1, 2, ..." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 191 | |
}, | |
"id": "q4oytj_gKHWs", | |
"outputId": "dbdcc87c-5131-456f-9960-657fc20bd16e" | |
}, | |
"source": [ | |
"rnn_012 = RNN().cuda()\n", | |
"train(rnn_012, repeat([0,1,2]))" | |
], | |
"execution_count": 23, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/html": [ | |
"" | |
], | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
} | |
}, | |
{ | |
"output_type": "stream", | |
"text": [ | |
"0.0022184200907126067\n", | |
"0.0002202294015325606\n", | |
"0.0001041749426163733\n", | |
"6.315491446293891e-05\n", | |
"4.312607115134597e-05\n", | |
"3.182382781524211e-05\n", | |
"2.4523668352048844e-05\n", | |
"1.951815066859126e-05\n", | |
"1.6031735553406178e-05\n", | |
"1.3375555304810405e-05\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "j0d9vmAMJ_7E", | |
"outputId": "f3263186-5341-446b-b0a3-d91d18d62e13" | |
}, | |
"source": [ | |
"i = 0\n", | |
"for _ in range(10):\n", | |
" i = predict(rnn_012, i)\n", | |
" print(i, end=' ')" | |
], | |
"execution_count": 24, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"1 2 0 1 2 0 1 2 0 1 " | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "heLpKJlG4z4d" | |
}, | |
"source": [ | |
"Looks like it works! Hooray." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "pe0O19Wt44fk" | |
}, | |
"source": [ | |
"## Model 2: `repeat([0, 1, 2, 1])`\n", | |
"\n", | |
"Now let's give the RNN a harder training task: let's see if it can learn the sequence 0, 1, 2, 1, 0, 1, 2, 1, ....\n", | |
"\n", | |
"This is harder because the next number after `1` can be either 2 or 0, depending on the sequence so far" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 191 | |
}, | |
"id": "MpcAvgmWIAgE", | |
"outputId": "04b2ce96-0ed5-402c-c138-c880b63753fc" | |
}, | |
"source": [ | |
"count_up_down_rnn = RNN2().cuda()\n", | |
"train(count_up_down_rnn, repeat([0,1,2,1]))" | |
], | |
"execution_count": 27, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/html": [ | |
"" | |
], | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
} | |
}, | |
{ | |
"output_type": "stream", | |
"text": [ | |
"0.004983700064104051\n", | |
"0.0035814674655674024\n", | |
"0.003563214811589569\n", | |
"0.003552207035868196\n", | |
"0.0035483673014503436\n", | |
"0.0035461855988920435\n", | |
"0.0035448039943294135\n", | |
"0.0035438635173806687\n", | |
"0.003543186600592162\n", | |
"0.00354267576578859\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "mt1_UdKsM6Sf", | |
"outputId": "31499366-a6db-49ff-8918-ed59331fb44e" | |
}, | |
"source": [ | |
"i = 0\n", | |
"for _ in range(12):\n", | |
" i = predict(count_up_down_rnn, i)\n", | |
" print(i, end=' ')\n" | |
], | |
"execution_count": 22, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"1 0 1 2 1 0 1 2 1 0 1 2 " | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "8idokkMY5Lbv" | |
}, | |
"source": [ | |
"That works too! Hooray. We can also print out the hidden vector after every step -- you'll notice that the hidden vector corresponding to the `1` changes depending on whether the previous number was a 2 or a 0" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "ckqoIwzgQPYz", | |
"outputId": "7cfed0e8-1fae-4cac-e1f9-9842c028d831" | |
}, | |
"source": [ | |
"i = 0\n", | |
"for _ in range(12):\n", | |
" i = predict(count_up_down_rnn, i)\n", | |
" print(i)\n", | |
" print(count_up_down_rnn.hidden)\n" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"1\n", | |
"tensor([-0.0075, 0.1063, 0.2824, -0.4940], device='cuda:0')\n", | |
"2\n", | |
"tensor([ 0.4515, -0.1440, 0.4424, -0.7459], device='cuda:0')\n", | |
"1\n", | |
"tensor([ 0.3542, -0.2290, 0.5238, -0.3698], device='cuda:0')\n", | |
"0\n", | |
"tensor([ 0.4261, 0.0932, 0.5027, -0.5520], device='cuda:0')\n", | |
"1\n", | |
"tensor([-0.0075, 0.1063, 0.2824, -0.4940], device='cuda:0')\n", | |
"2\n", | |
"tensor([ 0.4515, -0.1440, 0.4424, -0.7459], device='cuda:0')\n", | |
"1\n", | |
"tensor([ 0.3542, -0.2290, 0.5238, -0.3698], device='cuda:0')\n", | |
"0\n", | |
"tensor([ 0.4261, 0.0932, 0.5027, -0.5520], device='cuda:0')\n", | |
"1\n", | |
"tensor([-0.0075, 0.1063, 0.2824, -0.4940], device='cuda:0')\n", | |
"2\n", | |
"tensor([ 0.4515, -0.1440, 0.4424, -0.7459], device='cuda:0')\n", | |
"1\n", | |
"tensor([ 0.3542, -0.2290, 0.5238, -0.3698], device='cuda:0')\n", | |
"0\n", | |
"tensor([ 0.4261, 0.0932, 0.5027, -0.5520], device='cuda:0')\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "zFbdybd25XLf" | |
}, | |
"source": [ | |
"## Model 3: Can the RNN memorize an 8 number sequence?\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "V4nX2wgqPFey" | |
}, | |
"source": [ | |
"rnn_8_numbers = RNN().cuda()" | |
], | |
"execution_count": 96, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 191 | |
}, | |
"id": "Wnh5mnNcIs2k", | |
"outputId": "95da21ce-4626-4ced-8a6f-c1785938c0c6" | |
}, | |
"source": [ | |
"train(rnn_8_numbers, repeat([0,0,1,1,2,2,1,1]), n_items=40000)" | |
], | |
"execution_count": 97, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/html": [ | |
"" | |
], | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
} | |
}, | |
{ | |
"output_type": "stream", | |
"text": [ | |
"0.1449985207306754\n", | |
"0.012814083898633498\n", | |
"0.002181912179502801\n", | |
"0.00040016488425779253\n", | |
"7.474760249539259e-05\n", | |
"1.4078158818364272e-05\n", | |
"2.666311959181655e-06\n", | |
"5.075927006270575e-07\n", | |
"9.435413546121652e-08\n", | |
"1.746415985337535e-08\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "n9SbGcnJJBqv", | |
"outputId": "f35164a7-6394-4be8-d046-aa4ce7e3a9af" | |
}, | |
"source": [ | |
"i = 0\n", | |
"for _ in range(33):\n", | |
" i = predict(rnn_8_numbers, i)\n", | |
" print(i, end=' ')" | |
], | |
"execution_count": 98, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"0 1 1 2 2 1 1 0 0 1 1 2 2 1 1 0 0 1 1 2 2 1 1 0 0 1 1 2 2 1 1 0 0 " | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "At65aULv5n_g" | |
}, | |
"source": [ | |
"Those results look really good! Seems like 8 numbers isn't too much, we just need to train for a bit longer." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "tFRGTGCT5eWp" | |
}, | |
"source": [ | |
"## Model 4: Can the RNN memorize a 20-number sequence?\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "Egje4f5VPQTA" | |
}, | |
"source": [ | |
"rnn_many_numbers = RNN().cuda()" | |
], | |
"execution_count": 91, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 191 | |
}, | |
"id": "DZ85GdQ3PhOV", | |
"outputId": "0ee1f2d3-43e9-411d-9924-4585976033d9" | |
}, | |
"source": [ | |
"train(rnn_many_numbers, repeat([0,0,0,0,0,1,1,1,1,1,2,2,2,2,2,1,1,1,1,1]))" | |
], | |
"execution_count": 92, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/html": [ | |
"" | |
], | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
} | |
}, | |
{ | |
"output_type": "stream", | |
"text": [ | |
"0.5211231253417209\n", | |
"0.44895306542445906\n", | |
"0.4292228445082437\n", | |
"0.415300760838436\n", | |
"0.40403987938558567\n", | |
"0.39438674849120436\n", | |
"0.3858543072966204\n", | |
"0.37817201468275746\n", | |
"0.37116976792353673\n", | |
"0.3647305486374853\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "GO031I-2Pkac", | |
"outputId": "1de65502-8c10-4965-b5c7-9326ed86b003" | |
}, | |
"source": [ | |
"i = 0\n", | |
"for _ in range(30):\n", | |
" i = predict(rnn_many_numbers, i)\n", | |
" print(i, end=' ')" | |
], | |
"execution_count": 93, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"0 0 0 1 1 1 1 2 2 2 2 2 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 " | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "RBbXWfuBPnq0" | |
}, | |
"source": [ | |
"Looks like that's a bit too hard -- it's pretty far off for this sequence! So it seems like maybe 12 numbers is too much for an RNN with a 4-dimensional hidden state to memorize. Let's try with a bigger hidden vector:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "YIiaXiux5mUS" | |
}, | |
"source": [ | |
"rnn_many_numbers2 = RNN(hidden_size=8).cuda()" | |
], | |
"execution_count": 64, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 191 | |
}, | |
"id": "w25UBxn26iHD", | |
"outputId": "6059f4ab-0ec4-4609-c9d4-acb405b58bd5" | |
}, | |
"source": [ | |
"train(rnn_many_numbers2, repeat([0,0,0,0,0,1,1,1,1,1,2,2,2,2,2,1,1,1,1,1]), n_items=100000)" | |
], | |
"execution_count": 85, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/html": [ | |
"" | |
], | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
} | |
}, | |
{ | |
"output_type": "stream", | |
"text": [ | |
"0.055236988693244234\n", | |
"0.04933105486383399\n", | |
"0.044032463341482204\n", | |
"0.03924229313720872\n", | |
"0.034912732995814315\n", | |
"0.031004845797666894\n", | |
"0.027484601506395576\n", | |
"0.024321592863570014\n", | |
"0.02148699332665616\n", | |
"0.018953162126702738\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "P1_0pfUv6ii0", | |
"outputId": "90f631f5-c731-402f-ff6a-94ed7ecbd3bc" | |
}, | |
"source": [ | |
"i = 0\n", | |
"for _ in range(30):\n", | |
" i = predict(rnn_many_numbers2, i)\n", | |
" print(i, end=' ')" | |
], | |
"execution_count": 90, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"0 0 0 0 1 1 1 1 1 2 2 2 2 2 1 1 1 1 1 0 0 0 0 0 1 1 1 1 1 2 " | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "S3HJ9NQv6yH0" | |
}, | |
"source": [ | |
"this is a lot better! So it seems like we needed a bigger hidden state and to train the model for longer." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "LvvcsCHn634s" | |
}, | |
"source": [ | |
"" | |
], | |
"execution_count": null, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment