Skip to content

Instantly share code, notes, and snippets.

@stsievert
Last active February 14, 2022 21:20
Show Gist options
  • Save stsievert/8d42ebb35499e37e0ab55d7156f12fdf to your computer and use it in GitHub Desktop.
Save stsievert/8d42ebb35499e37e0ab55d7156f12fdf to your computer and use it in GitHub Desktop.
PyTorch MNIST autoencoder
from keras.datasets import mnist
import numpy as np
import skimage.util
import random
import skimage.filters
import skimage
import scipy.signal
def noise_img(x):
noises = [
{"mode": "s&p", "amount": np.random.uniform(0.1, 0.1)},
{"mode": "gaussian", "var": np.random.uniform(0.0, 0.10)},
]
# noise = random.choice(noises)
noise = noises[1]
return skimage.util.random_noise(x, **noise)
def train_formatting(img):
img = img.reshape(28, 28).astype("float32")
return img.flat[:]
def blur_img(img):
assert img.ndim == 1
n = int(np.sqrt(img.shape[0]))
img = img.reshape(n, n)
h = np.zeros((n, n))
angle = np.random.uniform(-5, 5)
w = random.choice(range(1, 3))
h[n // 2, n // 2 - w : n // 2 + w] = 1
h = skimage.transform.rotate(h, angle)
h /= h.sum()
y = scipy.signal.convolve(img, h, mode="same")
return y.flat[:]
def dataset(n=None):
(x_train, _), (x_test, _) = mnist.load_data()
x = np.concatenate((x_train, x_test))
if n:
x = x[:n]
else:
n = int(70e3)
x = x.astype("float32") / 255.
x = np.reshape(x, (len(x), 28 * 28))
y = np.apply_along_axis(train_formatting, 1, x)
clean = y.copy()
noisy = y.copy()
# order = [noise_img, blur_img]
# order = [blur_img]
order = [noise_img]
random.shuffle(order)
for fn in order:
noisy = np.apply_along_axis(fn, 1, noisy)
noisy = noisy.reshape(-1, 1, 28, 28).astype("float32")
clean = clean.reshape(-1, 1, 28, 28).astype("float32")
return noisy, clean
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This notebook aims to show a simple example with an autoencoder. The highlights of this notebook are that\n",
"\n",
"* we can customize problem difficulty (because we control the input)\n",
"* we can customize the model expressiveness to capture problem difficulty (well, it's a todo)\n",
"\n",
"I will spend some time manually tuning these to make it a realistic problem. After this is done, we have 400 parameter combinations, each with 2 contininous variables to tune. The variables we have to tune are\n",
"\n",
"* model\n",
" * initialization\n",
" * activation function\n",
" * weight decay (which is similar to $\\ell_2$ regularization)\n",
"* optimizer\n",
" * which optimizer to use (e.g., Adam, SGD)\n",
" * batch size used to approximate gradient\n",
" * learning rate (but not for Adam)\n",
" * momentum for SGD\n",
"\n",
"I believe this is an interesting problem. For example, switching activation from ReLU to PReLU caused the loss after 1st epoch to go from ~0.20 to 0.12."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Setup"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'0.4.1'"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import torch\n",
"torch.__version__"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using TensorFlow backend.\n",
"/home/ssievert/miniconda3/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: numpy.dtype size changed, may indicate binary incompatibility. Expected 96, got 88\n",
" return f(*args, **kwds)\n"
]
}
],
"source": [
"import numpy as np\n",
"import noisy_mnist"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Input data\n",
"`autoencoder` is a package I wrote to noise MNIST images. It outputs a NumPy ndarray of noisy images."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"import noisy_mnist\n",
"noisy, clean = noisy_mnist.dataset()#n=1024)\n",
"assert isinstance(noisy, np.ndarray)\n",
"\n",
"from sklearn.model_selection import train_test_split\n",
"noisy_train, noisy_test, clean_train, clean_test = train_test_split(noisy, clean)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(52500, 1, 28, 28)"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"noisy_train.shape"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 864x144 with 24 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import seaborn as sns\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"import numpy as np\n",
"\n",
"w = 1.0\n",
"ncols = 12\n",
"idx = np.random.choice(len(clean_test), size=ncols)\n",
"fig, axs = plt.subplots(nrows=2, ncols=ncols, figsize=(ncols*w, 2*w))\n",
"for col, (lower, upper, i) in enumerate(zip(axs[1], axs[0], idx)):\n",
" if col == 0:\n",
" upper.text(-40, 15, 'ground truth')\n",
" lower.text(-20, 15, 'input')\n",
" clean = clean_train[i].reshape(28, 28)\n",
" noisy = noisy_train[i].reshape(28, 28)\n",
" kwargs = {'cbar': False, 'xticklabels': False, 'yticklabels': False, 'cmap': 'gray'}\n",
" sns.heatmap(clean, ax=upper, **kwargs)\n",
" sns.heatmap(noisy, ax=lower, **kwargs)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This input has random amounts of noise – it ranges everywhere from perfectly clean images to moderately corrupted images."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Model creation"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"import torch.nn as nn\n",
"import torch.nn.functional as F"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"import skorch"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This convolutional model is adaptived from https://github.com/Kaixhin/Autoencoders/blob/master/models/ConvAE.lua, which is turn adapted from https://blog.keras.io/building-autoencoders-in-keras.html.\n",
"\n",
"I have allowed custimization of\n",
"\n",
"* the activation function\n",
"* the initialization scheme\n",
"* **TODO** width/depth to allow more complex models (which is suited for random noise)\n",
"\n",
"**Notes on other models**\n",
"* https://github.com/baldassarreFe/zalando-pytorch/blob/master/notebooks/4.0-fb-autoencoder.ipynb\n",
"* Maybe go with VGG deconv net at https://github.com/csgwon/pytorch-deconvnet/blob/master/models/vgg16_deconv.py (will require ImageNet)\n",
"* A simple DAE is at https://github.com/ReyhaneAskari/pytorch_experiments/blob/master/DAE.py"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"import torch.nn as nn\n",
"from toolz import partial\n",
" \n",
"def _initialize(method, layer, gain=1):\n",
" weight = layer.weight.data\n",
" _before = weight.data.clone()\n",
" kwargs = {'gain': gain} if 'xavier' in str(method) else {}\n",
" method(weight.data, **kwargs)\n",
" assert torch.all(weight.data != _before)\n",
" \n",
"class Autoencoder(nn.Module):\n",
" \"\"\" Autoencoder adapted from [1]\n",
" \n",
" [1]:https://github.com/Kaixhin/Autoencoders/blob/master/models/ConvAE.lua\n",
" \"\"\"\n",
" def __init__(self, activation='ReLU', init='xavier_uniform_', width_factor=1):\n",
" super().__init__()\n",
" device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
" \n",
" self.activation = activation\n",
" self.init = init\n",
" \n",
" init_method = getattr(torch.nn.init, init)\n",
" act_layer = getattr(nn, activation)\n",
" print(act_layer)\n",
" \n",
" gain = 1\n",
" if self.activation in ['LeakyReLU', 'ReLU']:\n",
" name = 'leaky_relu' if self.activation == 'LeakyReLU' else 'relu'\n",
" gain = torch.nn.init.calculate_gain(name)\n",
" if self.activation != 'PReLU':\n",
" act_layer = partial(act_layer, inplace=True)\n",
" \n",
" width = int(width_factor * 32)\n",
" layers = [nn.Conv2d(1, width, kernel_size=3, padding=1, stride=1),\n",
" act_layer(),\n",
" nn.MaxPool2d(2, stride=2, padding=1),\n",
" nn.Conv2d(width, width, kernel_size=3, padding=1, stride=1),\n",
" act_layer(),\n",
" nn.MaxPool2d(2, stride=2, padding=0)]\n",
" for layer in layers:\n",
" if hasattr(layer, 'weight') and layer.weight.data.dim() > 1:\n",
" _initialize(init_method, layer)\n",
" \n",
" self.encoder = nn.Sequential(*layers).to(device)\n",
" \n",
" modules = []\n",
" modules += [[nn.Conv2d(width, width, kernel_size=3, stride=1, padding=1),\n",
" act_layer()]]\n",
" modules += [[nn.Conv2d(width, width, kernel_size=3, stride=1, padding=1),\n",
" act_layer()]]\n",
" modules += [[nn.Conv2d(width, 1, kernel_size=3, stride=1, padding=1),\n",
" nn.Sigmoid()]]\n",
" self.decoders = []\n",
" for module in modules:\n",
" [_initialize(init_method, layer) for layer in module\n",
" if hasattr(layer, 'weight') and layer.weight.data.dim() > 1]\n",
" self.decoders += [nn.Sequential(*module).to(device)]\n",
" \n",
" def forward(self, x):\n",
" x = self.encoder(x)\n",
" \n",
" for i, decoder in enumerate(self.decoders):\n",
" x = decoder(x)\n",
" if i < len(self.decoders) - 1:\n",
" x = F.interpolate(x, scale_factor=2)\n",
" return x"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Hyperparameters"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"from pprint import pprint\n",
"class Trim(skorch.NeuralNetRegressor):\n",
" \"\"\"\n",
" This wrapper trims the arguments in `params` to make \n",
" sure they're accepted by the optimizer\n",
" \"\"\"\n",
" def __init__(self, verbose=True, **kwargs):\n",
" if kwargs['optimizer'] != 'Adam':\n",
" kwargs.pop('optimizer__amsgrad', None)\n",
" if kwargs['optimizer'] == 'Adam':\n",
" kwargs.pop('optimizer__lr', None)\n",
" if kwargs['optimizer'] != 'SGD':\n",
" kwargs.pop('optimizer__nesterov')\n",
" kwargs.pop('optimizer__momentum')\n",
" kwargs['optimizer'] = getattr(torch.optim, kwargs['optimizer'])\n",
" pprint({k: v for k, v in kwargs.items() if k != 'module'})\n",
" super().__init__(**kwargs)\n",
"\n",
"params = {\n",
" 'module__init': ['xavier_uniform_',\n",
" 'xavier_normal_',\n",
" 'kaiming_uniform_',\n",
" 'kaiming_normal_',\n",
" ],\n",
" 'module__activation': ['ReLU', 'LeakyReLU', 'ELU', 'PReLU'],\n",
" 'optimizer': ['SGD',\n",
" 'ASGD',\n",
" 'Adam',\n",
" 'Adagrad',\n",
" 'RMSprop'], # optimizers in Adam's paper + ASGD\n",
" 'batch_size': [32, 64, 128, 256, 512],\n",
" 'optimizer__lr': np.logspace(2, -2, num=1000), # all optimizers but Adam\n",
" 'optimizer__weight_decay': [0] + np.logspace(-8, -3, num=1000).tolist(), # all optimizers\n",
" 'optimizer__nesterov': [True], # only for SGD\n",
" 'optimizer__momentum': np.logspace(-4, 0, num=1000), # only for SGD\n",
"}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Fitting"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<class 'torch.nn.modules.activation.ReLU'>\n",
"{'batch_size': 64,\n",
" 'criterion': <class 'torch.nn.modules.loss.BCELoss'>,\n",
" 'device': 'cuda',\n",
" 'max_epochs': 20,\n",
" 'module__activation': 'PReLU',\n",
" 'module__init': 'xavier_uniform_',\n",
" 'optimizer': <class 'torch.optim.adam.Adam'>,\n",
" 'optimizer__weight_decay': 1e-05,\n",
" 'warm_start': True}\n"
]
}
],
"source": [
"from torchvision.datasets import FashionMNIST\n",
"from torchvision import transforms\n",
"\n",
"device = 'cpu'\n",
"model = Autoencoder()\n",
"if torch.cuda.is_available():\n",
" device = 'cuda'\n",
" model = model.to('cuda')\n",
" \n",
"net = Trim(\n",
" module=model,\n",
" module__init='xavier_uniform_',\n",
" module__activation='PReLU',\n",
" max_epochs=20,\n",
" optimizer='Adam',\n",
" optimizer__lr=1.0,\n",
" optimizer__weight_decay=1e-5,\n",
" optimizer__nesterov=True,\n",
" optimizer__momentum=1e-2,\n",
" batch_size=64,\n",
" # train_split=None, # if only use training data; we will eventually\n",
" #criterion=torch.nn.MSELoss,\n",
" criterion=torch.nn.BCELoss,\n",
" device='cuda',\n",
" warm_start=True,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Re-initializing module!\n",
"<class 'torch.nn.modules.activation.PReLU'>\n",
" epoch train_loss valid_loss dur\n",
"------- ------------ ------------ ------\n",
" 1 \u001b[36m0.1290\u001b[0m \u001b[32m0.1144\u001b[0m 4.5228\n",
" 2 \u001b[36m0.1053\u001b[0m \u001b[32m0.1074\u001b[0m 4.5657\n",
" 3 \u001b[36m0.1034\u001b[0m \u001b[32m0.1032\u001b[0m 4.5263\n",
" 4 \u001b[36m0.1021\u001b[0m \u001b[32m0.1018\u001b[0m 4.4973\n",
" 5 \u001b[36m0.1014\u001b[0m \u001b[32m0.1009\u001b[0m 4.5685\n",
" 6 \u001b[36m0.1006\u001b[0m \u001b[32m0.1000\u001b[0m 4.5073\n",
" 7 \u001b[36m0.1002\u001b[0m 0.1037 4.5860\n",
" 8 \u001b[36m0.0998\u001b[0m 0.1045 4.5521\n",
" 9 \u001b[36m0.0994\u001b[0m \u001b[32m0.0999\u001b[0m 4.5896\n",
" 10 \u001b[36m0.0991\u001b[0m 0.1010 4.5799\n",
" 11 \u001b[36m0.0989\u001b[0m \u001b[32m0.0996\u001b[0m 4.5751\n",
" 12 \u001b[36m0.0986\u001b[0m 0.1014 4.5972\n",
" 13 \u001b[36m0.0984\u001b[0m \u001b[32m0.0990\u001b[0m 4.4776\n",
" 14 \u001b[36m0.0982\u001b[0m \u001b[32m0.0987\u001b[0m 4.6029\n",
" 15 \u001b[36m0.0981\u001b[0m \u001b[32m0.0983\u001b[0m 4.6004\n",
" 16 \u001b[36m0.0980\u001b[0m \u001b[32m0.0982\u001b[0m 4.5679\n",
" 17 \u001b[36m0.0979\u001b[0m \u001b[32m0.0982\u001b[0m 4.5686\n",
" 18 \u001b[36m0.0979\u001b[0m \u001b[32m0.0979\u001b[0m 4.5877\n",
" 19 \u001b[36m0.0978\u001b[0m \u001b[32m0.0976\u001b[0m 4.5966\n",
" 20 \u001b[36m0.0978\u001b[0m 0.0989 4.6106\n"
]
},
{
"data": {
"text/plain": [
"<class '__main__.Trim'>[initialized](\n",
" module_=Autoencoder(\n",
" (encoder): Sequential(\n",
" (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (1): PReLU(num_parameters=1)\n",
" (2): MaxPool2d(kernel_size=2, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
" (3): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (4): PReLU(num_parameters=1)\n",
" (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
" )\n",
" ),\n",
")"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"net.fit(noisy_train, clean_train)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Output"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"clean_test_hat = net.forward(noisy_test)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"scrolled": false
},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 864x216 with 36 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"w = 1.0\n",
"ncols = 12\n",
"# idx = np.random.choice(len(clean_test), size=ncols)\n",
"\n",
"fig, axs = plt.subplots(nrows=3, ncols=ncols, figsize=(ncols*w, 3*w))\n",
"for col, (lower, middle, upper, i) in enumerate(zip(axs[2], axs[1], axs[0], idx)):\n",
" if col == 0:\n",
" upper.text(-40, 15, 'ground truth')\n",
" middle.text(-20, 15, 'input')\n",
" lower.text(-25, 15, 'output')\n",
"\n",
" i = np.random.randint(len(clean_test))\n",
" clean = clean_test[i].reshape(28, 28)\n",
" predicted = clean_test_hat[i].reshape(28, 28)\n",
" noisy = noisy_test[i].reshape(28, 28)\n",
" kwargs = {'cbar': False, 'xticklabels': False, 'yticklabels': False, 'cmap': 'gray'}\n",
" sns.heatmap(clean, ax=upper, **kwargs)\n",
" sns.heatmap(noisy, ax=middle, **kwargs)\n",
" sns.heatmap(predicted, ax=lower, **kwargs)\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment