Last active
May 7, 2018 19:59
-
-
Save Stvad/8f9ad2215ad9549468d95c0b889511db to your computer and use it in GitHub Desktop.
Fast.ai Lesson6 problem with cat model
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"%matplotlib inline\n", | |
"\n", | |
"from pathlib import Path\n", | |
"\n", | |
"from torch import nn\n", | |
"from torch import optim\n", | |
"from torch.autograd import Variable\n", | |
"import torch.nn.functional as F\n", | |
"\n", | |
"\n", | |
"import torch\n", | |
"\n", | |
"from itertools import count\n", | |
"\n", | |
"import numpy as np\n", | |
"\n", | |
"from fastai.column_data import ColumnarModelData\n", | |
"from fastai.dataset import get_cv_idxs\n", | |
"from fastai.io import get_data\n", | |
"from fastai.learner import fit, set_lrs" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"path = 'data/nietzsche'" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"input_text_path = f'{path}/input_text.txt'" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"get_data(\"https://s3.amazonaws.com/text-datasets/nietzsche.txt\", input_text_path)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"input_text = Path(input_text_path).read_text()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"vocab = ['\\0'] + sorted(list(set(input_text)))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"char_to_id = dict(zip(vocab, count()))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"int_text = [char_to_id[char] for char in input_text]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"hidden_layer_size = 256\n", | |
"embedding_size=42" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class CharLoopModel(nn.Module):\n", | |
" def __init__(self, hidden_layer_size, embedding_size, vocab_size):\n", | |
" super().__init__()\n", | |
" \n", | |
" self.embedding = nn.Embedding(vocab_size, embedding_size)\n", | |
" self.input = nn.Linear(embedding_size, hidden_layer_size)\n", | |
" self.recurrent = nn.Linear(hidden_layer_size, hidden_layer_size)\n", | |
" self.output = nn.Linear(hidden_layer_size, vocab_size)\n", | |
" \n", | |
" def forward(self, *character_tensor):\n", | |
" inputs = character_tensor\n", | |
" batch_size = len(inputs[0])\n", | |
" hidden_activations = Variable(torch.zeros(batch_size, hidden_layer_size).cuda())\n", | |
" \n", | |
" for input in inputs:\n", | |
" inp = F.relu(self.input(self.embedding(input)))\n", | |
" hidden_activations = F.tanh(self.recurrent(hidden_activations + inp))\n", | |
" \n", | |
" return F.log_softmax(self.output(hidden_activations), dim=-1)\n", | |
" \n", | |
" " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"char_loop_model = CharLoopModel(hidden_layer_size, embedding_size, len(vocab)).cuda()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"sample_lenght = 9" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"*inputs, output = np.array([[int_text[i+j] for i in range(len(int_text)-sample_lenght)] for j in range(sample_lenght)])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"inputs = np.array(inputs).T" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"val_indexes = get_cv_idxs(len(output))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"model_data = ColumnarModelData.from_arrays('.', val_indexes,inputs, output, bs=1024)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"lr=1e-3\n", | |
"optimizer = optim.Adam(char_loop_model.parameters(), lr)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "63b26782980b4ee7859a8f6f2dc4db56", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"epoch trn_loss val_loss \n", | |
" 0 1.959717 1.924573 \n", | |
"\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"[array([1.92457])]" | |
] | |
}, | |
"execution_count": 18, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"fit(char_loop_model, model_data, 1, optimizer, crit=F.nll_loss)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"set_lrs(optimizer, 1e-4)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "cd83f4bf500d43659cfbb04957496514", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"epoch trn_loss val_loss \n", | |
" 0 1.873925 1.873269 \n", | |
"\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"[array([1.87327])]" | |
] | |
}, | |
"execution_count": 20, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"fit(char_loop_model, model_data, 1, optimizer, crit=F.nll_loss)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 112, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"set_lrs(optimizer, 1e-5)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 113, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "e55d3bf70ef3421b89cb1091ff1d4c4c", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"epoch trn_loss val_loss \n", | |
" 0 1.860882 1.868817 \n", | |
"\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"[array([1.86882])]" | |
] | |
}, | |
"execution_count": 113, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"fit(char_loop_model, model_data, 1, optimizer, crit=F.nll_loss)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 21, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class CharLoopModelCat(nn.Module):\n", | |
" def __init__(self, hidden_layer_size, embedding_size, vocab_size):\n", | |
" super().__init__()\n", | |
" \n", | |
" self.embedding = nn.Embedding(vocab_size, embedding_size)\n", | |
" self.input = nn.Linear(embedding_size+hidden_layer_size, hidden_layer_size)\n", | |
" self.recurrent = nn.Linear(hidden_layer_size, hidden_layer_size)\n", | |
" self.output = nn.Linear(hidden_layer_size, vocab_size)\n", | |
" \n", | |
" def forward(self, *inputs):\n", | |
" batch_size = len(inputs[0])\n", | |
" hidden_activations = Variable(torch.zeros(batch_size, hidden_layer_size).cuda())\n", | |
" \n", | |
" for input in inputs:\n", | |
" inp = torch.cat((self.embedding(input), hidden_activations), 1)\n", | |
" inp = F.relu(self.input(inp))\n", | |
" hidden_activations = F.tanh(self.recurrent(inp))\n", | |
" \n", | |
" return F.log_softmax(self.output(hidden_activations), dim=-1)\n", | |
" \n", | |
" " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 59, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"n_hidden=256" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 22, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"char_cat_loop_model = CharLoopModelCat(hidden_layer_size, embedding_size, len(vocab)).cuda()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 23, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"lr=1e-2\n", | |
"optimizer = optim.Adam(char_cat_loop_model.parameters(), lr)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 24, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "003272b16e9c4f6b967d2e3d90fbce68", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"epoch trn_loss val_loss \n", | |
" 0 3.193492 3.207976 \n", | |
"\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"[array([3.20798])]" | |
] | |
}, | |
"execution_count": 24, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"fit(char_cat_loop_model, model_data, 1, optimizer, crit=F.nll_loss)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 25, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"set_lrs(optimizer, 1e-3)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 26, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "98621aa3cd9b43e2b24588ac032154a0", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"epoch trn_loss val_loss \n", | |
" 0 3.127785 3.129088 \n", | |
"\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"[array([3.12909])]" | |
] | |
}, | |
"execution_count": 26, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"fit(char_cat_loop_model, model_data, 1, optimizer, crit=F.nll_loss)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 27, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class TorchRnnModel(nn.Module):\n", | |
" def __init__(self, hidden_layer_size, embedding_size, vocab_size):\n", | |
" super().__init__()\n", | |
" \n", | |
" self.embedding = nn.Embedding(vocab_size, embedding_size)\n", | |
" self.rnn = nn.RNN(embedding_size, hidden_layer_size)\n", | |
" self.output = nn.Linear(hidden_layer_size, vocab_size)\n", | |
" \n", | |
" def forward(self, *inputs):\n", | |
" batch_size = len(inputs[0])\n", | |
" hidden_activations = Variable(torch.zeros(1, batch_size, hidden_layer_size).cuda())\n", | |
" outputs, hidden_activations = self.rnn(self.embedding(torch.stack(inputs)), hidden_activations) \n", | |
" \n", | |
" return F.log_softmax(self.output(outputs[-1]), dim=-1)\n", | |
" \n", | |
" " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 28, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"torch_rnn_model = TorchRnnModel(hidden_layer_size, embedding_size, len(vocab)).cuda()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 29, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"lr=1e-2\n", | |
"optimizer = optim.Adam(torch_rnn_model.parameters(), lr)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 30, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "0064386f20b344929d7c511705431d58", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"epoch trn_loss val_loss \n", | |
" 0 1.828697 1.807756 \n", | |
"\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"[array([1.80776])]" | |
] | |
}, | |
"execution_count": 30, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"fit(torch_rnn_model, model_data, 1, optimizer, crit=F.nll_loss)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 34, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"set_lrs(optimizer, 1e-4)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 35, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "403ded53acc94172bd128cfa85b81c45", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"epoch trn_loss val_loss \n", | |
" 0 1.526433 1.557709 \n", | |
"\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"[array([1.55771])]" | |
] | |
}, | |
"execution_count": 35, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"fit(torch_rnn_model, model_data, 1, optimizer, crit=F.nll_loss)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Environment (conda_fastai)", | |
"language": "python", | |
"name": "conda_fastai" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.5" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment