Created
October 8, 2018 15:45
-
-
Save braingineer/1d7baecf2c99013d88d4d1db77449aec to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"from argparse import Namespace\n", | |
"import json\n", | |
"import os\n", | |
"\n", | |
"import matplotlib.pyplot as plt\n", | |
"import numpy as np\n", | |
"import pandas as pd\n", | |
"import seaborn as sns\n", | |
"import torch\n", | |
"import torch.nn as nn\n", | |
"import torch.nn.functional as F\n", | |
"import torch.optim as optim\n", | |
"from torch.utils.data import Dataset, DataLoader\n", | |
"from tqdm import tqdm_notebook\n", | |
"\n", | |
"from vocabulary import Vocabulary\n", | |
"\n", | |
"%matplotlib inline\n", | |
"\n", | |
"plt.style.use('fivethirtyeight')\n", | |
"plt.rcParams['figure.figsize'] = (14, 6)\n", | |
"\n", | |
"START_TOKEN = \"^\"\n", | |
"END_TOKEN = \"_\"" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"args = Namespace(\n", | |
" surname_csv=\"../data/surnames.csv\",\n", | |
" cuda=False,\n", | |
" num_epochs=100\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"class RawSurnames(object):\n", | |
" def __init__(self, data_path, delimiter=\",\"):\n", | |
" self.data = pd.read_csv(data_path, delimiter=delimiter)\n", | |
"\n", | |
" def get_data(self, filter_to_nationality=None):\n", | |
" if filter_to_nationality is not None:\n", | |
" return self.data[self.data.nationality.isin(filter_to_nationality)]\n", | |
" return self.data\n", | |
"\n", | |
"# vectorizer\n", | |
"\n", | |
"class SurnamesVectorizer(object):\n", | |
" def __init__(self, surname_vocab, nationality_vocab, max_seq_length):\n", | |
" self.surname_vocab = surname_vocab\n", | |
" self.nationality_vocab = nationality_vocab\n", | |
" self.max_seq_length = max_seq_length\n", | |
" \n", | |
" def save(self, filename):\n", | |
" vec_dict = {\"surname_vocab\": self.surname_vocab.get_serializable_contents(),\n", | |
" \"nationality_vocab\": self.nationality_vocab.get_serializable_contents(),\n", | |
" 'max_seq_length': self.max_seq_length}\n", | |
"\n", | |
" with open(filename, \"w\") as fp:\n", | |
" json.dump(vec_dict, fp)\n", | |
" \n", | |
" @classmethod\n", | |
" def load(cls, filename):\n", | |
" with open(filename, \"r\") as fp:\n", | |
" vec_dict = json.load(fp)\n", | |
"\n", | |
" vec_dict[\"surname_vocab\"] = Vocabulary.deserialize_from_contents(vec_dict[\"surname_vocab\"])\n", | |
" vec_dict[\"nationality_vocab\"] = Vocabulary.deserialize_from_contents(vec_dict[\"nationality_vocab\"])\n", | |
" return cls(**vec_dict)\n", | |
"\n", | |
" @classmethod\n", | |
" def fit(cls, surname_df):\n", | |
" \"\"\"\n", | |
" \"\"\"\n", | |
" surname_vocab = Vocabulary(use_unks=False,\n", | |
" use_mask=True,\n", | |
" use_start_end=True,\n", | |
" start_token=START_TOKEN,\n", | |
" end_token=END_TOKEN)\n", | |
"\n", | |
" nationality_vocab = Vocabulary(use_unks=False, use_start_end=False, use_mask=False)\n", | |
"\n", | |
" max_seq_length = 0\n", | |
" for index, row in surname_df.iterrows():\n", | |
" surname_vocab.add_many(row.surname)\n", | |
" nationality_vocab.add(row.nationality)\n", | |
"\n", | |
" if len(row.surname) > max_seq_length:\n", | |
" max_seq_length = len(row.surname)\n", | |
" max_seq_length = max_seq_length + 2\n", | |
"\n", | |
" return cls(surname_vocab, nationality_vocab, max_seq_length)\n", | |
"\n", | |
" @classmethod\n", | |
" def fit_transform(cls, surname_df, split='train'):\n", | |
" vectorizer = cls.fit(surname_df)\n", | |
" return vectorizer, vectorizer.transform(surname_df, split)\n", | |
"\n", | |
" def transform(self, surname_df, split='train'):\n", | |
"\n", | |
" df = surname_df[surname_df.split==split].reset_index()\n", | |
" n_data = len(df)\n", | |
" \n", | |
" x_surnames = np.zeros((n_data, self.max_seq_length), dtype=np.int64)\n", | |
" y_nationalities = np.zeros(n_data, dtype=np.int64)\n", | |
"\n", | |
" for index, row in df.iterrows():\n", | |
" vectorized_surname = list(self.surname_vocab.map(row.surname, \n", | |
" include_start_end=True))\n", | |
" x_surnames[index, :len(vectorized_surname)] = vectorized_surname\n", | |
" y_nationalities[index] = self.nationality_vocab[row.nationality]\n", | |
"\n", | |
" return VectorizedSurnames(x_surnames, y_nationalities)\n", | |
"\n", | |
"# vec data\n", | |
"\n", | |
"\n", | |
"class VectorizedSurnames(Dataset):\n", | |
" def __init__(self, x_surnames, y_nationalities):\n", | |
" self.x_surnames = x_surnames\n", | |
" self.y_nationalities = y_nationalities\n", | |
"\n", | |
" def __len__(self):\n", | |
" return len(self.x_surnames)\n", | |
"\n", | |
" def __getitem__(self, index):\n", | |
" return {'x_surnames': self.x_surnames[index],\n", | |
" 'y_nationalities': self.y_nationalities[index],\n", | |
" 'x_lengths': len(self.x_surnames[index].nonzero()[0])}\n", | |
"\n", | |
"# data generator\n", | |
"\n", | |
"def generate_batches(dataset, batch_size, shuffle=True,\n", | |
" drop_last=True, device=\"cpu\"): \n", | |
" \"\"\"\n", | |
" A generator function which wraps the PyTorch DataLoader. It will \n", | |
" ensure each tensor is on the write device location.\n", | |
" \"\"\"\n", | |
" dataloader = DataLoader(dataset=dataset, batch_size=batch_size,\n", | |
" shuffle=shuffle, drop_last=drop_last)\n", | |
"\n", | |
" for data_dict in dataloader:\n", | |
" out_data_dict = {}\n", | |
" for name, tensor in data_dict.items():\n", | |
" out_data_dict[name] = data_dict[name].to(device)\n", | |
" yield out_data_dict" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"raw_data = RawSurnames(args.surname_csv).get_data()\n", | |
"\n", | |
"vectorizer = SurnamesVectorizer.fit(raw_data)\n", | |
"\n", | |
"train_dataset = vectorizer.transform(raw_data, split='train')\n", | |
"test_dataset = vectorizer.transform(raw_data, split='test')" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Tasks\n", | |
"\n", | |
"1. embed this vector\n", | |
"2. apply convnet to embedded surnames\n", | |
"3. compute prediction vector \n", | |
"\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"hyperparams = Namespace(\n", | |
" embedding_dim=64\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"batch_gen = generate_batches(train_dataset, batch_size=8)\n", | |
"batch_dict = next(batch_gen)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"embeddings = nn.Embedding(num_embeddings=len(vectorizer.surname_vocab), \n", | |
" embedding_dim=hyperparams.embedding_dim, \n", | |
" padding_idx=0)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"torch.Size([8, 22, 64])" | |
] | |
}, | |
"execution_count": 20, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"embeddings(batch_dict['x_surnames']).shape" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([ 1, 27, 26, 37, 25, 20, 5, 20, 18, 2, 0, 0, 0, 0, 0, 0, 0,\n", | |
" 0, 0, 0, 0, 0])" | |
] | |
}, | |
"execution_count": 14, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"train_dataset[1000]['x_surnames']" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"'^Ingledew_<MASK><MASK><MASK><MASK><MASK><MASK><MASK><MASK><MASK><MASK><MASK><MASK>'" | |
] | |
}, | |
"execution_count": 13, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"\"\".join(list(vectorizer.surname_vocab.lookup_many(train_dataset[1000]['x_surnames'])))" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.2" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
batch_gen = generate_batches(train_dataset, batch_size=8)
batch_dict = next(batch_gen)
embeddings = nn.Embedding(num_embeddings=len(vectorizer.surname_vocab),
embedding_dim=hyperparams.embedding_dim,
padding_idx=0)
x_embedded = embeddings(batch_dict['x_surnames'])
x_embedded.shape
x_embedded_fixed = x_embedded.permute(0, 2, 1)
x_embedded_fixed.shape
conv1 = nn.Conv1d(in_channels=hyperparams.embedding_dim,
out_channels=16,
kernel_size=3)
conv1(x_embedded_fixed).shape
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
x_embedded = embedding(batch_dict[...])
in_channels=feat,
out_channels=??,
kernel_size=??,
stride=1,
padding=0
)