Skip to content

Instantly share code, notes, and snippets.

@jvns
Created November 21, 2020 04:04
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jvns/b8804fb9d0672ce147a28d22648b4bd7 to your computer and use it in GitHub Desktop.
Save jvns/b8804fb9d0672ce147a28d22648b4bd7 to your computer and use it in GitHub Desktop.
rnn 123.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "rnn 123.ipynb",
"provenance": [],
"collapsed_sections": [],
"authorship_tag": "ABX9TyMgy3H46gOlgYJzKwafyS7N",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/jvns/b8804fb9d0672ce147a28d22648b4bd7/rnn-123.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"metadata": {
"id": "0GTbDNteFg2m"
},
"source": [
"import itertools\n",
"import fastprogress\n",
"from fastai.text import *\n"
],
"execution_count": 1,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "UPHnBLaTFl8O"
},
"source": [
"# Define the models\n",
"\n",
"I'm defining 2 different versions of the RNN model, just because I wanted to test 2 different ways to see if they both work"
]
},
{
"cell_type": "code",
"metadata": {
"id": "ryL6xSivFrXM"
},
"source": [
"input_size = 3\n",
"hidden_size = 4\n",
"class RNN(nn.Module):\n",
" def __init__(self, hidden_size=4):\n",
" super().__init__()\n",
" self.i2h = nn.Linear(input_size + hidden_size, hidden_size)\n",
" self.i2o = nn.Linear(input_size + hidden_size, input_size)\n",
" self.hidden = torch.zeros(hidden_size).cuda()\n",
"\n",
" def forward(self, input):\n",
" input = torch.nn.functional.one_hot(input, num_classes=input_size).type(torch.FloatTensor).cuda()\n",
" combined = torch.cat((input, self.hidden))\n",
" hidden = self.i2h(combined)\n",
" output = self.i2o(combined)\n",
" self.hidden = hidden.detach()\n",
" return output\n",
"\n"
],
"execution_count": 37,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "tOhSjTcC1oO4"
},
"source": [
"class RNN2(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.i2h = nn.Linear(input_size, hidden_size) # Wxh\n",
" self.h2h = nn.Linear(hidden_size, hidden_size) # Whh\n",
" self.h2o = nn.Linear(hidden_size, input_size) # Why\n",
" self.hidden = torch.zeros(hidden_size).cuda()\n",
"\n",
" def forward(self, input):\n",
" x = self.i2h(torch.nn.functional.one_hot(input, num_classes=input_size).type(torch.FloatTensor).cuda())\n",
" y = self.h2h(self.hidden)\n",
" hidden = torch.tanh(y + x)\n",
" self.hidden = hidden.detach()\n",
" return self.h2o(hidden)"
],
"execution_count": 16,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "-yCxsmBo4UmB"
},
"source": [
"# Define our training function\n",
"\n",
"This trains an RNN on a sequence of input / output pairs"
]
},
{
"cell_type": "code",
"metadata": {
"id": "pdTq4A_n564r"
},
"source": [
"\"\"\"\n",
"Example: repeat([0,1,2]) = [(0,1), (1,2), (2,0), (0,1), (1,2), (2, 0), ...]\n",
"We're going to use this to generate training data for our RNN\n",
"\"\"\"\n",
"def repeat(numbers):\n",
" n = len(numbers)\n",
" for i in range(10000000):\n",
" yield (numbers[i % n], numbers[(i+1) % n])\n"
],
"execution_count": 35,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "o3JolnM6GXwP"
},
"source": [
"# `vectorify` just turns a number into a 1-dimensional vector PyTorch can use\n",
"def vectorify(i):\n",
" return torch.Tensor([i]).type(torch.LongTensor).squeeze().cuda()\n",
"# trains an RNN model on a sequence of training data\n",
"def train(rnn, sequence, n_items=10000):\n",
" optimizer = torch.optim.Adam(rnn.parameters(), lr=0.05)\n",
" mb = master_bar(range(10))\n",
" prev = 0\n",
" for _ in mb:\n",
" running_loss = 0\n",
" for i in progress_bar(range(int(n_items/10)), parent=mb):\n",
" input, target = next(sequence)\n",
" input = vectorify(input)\n",
" target = vectorify(target)\n",
" # forward pass\n",
" output = rnn(input)\n",
" # compute loss\n",
" loss = F.cross_entropy(output.unsqueeze(0), target.unsqueeze(0))\n",
" running_loss += loss.item()\n",
" # compute gradients and take optimizer step\n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" optimizer.step()\n",
" print(running_loss / (n_items/10))"
],
"execution_count": 86,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "W4mYhhVJNjxE"
},
"source": [
"def predict(rnn, i):\n",
" return int(torch.multinomial(F.softmax(rnn(vectorify(i)), dim=-1), 1)[0])"
],
"execution_count": 8,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "M5Q9IC_g4dkY"
},
"source": [
"# Train different models on a bunch of sequences!"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Hup3sbyM4jKG"
},
"source": [
"### Model 1: `repeat([0, 1, 2])`\n",
"\n",
"Let's see if the RNN can learn the sequence 0, 1, 2, 0, 1, 2, ..."
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 191
},
"id": "q4oytj_gKHWs",
"outputId": "dbdcc87c-5131-456f-9960-657fc20bd16e"
},
"source": [
"rnn_012 = RNN().cuda()\n",
"train(rnn_012, repeat([0,1,2]))"
],
"execution_count": 23,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/html": [
""
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "stream",
"text": [
"0.0022184200907126067\n",
"0.0002202294015325606\n",
"0.0001041749426163733\n",
"6.315491446293891e-05\n",
"4.312607115134597e-05\n",
"3.182382781524211e-05\n",
"2.4523668352048844e-05\n",
"1.951815066859126e-05\n",
"1.6031735553406178e-05\n",
"1.3375555304810405e-05\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "j0d9vmAMJ_7E",
"outputId": "f3263186-5341-446b-b0a3-d91d18d62e13"
},
"source": [
"i = 0\n",
"for _ in range(10):\n",
" i = predict(rnn_012, i)\n",
" print(i, end=' ')"
],
"execution_count": 24,
"outputs": [
{
"output_type": "stream",
"text": [
"1 2 0 1 2 0 1 2 0 1 "
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "heLpKJlG4z4d"
},
"source": [
"Looks like it works! Hooray."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pe0O19Wt44fk"
},
"source": [
"## Model 2: `repeat([0, 1, 2, 1])`\n",
"\n",
"Now let's give the RNN a harder training task: let's see if it can learn the sequence 0, 1, 2, 1, 0, 1, 2, 1, ....\n",
"\n",
"This is harder because the next number after `1` can be either 2 or 0, depending on the sequence so far"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 191
},
"id": "MpcAvgmWIAgE",
"outputId": "04b2ce96-0ed5-402c-c138-c880b63753fc"
},
"source": [
"count_up_down_rnn = RNN2().cuda()\n",
"train(count_up_down_rnn, repeat([0,1,2,1]))"
],
"execution_count": 27,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/html": [
""
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "stream",
"text": [
"0.004983700064104051\n",
"0.0035814674655674024\n",
"0.003563214811589569\n",
"0.003552207035868196\n",
"0.0035483673014503436\n",
"0.0035461855988920435\n",
"0.0035448039943294135\n",
"0.0035438635173806687\n",
"0.003543186600592162\n",
"0.00354267576578859\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "mt1_UdKsM6Sf",
"outputId": "31499366-a6db-49ff-8918-ed59331fb44e"
},
"source": [
"i = 0\n",
"for _ in range(12):\n",
" i = predict(count_up_down_rnn, i)\n",
" print(i, end=' ')\n"
],
"execution_count": 22,
"outputs": [
{
"output_type": "stream",
"text": [
"1 0 1 2 1 0 1 2 1 0 1 2 "
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8idokkMY5Lbv"
},
"source": [
"That works too! Hooray. We can also print out the hidden vector after every step -- you'll notice that the hidden vector corresponding to the `1` changes depending on whether the previous number was a 2 or a 0"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "ckqoIwzgQPYz",
"outputId": "7cfed0e8-1fae-4cac-e1f9-9842c028d831"
},
"source": [
"i = 0\n",
"for _ in range(12):\n",
" i = predict(count_up_down_rnn, i)\n",
" print(i)\n",
" print(count_up_down_rnn.hidden)\n"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"1\n",
"tensor([-0.0075, 0.1063, 0.2824, -0.4940], device='cuda:0')\n",
"2\n",
"tensor([ 0.4515, -0.1440, 0.4424, -0.7459], device='cuda:0')\n",
"1\n",
"tensor([ 0.3542, -0.2290, 0.5238, -0.3698], device='cuda:0')\n",
"0\n",
"tensor([ 0.4261, 0.0932, 0.5027, -0.5520], device='cuda:0')\n",
"1\n",
"tensor([-0.0075, 0.1063, 0.2824, -0.4940], device='cuda:0')\n",
"2\n",
"tensor([ 0.4515, -0.1440, 0.4424, -0.7459], device='cuda:0')\n",
"1\n",
"tensor([ 0.3542, -0.2290, 0.5238, -0.3698], device='cuda:0')\n",
"0\n",
"tensor([ 0.4261, 0.0932, 0.5027, -0.5520], device='cuda:0')\n",
"1\n",
"tensor([-0.0075, 0.1063, 0.2824, -0.4940], device='cuda:0')\n",
"2\n",
"tensor([ 0.4515, -0.1440, 0.4424, -0.7459], device='cuda:0')\n",
"1\n",
"tensor([ 0.3542, -0.2290, 0.5238, -0.3698], device='cuda:0')\n",
"0\n",
"tensor([ 0.4261, 0.0932, 0.5027, -0.5520], device='cuda:0')\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "zFbdybd25XLf"
},
"source": [
"## Model 3: Can the RNN memorize an 8 number sequence?\n"
]
},
{
"cell_type": "code",
"metadata": {
"id": "V4nX2wgqPFey"
},
"source": [
"rnn_8_numbers = RNN().cuda()"
],
"execution_count": 96,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 191
},
"id": "Wnh5mnNcIs2k",
"outputId": "95da21ce-4626-4ced-8a6f-c1785938c0c6"
},
"source": [
"train(rnn_8_numbers, repeat([0,0,1,1,2,2,1,1]), n_items=40000)"
],
"execution_count": 97,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/html": [
""
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "stream",
"text": [
"0.1449985207306754\n",
"0.012814083898633498\n",
"0.002181912179502801\n",
"0.00040016488425779253\n",
"7.474760249539259e-05\n",
"1.4078158818364272e-05\n",
"2.666311959181655e-06\n",
"5.075927006270575e-07\n",
"9.435413546121652e-08\n",
"1.746415985337535e-08\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "n9SbGcnJJBqv",
"outputId": "f35164a7-6394-4be8-d046-aa4ce7e3a9af"
},
"source": [
"i = 0\n",
"for _ in range(33):\n",
" i = predict(rnn_8_numbers, i)\n",
" print(i, end=' ')"
],
"execution_count": 98,
"outputs": [
{
"output_type": "stream",
"text": [
"0 1 1 2 2 1 1 0 0 1 1 2 2 1 1 0 0 1 1 2 2 1 1 0 0 1 1 2 2 1 1 0 0 "
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "At65aULv5n_g"
},
"source": [
"Those results look really good! Seems like 8 numbers isn't too much, we just need to train for a bit longer."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tFRGTGCT5eWp"
},
"source": [
"## Model 4: Can the RNN memorize a 20-number sequence?\n"
]
},
{
"cell_type": "code",
"metadata": {
"id": "Egje4f5VPQTA"
},
"source": [
"rnn_many_numbers = RNN().cuda()"
],
"execution_count": 91,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 191
},
"id": "DZ85GdQ3PhOV",
"outputId": "0ee1f2d3-43e9-411d-9924-4585976033d9"
},
"source": [
"train(rnn_many_numbers, repeat([0,0,0,0,0,1,1,1,1,1,2,2,2,2,2,1,1,1,1,1]))"
],
"execution_count": 92,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/html": [
""
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "stream",
"text": [
"0.5211231253417209\n",
"0.44895306542445906\n",
"0.4292228445082437\n",
"0.415300760838436\n",
"0.40403987938558567\n",
"0.39438674849120436\n",
"0.3858543072966204\n",
"0.37817201468275746\n",
"0.37116976792353673\n",
"0.3647305486374853\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "GO031I-2Pkac",
"outputId": "1de65502-8c10-4965-b5c7-9326ed86b003"
},
"source": [
"i = 0\n",
"for _ in range(30):\n",
" i = predict(rnn_many_numbers, i)\n",
" print(i, end=' ')"
],
"execution_count": 93,
"outputs": [
{
"output_type": "stream",
"text": [
"0 0 0 1 1 1 1 2 2 2 2 2 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 "
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RBbXWfuBPnq0"
},
"source": [
"Looks like that's a bit too hard -- it's pretty far off for this sequence! So it seems like maybe 12 numbers is too much for an RNN with a 4-dimensional hidden state to memorize. Let's try with a bigger hidden vector:"
]
},
{
"cell_type": "code",
"metadata": {
"id": "YIiaXiux5mUS"
},
"source": [
"rnn_many_numbers2 = RNN(hidden_size=8).cuda()"
],
"execution_count": 64,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 191
},
"id": "w25UBxn26iHD",
"outputId": "6059f4ab-0ec4-4609-c9d4-acb405b58bd5"
},
"source": [
"train(rnn_many_numbers2, repeat([0,0,0,0,0,1,1,1,1,1,2,2,2,2,2,1,1,1,1,1]), n_items=100000)"
],
"execution_count": 85,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/html": [
""
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "stream",
"text": [
"0.055236988693244234\n",
"0.04933105486383399\n",
"0.044032463341482204\n",
"0.03924229313720872\n",
"0.034912732995814315\n",
"0.031004845797666894\n",
"0.027484601506395576\n",
"0.024321592863570014\n",
"0.02148699332665616\n",
"0.018953162126702738\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "P1_0pfUv6ii0",
"outputId": "90f631f5-c731-402f-ff6a-94ed7ecbd3bc"
},
"source": [
"i = 0\n",
"for _ in range(30):\n",
" i = predict(rnn_many_numbers2, i)\n",
" print(i, end=' ')"
],
"execution_count": 90,
"outputs": [
{
"output_type": "stream",
"text": [
"0 0 0 0 1 1 1 1 1 2 2 2 2 2 1 1 1 1 1 0 0 0 0 0 1 1 1 1 1 2 "
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "S3HJ9NQv6yH0"
},
"source": [
"this is a lot better! So it seems like we needed a bigger hidden state and to train the model for longer."
]
},
{
"cell_type": "code",
"metadata": {
"id": "LvvcsCHn634s"
},
"source": [
""
],
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment