Created
June 25, 2020 23:03
-
-
Save unixpickle/f02b9e8881081fa8d033c2e723d01f89 to your computer and use it in GitHub Desktop.
Binarized mnist
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np\n", | |
"import torch\n", | |
"import torch.nn as nn\n", | |
"import torch.nn.functional as F\n", | |
"import torch.optim as optim\n", | |
"from torchvision import datasets, transforms\n", | |
"from tqdm.auto import tqdm" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class XYSeqModel(nn.Module):\n", | |
" \"\"\"\n", | |
" A sequence model that takes in (x, y, previous_pixel)\n", | |
" and outputs binary log probabilities.\n", | |
" \"\"\"\n", | |
"\n", | |
" def __init__(self, layers=3):\n", | |
" super().__init__()\n", | |
" # Separate embeddings for each input.\n", | |
" self.x_embed = nn.Embedding(28, 128)\n", | |
" self.y_embed = nn.Embedding(28, 128)\n", | |
" self.prev_embed = nn.Embedding(2, 128)\n", | |
" # Simple LSTM with linear output layer.\n", | |
" self.rnn = nn.LSTM(128*3, 256, num_layers=layers)\n", | |
" self.out_layer = nn.Linear(256, 1)\n", | |
" # Learnable initial hidden state.\n", | |
" for i in range(2):\n", | |
" p = nn.Parameter(torch.zeros([layers, 256], dtype=torch.float))\n", | |
" self.register_parameter('hidden_%d' % i, p)\n", | |
"\n", | |
" def forward(self, inputs, hidden=None):\n", | |
" \"\"\"\n", | |
" Apply the model to a sequence.\n", | |
"\n", | |
" Args:\n", | |
" inputs: a [T x N x 3] integer batch of (x, y, previous).\n", | |
" hidden: if specified, a tuple of initial hidden states.\n", | |
" This is mostly useful for sampling.\n", | |
"\n", | |
" Returns:\n", | |
" A tuple (outputs, hidden):\n", | |
" outputs: an [N x 1] batch of logits.\n", | |
" hidden: a tuple of final hidden states, \n", | |
" \"\"\"\n", | |
" x_vec = self.x_embed(inputs[:, :, 0])\n", | |
" y_vec = self.y_embed(inputs[:, :, 1])\n", | |
" prev_vec = self.prev_embed(inputs[:, :, 2])\n", | |
" x = torch.cat([x_vec, y_vec, prev_vec], dim=-1)\n", | |
"\n", | |
" batch = x.shape[1]\n", | |
" if hidden is None:\n", | |
" init_hidden = (self.hidden_0, self.hidden_1)\n", | |
" hidden = tuple(h[:, None].repeat(1, batch, 1) for h in init_hidden)\n", | |
"\n", | |
" outputs = []\n", | |
" for t in range(inputs.shape[0]):\n", | |
" outs, hidden = self.rnn(x[t:t+1], hidden)\n", | |
" outs = self.out_layer(outs.view(batch, -1)).view(1, batch, 1)\n", | |
" outputs.append(outs)\n", | |
"\n", | |
" return torch.cat(outputs, dim=0), hidden\n", | |
" \n", | |
" def compute_loss(self, inputs, targets):\n", | |
" \"\"\"\n", | |
" Compute the binary cross-entropy loss (NLL).\n", | |
" \"\"\"\n", | |
" logits, _ = self(inputs)\n", | |
" return F.binary_cross_entropy_with_logits(logits, targets)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def binarized_pixel_sequences(device, train, batch):\n", | |
" \"\"\"\n", | |
" Iterate over tuples (inputs, targets):\n", | |
" inputs: an [T x N x 3] tensor of shifted inputs.\n", | |
" targets: a [T x N x 1] tensor of targets.\n", | |
" \"\"\"\n", | |
" mnist = datasets.MNIST('data', train=train, download=True,\n", | |
" transform=transforms.ToTensor())\n", | |
" loader = torch.utils.data.DataLoader(mnist, batch_size=batch, shuffle=True, drop_last=True)\n", | |
" xs = torch.tensor([[i % 28 for i in range(28*28)]] * batch)\n", | |
" ys = torch.tensor([[i // 28 for i in range(28*28)]] * batch)\n", | |
" coords = torch.stack([xs, ys], dim=-1).permute(1, 0, 2).long()\n", | |
" while True:\n", | |
" for images, _ in loader:\n", | |
" # Binarize stochastically.\n", | |
" binarized = (images > torch.rand_like(images)).view(images.shape[0], -1)\n", | |
" targets = binarized.float().permute(1, 0)[..., None]\n", | |
" shifted = torch.cat([torch.zeros_like(targets[:1]), targets[:-1]], dim=0)\n", | |
" inputs = torch.cat([coords, shifted.long()], dim=-1)\n", | |
" yield inputs.to(device), targets.to(device)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Change this as needed.\n", | |
"device = torch.device('cuda')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"model = XYSeqModel()\n", | |
"model.to(device)\n", | |
"opt = optim.Adam(model.parameters(), lr=1e-3)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"BATCH_SIZE = 64\n", | |
"train_data = binarized_pixel_sequences(device, train=True, batch=BATCH_SIZE)\n", | |
"test_data = binarized_pixel_sequences(device, train=False, batch=BATCH_SIZE)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "49344736fb254aa690010240285af3bf", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"# Train on the training set\n", | |
"for i in tqdm(range(200000 // BATCH_SIZE)):\n", | |
" train_in, train_targ = next(train_data)\n", | |
" train_loss = model.compute_loss(train_in, train_targ)\n", | |
" opt.zero_grad()\n", | |
" train_loss.backward()\n", | |
" opt.step()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "78f57225d98a4f2f88709d8d2792120b", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"HBox(children=(IntProgress(value=0, max=156), HTML(value='')))" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"\n", | |
"81.71128262311984\n" | |
] | |
} | |
], | |
"source": [ | |
"# Evaluate on the test set.\n", | |
"losses = []\n", | |
"for i in tqdm(range(10000 // BATCH_SIZE)):\n", | |
" test_in, test_targ = next(test_data)\n", | |
" with torch.no_grad():\n", | |
" test_loss = model.compute_loss(test_in, test_targ)\n", | |
" losses.append(test_loss.item())\n", | |
"print(28 * 28 * sum(losses) / len(losses))" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.3" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 4 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment