Skip to content

Instantly share code, notes, and snippets.

@unixpickle
Created June 25, 2020 23:03
Show Gist options
  • Save unixpickle/f02b9e8881081fa8d033c2e723d01f89 to your computer and use it in GitHub Desktop.
Save unixpickle/f02b9e8881081fa8d033c2e723d01f89 to your computer and use it in GitHub Desktop.
Binarized mnist
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"import torch.optim as optim\n",
"from torchvision import datasets, transforms\n",
"from tqdm.auto import tqdm"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"class XYSeqModel(nn.Module):\n",
" \"\"\"\n",
" A sequence model that takes in (x, y, previous_pixel)\n",
" and outputs binary log probabilities.\n",
" \"\"\"\n",
"\n",
" def __init__(self, layers=3):\n",
" super().__init__()\n",
" # Separate embeddings for each input.\n",
" self.x_embed = nn.Embedding(28, 128)\n",
" self.y_embed = nn.Embedding(28, 128)\n",
" self.prev_embed = nn.Embedding(2, 128)\n",
" # Simple LSTM with linear output layer.\n",
" self.rnn = nn.LSTM(128*3, 256, num_layers=layers)\n",
" self.out_layer = nn.Linear(256, 1)\n",
" # Learnable initial hidden state.\n",
" for i in range(2):\n",
" p = nn.Parameter(torch.zeros([layers, 256], dtype=torch.float))\n",
" self.register_parameter('hidden_%d' % i, p)\n",
"\n",
" def forward(self, inputs, hidden=None):\n",
" \"\"\"\n",
" Apply the model to a sequence.\n",
"\n",
" Args:\n",
" inputs: a [T x N x 3] integer batch of (x, y, previous).\n",
" hidden: if specified, a tuple of initial hidden states.\n",
" This is mostly useful for sampling.\n",
"\n",
" Returns:\n",
" A tuple (outputs, hidden):\n",
" outputs: an [N x 1] batch of logits.\n",
" hidden: a tuple of final hidden states, \n",
" \"\"\"\n",
" x_vec = self.x_embed(inputs[:, :, 0])\n",
" y_vec = self.y_embed(inputs[:, :, 1])\n",
" prev_vec = self.prev_embed(inputs[:, :, 2])\n",
" x = torch.cat([x_vec, y_vec, prev_vec], dim=-1)\n",
"\n",
" batch = x.shape[1]\n",
" if hidden is None:\n",
" init_hidden = (self.hidden_0, self.hidden_1)\n",
" hidden = tuple(h[:, None].repeat(1, batch, 1) for h in init_hidden)\n",
"\n",
" outputs = []\n",
" for t in range(inputs.shape[0]):\n",
" outs, hidden = self.rnn(x[t:t+1], hidden)\n",
" outs = self.out_layer(outs.view(batch, -1)).view(1, batch, 1)\n",
" outputs.append(outs)\n",
"\n",
" return torch.cat(outputs, dim=0), hidden\n",
" \n",
" def compute_loss(self, inputs, targets):\n",
" \"\"\"\n",
" Compute the binary cross-entropy loss (NLL).\n",
" \"\"\"\n",
" logits, _ = self(inputs)\n",
" return F.binary_cross_entropy_with_logits(logits, targets)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"def binarized_pixel_sequences(device, train, batch):\n",
" \"\"\"\n",
" Iterate over tuples (inputs, targets):\n",
" inputs: an [T x N x 3] tensor of shifted inputs.\n",
" targets: a [T x N x 1] tensor of targets.\n",
" \"\"\"\n",
" mnist = datasets.MNIST('data', train=train, download=True,\n",
" transform=transforms.ToTensor())\n",
" loader = torch.utils.data.DataLoader(mnist, batch_size=batch, shuffle=True, drop_last=True)\n",
" xs = torch.tensor([[i % 28 for i in range(28*28)]] * batch)\n",
" ys = torch.tensor([[i // 28 for i in range(28*28)]] * batch)\n",
" coords = torch.stack([xs, ys], dim=-1).permute(1, 0, 2).long()\n",
" while True:\n",
" for images, _ in loader:\n",
" # Binarize stochastically.\n",
" binarized = (images > torch.rand_like(images)).view(images.shape[0], -1)\n",
" targets = binarized.float().permute(1, 0)[..., None]\n",
" shifted = torch.cat([torch.zeros_like(targets[:1]), targets[:-1]], dim=0)\n",
" inputs = torch.cat([coords, shifted.long()], dim=-1)\n",
" yield inputs.to(device), targets.to(device)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# Change this as needed.\n",
"device = torch.device('cuda')"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"model = XYSeqModel()\n",
"model.to(device)\n",
"opt = optim.Adam(model.parameters(), lr=1e-3)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"BATCH_SIZE = 64\n",
"train_data = binarized_pixel_sequences(device, train=True, batch=BATCH_SIZE)\n",
"test_data = binarized_pixel_sequences(device, train=False, batch=BATCH_SIZE)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "49344736fb254aa690010240285af3bf",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"# Train on the training set\n",
"for i in tqdm(range(200000 // BATCH_SIZE)):\n",
" train_in, train_targ = next(train_data)\n",
" train_loss = model.compute_loss(train_in, train_targ)\n",
" opt.zero_grad()\n",
" train_loss.backward()\n",
" opt.step()"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "78f57225d98a4f2f88709d8d2792120b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, max=156), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"81.71128262311984\n"
]
}
],
"source": [
"# Evaluate on the test set.\n",
"losses = []\n",
"for i in tqdm(range(10000 // BATCH_SIZE)):\n",
" test_in, test_targ = next(test_data)\n",
" with torch.no_grad():\n",
" test_loss = model.compute_loss(test_in, test_targ)\n",
" losses.append(test_loss.item())\n",
"print(28 * 28 * sum(losses) / len(losses))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment