Skip to content

Instantly share code, notes, and snippets.

@stsievert
Last active September 10, 2018 18:52
Show Gist options
  • Save stsievert/33bdc47b52ffc085dddd75e9b719cc07 to your computer and use it in GitHub Desktop.
Save stsievert/33bdc47b52ffc085dddd75e9b719cc07 to your computer and use it in GitHub Desktop.
Testing patience for hyperparam search
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Prior work for default patience:\n",
"\n",
"* `torch.optim.ReduceLROnPlateau` default to `patience=10` epochs\n",
" * https://pytorch.org/docs/stable/optim.html#torch.optim.lr_scheduler.ReduceLROnPlateau\n",
"* `keras.ReduceOnPlateau` defaults to `patience=10` epochs.\n",
" * https://github.com/tensorflow/tensorflow/blob/4dcfddc5d12018a5a0fdca652b9221ed95e9eb23/tensorflow/python/keras/callbacks.py#L891\n",
"* MXNet defaults to `patience=10` epochs in their ReduceLROnPlateau\n",
" * https://github.com/awslabs/keras-apache-mxnet/blob/01d59d3f91ffb13d73cadc11db94a30e4b05c2f8/keras/callbacks.py#L991\n",
"* \"Random Search for Hyper-Parameter Optimization\" says \"We permitted a minimum of 100 and a maximum of 1000 iterations over the training data, stopping if ever, at iteration $t$ , [if] the best validation performance was observed before iteration $t / 2$.\"\n",
" * by Bengio et. al. http://www.jmlr.org/papers/volume13/bergstra12a/bergstra12a.pdf (mentioned in the sklearn docs at http://scikit-learn.org/stable/modules/grid_search.html#randomized-parameter-optimization)\n",
" \n",
"An approach like the last seems ideal and similar to the doubling trick. But they have a minimum number of iterations; that's what we're trying to choose.\n",
"\n",
"I think a static patience parameter is best.\n",
"\n",
"I am inclined to have this in terms of epochs; I think there's some paper that says \"run for X epochs\" for convergence.\n",
"\n",
"It looks like we should default to 10; that's what everyone else is doing (including Facebook, Amazon and Google)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Setup"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table style=\"border: 2px solid white;\">\n",
"<tr>\n",
"<td style=\"vertical-align: top; border: 0px solid white\">\n",
"<h3>Client</h3>\n",
"<ul>\n",
" <li><b>Scheduler: </b>tcp://dask-scheduler:8786\n",
" <li><b>Dashboard: </b><a href='http://dask-scheduler:8787/status' target='_blank'>http://dask-scheduler:8787/status</a>\n",
"</ul>\n",
"</td>\n",
"<td style=\"vertical-align: top; border: 0px solid white\">\n",
"<h3>Cluster</h3>\n",
"<ul>\n",
" <li><b>Workers: </b>16</li>\n",
" <li><b>Cores: </b>32</li>\n",
" <li><b>Memory: </b>96.00 GB</li>\n",
"</ul>\n",
"</td>\n",
"</tr>\n",
"</table>"
],
"text/plain": [
"<Client: scheduler='tcp://10.52.111.5:8786' processes=16 cores=32>"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import distributed\n",
"from distributed import Client\n",
"client = Client()\n",
"client"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'tcp://10.52.101.3:32807': 2,\n",
" 'tcp://10.52.102.3:42949': 2,\n",
" 'tcp://10.52.103.4:39683': 2,\n",
" 'tcp://10.52.104.3:35883': 2,\n",
" 'tcp://10.52.105.3:44301': 2,\n",
" 'tcp://10.52.106.3:44269': 2,\n",
" 'tcp://10.52.107.3:45867': 2,\n",
" 'tcp://10.52.108.3:34393': 2,\n",
" 'tcp://10.52.109.3:35189': 2,\n",
" 'tcp://10.52.110.3:39023': 2,\n",
" 'tcp://10.52.111.4:36131': 2,\n",
" 'tcp://10.52.112.3:41667': 2,\n",
" 'tcp://10.52.114.3:43839': 2,\n",
" 'tcp://10.52.115.3:46805': 2,\n",
" 'tcp://10.52.116.3:36847': 2,\n",
" 'tcp://10.52.117.3:34885': 2}"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"client.ncores()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"import subprocess\n",
"def debug_loop():\n",
" subprocess.call(\"pip install git+https://github.com/stsievert/dask-ml@hyperband-scale\".split(\" \"))\n",
" import dask_ml\n",
" return dask_ml.__version__"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 136 ms, sys: 21.6 ms, total: 157 ms\n",
"Wall time: 3.26 s\n"
]
},
{
"data": {
"text/plain": [
"'0.4.2.dev395+g7b1cf92'"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%time debug_loop()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"from dask_ml.model_selection._successive_halving import _SHA\n",
"#_SHA.fit??"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 11.4 ms, sys: 2.42 ms, total: 13.9 ms\n",
"Wall time: 3.67 s\n"
]
},
{
"data": {
"text/plain": [
"{'tcp://10.52.101.3:32807': '0.4.2.dev395+g7b1cf92',\n",
" 'tcp://10.52.102.3:42949': '0.4.2.dev395+g7b1cf92',\n",
" 'tcp://10.52.103.4:39683': '0.4.2.dev395+g7b1cf92',\n",
" 'tcp://10.52.104.3:35883': '0.4.2.dev395+g7b1cf92',\n",
" 'tcp://10.52.105.3:44301': '0.4.2.dev395+g7b1cf92',\n",
" 'tcp://10.52.106.3:44269': '0.4.2.dev395+g7b1cf92',\n",
" 'tcp://10.52.107.3:45867': '0.4.2.dev395+g7b1cf92',\n",
" 'tcp://10.52.108.3:34393': '0.4.2.dev395+g7b1cf92',\n",
" 'tcp://10.52.109.3:35189': '0.4.2.dev395+g7b1cf92',\n",
" 'tcp://10.52.110.3:39023': '0.4.2.dev395+g7b1cf92',\n",
" 'tcp://10.52.111.4:36131': '0.4.2.dev395+g7b1cf92',\n",
" 'tcp://10.52.112.3:41667': '0.4.2.dev395+g7b1cf92',\n",
" 'tcp://10.52.114.3:43839': '0.4.2.dev395+g7b1cf92',\n",
" 'tcp://10.52.115.3:46805': '0.4.2.dev395+g7b1cf92',\n",
" 'tcp://10.52.116.3:36847': '0.4.2.dev395+g7b1cf92',\n",
" 'tcp://10.52.117.3:34885': '0.4.2.dev395+g7b1cf92'}"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%time client.run(debug_loop)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 3.5 ms, sys: 1.44 ms, total: 4.95 ms\n",
"Wall time: 206 ms\n"
]
},
{
"data": {
"text/html": [
"<table style=\"border: 2px solid white;\">\n",
"<tr>\n",
"<td style=\"vertical-align: top; border: 0px solid white\">\n",
"<h3>Client</h3>\n",
"<ul>\n",
" <li><b>Scheduler: </b>tcp://dask-scheduler:8786\n",
" <li><b>Dashboard: </b><a href='http://dask-scheduler:8787/status' target='_blank'>http://dask-scheduler:8787/status</a>\n",
"</ul>\n",
"</td>\n",
"<td style=\"vertical-align: top; border: 0px solid white\">\n",
"<h3>Cluster</h3>\n",
"<ul>\n",
" <li><b>Workers: </b>16</li>\n",
" <li><b>Cores: </b>32</li>\n",
" <li><b>Memory: </b>96.00 GB</li>\n",
"</ul>\n",
"</td>\n",
"</tr>\n",
"</table>"
],
"text/plain": [
"<Client: scheduler='tcp://10.52.111.5:8786' processes=16 cores=32>"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%time client.restart()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 4.41 ms, sys: 3.43 ms, total: 7.85 ms\n",
"Wall time: 985 ms\n"
]
}
],
"source": [
"%time client.upload_file('autoencoder.py')"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'0.4.2.dev395+g7b1cf92'"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from dask_ml.model_selection._successive_halving import stop_on_plateau\n",
"from dask_ml.model_selection import HyperbandCV\n",
"import dask_ml\n",
"dask_ml.__version__"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Data\n",
"See below for an image."
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using TensorFlow backend.\n"
]
}
],
"source": [
"import noisy_mnist\n",
"_X, _y = noisy_mnist.dataset()#n=10 * 1024)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(dask.array<array, shape=(70000, 784), dtype=float32, chunksize=(35000, 784)>,\n",
" dask.array<array, shape=(70000, 784), dtype=float32, chunksize=(35000, 784)>)"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import dask.array as da\n",
"n, d = _X.shape\n",
"X = da.from_array(_X, chunks=(n // 2, d))\n",
"y = da.from_array(_y, chunks=n // 2)\n",
"X, y"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 360x144 with 10 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"cols = 5\n",
"w = 1.0\n",
"fig, axs = plt.subplots(figsize=(cols*w, 2*w), ncols=cols, nrows=2)\n",
"for col, (upper, lower) in enumerate(zip(axs[0], axs[1])):\n",
" if col == 0:\n",
" upper.text(-28, 14, 'ground\\ntruth')\n",
" lower.text(-28, 14, 'input')\n",
" i = np.random.choice(len(X))\n",
" noisy = X[i].reshape(28, 28)\n",
" clean = y[i].reshape(28, 28)\n",
" kwargs = {'cbar': False, 'xticklabels': False, 'yticklabels': False, 'cmap': 'gray'}\n",
" sns.heatmap(noisy, ax=lower, **kwargs)\n",
" sns.heatmap(clean, ax=upper, **kwargs)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"I use a deep learning library (PyTorch) for this model, at least through the scikit-learn interface for PyTorch, [skorch].\n",
"\n",
"[skorch]:https://github.com/dnouri/skorch"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"from autoencoder import Autoencoder, NegLossScore\n",
"import torch\n",
"\n",
"model = NegLossScore(module=Autoencoder,\n",
" criterion=torch.nn.BCELoss,\n",
" warm_start=True,\n",
" train_split=None,\n",
" max_epochs=1,\n",
" callbacks=[])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"I don't show it here; I'd rather concentrate on tuning hyperparameters. But briefly, it's a simple fully connected 3 hidden layer autoencoder with a latent dimension of 49."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Parameters\n",
"\n",
"The parameters I am interested in tuning are\n",
"\n",
"* model\n",
" * initialization\n",
" * activation function\n",
" * weight decay (which is similar to $\\ell_2$ regularization)\n",
"* optimizer\n",
" * which optimizer to use (e.g., Adam, SGD)\n",
" * batch size used to approximate gradient\n",
" * learning rate (but not for Adam)\n",
" * momentum for SGD\n",
" \n",
"After looking at the results, I think I was too exploratory in my tuning of step size. I should have experimented with it more to determine a reasonable range."
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"\n",
"params = {\n",
" 'module__init': ['xavier_uniform_',\n",
" 'xavier_normal_',\n",
" 'kaiming_uniform_',\n",
" 'kaiming_normal_',\n",
" ],\n",
" 'module__activation': ['ReLU', 'LeakyReLU', 'ELU', 'PReLU'],\n",
" 'optimizer': ['SGD'] * 5 + ['Adam'] * 2,\n",
" 'batch_size': [32, 64, 128, 256, 512],\n",
" 'optimizer__lr': np.logspace(1, -1.5, num=1000),\n",
" 'optimizer__weight_decay': [0]*200 + np.logspace(-7, -3, num=1000).tolist(),\n",
" 'optimizer__nesterov': [True],\n",
" 'optimizer__momentum': np.linspace(0, 1, num=1000),\n",
" 'train_split': [None],\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.model_selection import ParameterSampler\n",
"import torch\n",
"\n",
"def trim_params(**kwargs):\n",
" if kwargs['optimizer'] != 'Adam':\n",
" kwargs.pop('optimizer__amsgrad', None)\n",
" if kwargs['optimizer'] == 'Adam':\n",
" kwargs.pop('optimizer__lr', None)\n",
" if kwargs['optimizer'] != 'SGD':\n",
" kwargs.pop('optimizer__nesterov', None)\n",
" kwargs.pop('optimizer__momentum', None)\n",
" kwargs['optimizer'] = getattr(torch.optim, kwargs['optimizer'])\n",
" return kwargs"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"# # for debugging; ignore this cell\n",
"# from sklearn.linear_model import SGDClassifier\n",
"# from sklearn.datasets import make_classification\n",
"# from sklearn.model_selection import ParameterSampler\n",
"# import dask.array as da\n",
"# import numpy as np\n",
"# model = SGDClassifier()\n",
"# params = {'alpha': np.logspace(-7, 0, num=int(1e6))}\n",
"\n",
"# n, d = int(10e3), 700\n",
"# _X = np.random.rand(n, d)\n",
"# _beta = np.random.rand(d)\n",
"# _y = np.sign(_X @ _beta + d * 0.1 * np.random.randn(n))\n",
"# X = da.from_array(_X, chunks=(n // 10, d))\n",
"# y = da.from_array(_y, chunks=n // 10)\n",
"# X, y"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Hyperparameter optimization"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"from dask_ml.model_selection import train_test_split\n",
"from sklearn.linear_model import SGDClassifier\n",
"from dask_ml.model_selection import HyperbandCV\n",
"\n",
"def test_hyperband(model, params, X, y, max_iter=27, patience=np.inf, tol=1e-4):\n",
" fit_params = {}\n",
" if isinstance(model, SGDClassifier):\n",
" fit_params = {'classes': da.unique(y).compute()}\n",
" param_list = list(ParameterSampler(params, max_iter * 100))\n",
" else:\n",
" param_list = [trim_params(**param)\n",
" for param in ParameterSampler(params, max_iter * 100)]\n",
" \n",
" search = HyperbandCV(model, param_list, max_iter, patience=patience)\n",
" search.fit(X, y, **fit_params)\n",
" \n",
" meta = {'max_iter': max_iter, 'patience': patience, 'tol': tol, \"alg\": \"hyperband\"}\n",
" [h.update(meta) for h in search.history_]\n",
" return search, search.history_"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"X_train, X_test, y_train, y_test = train_test_split(X, y)\n",
"max_iter = 243"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"all_history = {}"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"20 0 -0.09480269998311996\n",
"40 0 -0.09297627210617065\n",
"243 0 -0.09296257048845291\n",
"20 1 -0.09345395863056183\n",
"40 1 -0.09191805869340897\n",
"243 1 -0.09264714270830154\n"
]
}
],
"source": [
"searches = []\n",
"\n",
"# [inf, 5, 10] epochs respectively\n",
"# P = [max_iter, 20, 10]\n",
"P = [20, 40, max_iter]\n",
"# 10, 20, 40 and 121 epochs respectively\n",
"for p in P:\n",
" all_history[f\"hyperband-p{p}\"] = []\n",
"for _ in range(3):\n",
" for p in P:\n",
" search, hist = test_hyperband(model, params, X_train, y_train, max_iter=max_iter, patience=p)\n",
" if p == max_iter:\n",
" searches += [search]\n",
" all_history[f\"hyperband-p{p}\"] += [hist]\n",
" print(p, _, search.best_score_)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Visualizing output of best estimator"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"noisy_test = X_test.compute()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"search = searches[0]\n",
"clean_hat = search.best_estimator_.predict(noisy_test)\n",
"clean_hat.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"cols = 5\n",
"w = 1.0\n",
"fig, axs = plt.subplots(figsize=(cols*w, 3*w), ncols=cols, nrows=3)\n",
"for col, (upper, middle, lower) in enumerate(zip(axs[0], axs[1], axs[2])):\n",
" if col == 0:\n",
" upper.text(-28, 14, 'ground\\ntruth')\n",
" middle.text(-28, 14, 'input')\n",
" lower.text(-28, 14, 'output')\n",
" i = np.random.choice(len(X_test))\n",
" noisy = X_test[i].reshape(28, 28)\n",
" clean = y_test[i].reshape(28, 28)\n",
" clean_hat_i = clean_hat[i].reshape(28, 28)\n",
" kwargs = {'cbar': False, 'xticklabels': False, 'yticklabels': False, 'cmap': 'gray'}\n",
" sns.heatmap(noisy, ax=middle, **kwargs)\n",
" sns.heatmap(clean, ax=upper, **kwargs)\n",
" sns.heatmap(clean_hat_i, ax=lower, **kwargs)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Setting parameters for Hyperband\n",
"Need to know two things:\n",
"\n",
"1. how many \"epochs\" or \"passes through data\" to train model\n",
"2. how many configs to evaluate\n",
" * this is some measure of how complex the search space is\n",
" \n",
"This determines\n",
"\n",
"* The `max_iter` argument for `HyperbandCV`\n",
"* the chunks size for the array to pass in\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Comparison with early stopping"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from dask_ml.model_selection._successive_halving import _HistoryRecorder, stop_on_plateau\n",
"from dask_ml.model_selection._incremental import fit\n",
"from dask_ml.model_selection import train_test_split\n",
"from sklearn.model_selection import ParameterSampler\n",
"import random\n",
"\n",
"def test_rand(model, params, X, y, max_iter, num_models, num_calls, patience=10, tol=1e-4):\n",
" rand_search = _HistoryRecorder(stop_on_plateau, patience=patience, tol=tol, max_iter=num_calls)\n",
" X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.15)\n",
" \n",
" if isinstance(model, SGDClassifier):\n",
" rand_params = list(ParameterSampler(params, int(num_models)))\n",
" fit_params = {'classes': da.unique(y).compute()}\n",
" else:\n",
" rand_params = [trim_params(**param)\n",
" for param in ParameterSampler(params, int(num_models))]\n",
" fit_params = {}\n",
"\n",
" _ = fit(\n",
" model,\n",
" rand_params,\n",
" X_train,\n",
" y_train,\n",
" X_test,\n",
" y_test,\n",
" additional_partial_fit_calls=rand_search.fit,\n",
" fit_params=fit_params,\n",
" random_state=42\n",
" )\n",
" meta = {'max_iter': max_iter, 'patience': patience, 'tol': tol, \"alg\": \"stop_on_plateau\"}\n",
" [h.update(meta) for h in rand_search.history]\n",
" return rand_search.history"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"total_calls = search.metadata()['partial_fit_calls']\n",
"num_calls = max_iter\n",
"num_models = max(sum(client.ncores().values()), total_calls // num_calls)\n",
"num_calls, num_models, search.metadata()['partial_fit_calls']"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#P = [max_iter, 10, 20]\n",
"P = [20, 40, 80, max_iter]\n",
"P = [20, 40, 80, 120, max_iter]\n",
"for p in P:\n",
" all_history[f\"random-p{p}\"] = []\n",
"for p in P:\n",
" for _ in range(3):\n",
" all_history[f\"random-p{p}\"] += [test_rand(model, params, X, y, max_iter, num_models, num_calls, patience=p)]\n",
" print(p, _)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Performance\n",
"`HyperbandCV` will find close to the best possible parameters with the given computational budget.*\n",
"\n",
"<sup>* \"will\" := with high probability,\n",
"\"close\" := within log factors,\n",
"\"best possible\" in expected value.</sup>\n"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'hi'"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"'hi'"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
"from pprint import pprint\n",
"import toolz\n",
"\n",
"def shape_history(hist, **kwargs):\n",
" history = sorted(hist, key=lambda item: item['wall_time'])\n",
" \n",
" out = []\n",
" scores = {}\n",
" calls = {}\n",
" train_time = {}\n",
" \n",
" start = min(h['wall_time'] for h in history)\n",
" for h in history:\n",
" scores[h['model_id']] = h['score']\n",
" calls[h['model_id']] = h['partial_fit_calls']\n",
" train_time[h['model_id']] = h['partial_fit_time'] + h['score_time']\n",
" p = h[\"patience\"]\n",
" out += [{'wall_time': h['wall_time'] - start,\n",
" 'best_score': max(scores.values()),\n",
" 'cumulative_partial_fit_calls': sum(calls.values()),\n",
" 'alg': h['alg'],\n",
" 'adaptive': not \"hyperband\" in h[\"alg\"],\n",
" 'train_time': sum(train_time.values()),\n",
" 'model_id': h[\"model_id\"],\n",
" 'patience': p if not np.isinf(p) else -1,\n",
" 'patience_': \"p=\" + str(p) if not np.isinf(p) else \"p=inf\",\n",
" 'tol': h[\"tol\"],\n",
" 'base_alg': \"hyperband\" if \"hyperband\" in h[\"alg\"] else \"stop_on_plateau\",\n",
" **kwargs\n",
" }]\n",
" return out\n"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"shaped_history = [shape_history(h, repeat=repeat, alg=alg)\n",
" for alg, hist in all_history.items()\n",
" for repeat, h in enumerate(hist)]\n",
"alg_shaped_histories = {alg: [shape_history(h, repeat=repeat, alg=alg)\n",
" for repeat, h in enumerate(hist)]\n",
" for alg, hist in all_history.items()}\n",
"history = sum(shaped_history, [])\n"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['hyperband-p20' 'hyperband-p40' 'hyperband-p243']\n",
"[ 20 40 243]\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>adaptive</th>\n",
" <th>alg</th>\n",
" <th>base_alg</th>\n",
" <th>best_score</th>\n",
" <th>cumulative_partial_fit_calls</th>\n",
" <th>model_id</th>\n",
" <th>patience</th>\n",
" <th>patience_</th>\n",
" <th>repeat</th>\n",
" <th>tol</th>\n",
" <th>train_time</th>\n",
" <th>wall_time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>False</td>\n",
" <td>hyperband-p20</td>\n",
" <td>hyperband</td>\n",
" <td>-0.542950</td>\n",
" <td>1</td>\n",
" <td>bracket=2-0</td>\n",
" <td>20</td>\n",
" <td>p=20</td>\n",
" <td>0</td>\n",
" <td>0.0</td>\n",
" <td>8.795053</td>\n",
" <td>0.000000e+00</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>False</td>\n",
" <td>hyperband-p20</td>\n",
" <td>hyperband</td>\n",
" <td>-0.542950</td>\n",
" <td>2</td>\n",
" <td>bracket=2-1</td>\n",
" <td>20</td>\n",
" <td>p=20</td>\n",
" <td>0</td>\n",
" <td>0.0</td>\n",
" <td>12.844531</td>\n",
" <td>9.536743e-07</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>False</td>\n",
" <td>hyperband-p20</td>\n",
" <td>hyperband</td>\n",
" <td>-0.253263</td>\n",
" <td>3</td>\n",
" <td>bracket=2-2</td>\n",
" <td>20</td>\n",
" <td>p=20</td>\n",
" <td>0</td>\n",
" <td>0.0</td>\n",
" <td>17.382286</td>\n",
" <td>1.430511e-06</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>False</td>\n",
" <td>hyperband-p20</td>\n",
" <td>hyperband</td>\n",
" <td>-0.211405</td>\n",
" <td>4</td>\n",
" <td>bracket=2-3</td>\n",
" <td>20</td>\n",
" <td>p=20</td>\n",
" <td>0</td>\n",
" <td>0.0</td>\n",
" <td>23.095830</td>\n",
" <td>1.907349e-06</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>False</td>\n",
" <td>hyperband-p20</td>\n",
" <td>hyperband</td>\n",
" <td>-0.211405</td>\n",
" <td>5</td>\n",
" <td>bracket=2-4</td>\n",
" <td>20</td>\n",
" <td>p=20</td>\n",
" <td>0</td>\n",
" <td>0.0</td>\n",
" <td>28.456335</td>\n",
" <td>2.145767e-06</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" adaptive alg base_alg best_score \\\n",
"0 False hyperband-p20 hyperband -0.542950 \n",
"1 False hyperband-p20 hyperband -0.542950 \n",
"2 False hyperband-p20 hyperband -0.253263 \n",
"3 False hyperband-p20 hyperband -0.211405 \n",
"4 False hyperband-p20 hyperband -0.211405 \n",
"\n",
" cumulative_partial_fit_calls model_id patience patience_ repeat tol \\\n",
"0 1 bracket=2-0 20 p=20 0 0.0 \n",
"1 2 bracket=2-1 20 p=20 0 0.0 \n",
"2 3 bracket=2-2 20 p=20 0 0.0 \n",
"3 4 bracket=2-3 20 p=20 0 0.0 \n",
"4 5 bracket=2-4 20 p=20 0 0.0 \n",
"\n",
" train_time wall_time \n",
"0 8.795053 0.000000e+00 \n",
"1 12.844531 9.536743e-07 \n",
"2 17.382286 1.430511e-06 \n",
"3 23.095830 1.907349e-06 \n",
"4 28.456335 2.145767e-06 "
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import pandas as pd\n",
"df = pd.DataFrame(history)\n",
"df.to_csv('2018-09-10-history.csv')\n",
"print(df.alg.unique())\n",
"print(df.patience.unique())\n",
"df.head()"
]
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 360x360 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import seaborn as sns\n",
"import matplotlib.pyplot as plt\n",
"\n",
"w = 5\n",
"fig, ax = plt.subplots(figsize=(w, w))\n",
"\n",
"# adaptiveness = sorted([a for a in df.alg.unique() if a not in ['hyperband', 'stop_on_plateau']])\n",
"# adaptiveness = ['stop_on_plateau', 'hyperband']\n",
"\n",
"x = \"cumulative_partial_fit_calls\"\n",
"show = df.copy()\n",
"\n",
"sns.lineplot(\n",
" x=x,\n",
" y='best_score',\n",
" hue=\"alg\",\n",
" data=show,\n",
" ax=ax,\n",
" estimator='mean',\n",
")\n",
"ax.grid(linestyle='--')\n",
"# ax.set_ylim(-0.13, -0.09)\n",
"# ax.set_ylim(0.85, 0.875)\n",
"# plt.savefig('./successive-halving-comparison.png', dpi=300, bbox_inches='tight')\n",
"ax.set_ylim(-0.12, -0.09)\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [],
"source": [
"def shape_histories(histories, **kwargs):\n",
" histories = sum(histories, [])\n",
" _df = pd.DataFrame(histories)\n",
" _df = _df.sort_values(by=\"wall_time\")\n",
" \n",
" data = []\n",
" most_recent_scores = {}\n",
" most_recent_times = {}\n",
" for _, datum in _df.iterrows():\n",
" most_recent_scores[datum[\"repeat\"]] = datum[\"best_score\"]\n",
" \n",
" most_recent_times[datum[\"repeat\"]] = datum[\"wall_time\"]\n",
" scores = np.array(list(most_recent_scores.values()))\n",
" \n",
" data += [{\"best_score_mean\": scores.mean(),\n",
" \"best_score_std\": np.std(scores),\n",
" \"best_score_max\": scores.max(),\n",
" \"best_score_median\": np.median(scores),\n",
" \"best_score_min\": scores.min(),\n",
" \"wall_time (s)\": datum[\"wall_time\"],\n",
" \"wall_time (min)\": datum[\"wall_time\"] / 60,\n",
" \"alg\": datum[\"alg\"],\n",
" \"patience\": datum[\"patience\"],\n",
" \"tol\": datum[\"tol\"],\n",
" **kwargs}]\n",
" return data"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>alg</th>\n",
" <th>best_score_max</th>\n",
" <th>best_score_mean</th>\n",
" <th>best_score_median</th>\n",
" <th>best_score_min</th>\n",
" <th>best_score_std</th>\n",
" <th>patience</th>\n",
" <th>tol</th>\n",
" <th>wall_time (min)</th>\n",
" <th>wall_time (s)</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>hyperband-p20</td>\n",
" <td>-0.542950</td>\n",
" <td>-0.542950</td>\n",
" <td>-0.542950</td>\n",
" <td>-0.54295</td>\n",
" <td>0.000000</td>\n",
" <td>20</td>\n",
" <td>0.0</td>\n",
" <td>0.000000e+00</td>\n",
" <td>0.000000e+00</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>hyperband-p20</td>\n",
" <td>-0.130327</td>\n",
" <td>-0.336639</td>\n",
" <td>-0.336639</td>\n",
" <td>-0.54295</td>\n",
" <td>0.206311</td>\n",
" <td>20</td>\n",
" <td>0.0</td>\n",
" <td>0.000000e+00</td>\n",
" <td>0.000000e+00</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>hyperband-p20</td>\n",
" <td>-0.130327</td>\n",
" <td>-0.433896</td>\n",
" <td>-0.542950</td>\n",
" <td>-0.62841</td>\n",
" <td>0.217472</td>\n",
" <td>20</td>\n",
" <td>0.0</td>\n",
" <td>0.000000e+00</td>\n",
" <td>0.000000e+00</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>hyperband-p20</td>\n",
" <td>-0.130327</td>\n",
" <td>-0.359529</td>\n",
" <td>-0.405311</td>\n",
" <td>-0.54295</td>\n",
" <td>0.171535</td>\n",
" <td>20</td>\n",
" <td>0.0</td>\n",
" <td>7.947286e-09</td>\n",
" <td>4.768372e-07</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>hyperband-p20</td>\n",
" <td>-0.130327</td>\n",
" <td>-0.359529</td>\n",
" <td>-0.405311</td>\n",
" <td>-0.54295</td>\n",
" <td>0.171535</td>\n",
" <td>20</td>\n",
" <td>0.0</td>\n",
" <td>1.192093e-08</td>\n",
" <td>7.152557e-07</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" alg best_score_max best_score_mean best_score_median \\\n",
"0 hyperband-p20 -0.542950 -0.542950 -0.542950 \n",
"1 hyperband-p20 -0.130327 -0.336639 -0.336639 \n",
"2 hyperband-p20 -0.130327 -0.433896 -0.542950 \n",
"3 hyperband-p20 -0.130327 -0.359529 -0.405311 \n",
"4 hyperband-p20 -0.130327 -0.359529 -0.405311 \n",
"\n",
" best_score_min best_score_std patience tol wall_time (min) \\\n",
"0 -0.54295 0.000000 20 0.0 0.000000e+00 \n",
"1 -0.54295 0.206311 20 0.0 0.000000e+00 \n",
"2 -0.62841 0.217472 20 0.0 0.000000e+00 \n",
"3 -0.54295 0.171535 20 0.0 7.947286e-09 \n",
"4 -0.54295 0.171535 20 0.0 1.192093e-08 \n",
"\n",
" wall_time (s) \n",
"0 0.000000e+00 \n",
"1 0.000000e+00 \n",
"2 0.000000e+00 \n",
"3 4.768372e-07 \n",
"4 7.152557e-07 "
]
},
"execution_count": 38,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"histories = [shape_histories(alg_hist, alg=alg) for alg, alg_hist in alg_shaped_histories.items()]\n",
"time_df = pd.DataFrame(sum(histories, []))\n",
"time_df.to_csv('2018-09-10-shaped-history-more-freq-score.csv')\n",
"time_df.head()"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Text(0,0.5,'Best score')"
]
},
"execution_count": 42,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"fig, ax = plt.subplots()\n",
"\n",
"key = \"wall_time (s)\"\n",
"for alg in time_df.alg.unique():\n",
" s = time_df[time_df.alg == alg]\n",
" ax.plot(s[key].values, s[\"best_score_mean\"].values, label=alg)\n",
" ax.fill_between(\n",
" s[key],\n",
" s.best_score_mean - s.best_score_std,\n",
" s.best_score_mean + s.best_score_std,\n",
" alpha=0.2,\n",
" )\n",
"ax.set_ylim(-0.12, -0.09)\n",
"ax.legend(loc=\"lower right\", title=\"algorithm\")\n",
"ax.set_xlabel(key)\n",
"ax.grid(linestyle='--')\n",
"ax.set_ylabel(\"Best score\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Parameter visualization"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {},
"outputs": [],
"source": [
"hist = pd.DataFrame(search.cv_results_)\n",
"hist['param_optimizer_'] = hist['param_optimizer'].apply(lambda opt: str(opt).replace('<class', '').strip('>'))\n",
"hist['test_loss'] = -1 * hist['test_score']\n",
"hist = hist.sort_values(by='test_loss')\n",
"hist['rank'] = np.arange(len(hist)) + 1"
]
},
{
"cell_type": "code",
"execution_count": 48,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 720x720 with 4 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"w = 5\n",
"fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(2*w, 2*w))\n",
"axs = axs.flat[:]\n",
"hues = ['param_optimizer', 'param_module__init', 'param_module__activation', 'param_batch_size']\n",
"for ax, hue in zip(axs, hues):\n",
" cmap = None\n",
" if 'batch_size' in hue:\n",
" cmap = 'viridis'\n",
" sns.barplot(\n",
" x='rank', \n",
" y='test_loss',\n",
" hue=hue,\n",
" data=hist,\n",
" ax=ax,\n",
" palette=cmap,\n",
" dodge=False,\n",
" )\n",
" ax.set_xlim(-1.5, 50)\n",
" ax.set_ylim(0, 0.14)\n",
" ax.grid(linestyle='--', which='y')\n",
" ax.legend(loc='lower right')\n",
" ax.set_title(hue.replace('param_', ''))\n",
" ax.tick_params(labelbottom=False)\n",
"plt.savefig('2018-09-10-global-params.png', dpi=300)\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {},
"outputs": [],
"source": [
"sgd_alg = [a for a in hist.param_optimizer_.unique() if 'sgd' in a.lower()][0]\n",
"sgd = hist[hist.param_optimizer_ == sgd_alg]"
]
},
{
"cell_type": "code",
"execution_count": 47,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[\" 'torch.optim.sgd.SGD'\"]\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 720x720 with 4 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"w = 5\n",
"fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(2*w, 2*w))\n",
"axs = axs.flat[:]\n",
"variables = ['param_optimizer__lr', 'param_optimizer__momentum', 'param_optimizer__weight_decay', 'param_batch_size']\n",
"print(sgd.param_optimizer_.unique())\n",
"for ax, var in zip(axs, variables):\n",
" show = sgd.copy()\n",
" show = show.sort_values(by=var)\n",
" if 'weight_decay' in var:\n",
" show[var] += 1e-8\n",
" show = show[show.test_loss < 0.16]\n",
" sns.scatterplot(\n",
" x=var,\n",
" y='test_loss',\n",
" data=show,\n",
" hue='test_loss',\n",
" palette='magma',\n",
" legend=False,\n",
" ax=ax,\n",
" )\n",
" if 'lr' in var:\n",
" ax.set_xscale('log', basex=10)\n",
" if 'batch_size' in var:\n",
" ax.set_xscale('log', basex=2)\n",
" if 'weight_decay' in var:\n",
" ax.set_xlim(5e-9, 1e-3)\n",
" ax.set_xscale('log', basex=10)\n",
" \n",
" ax.grid(linestyle='--')\n",
" ax.set_ylim(0.08, 0.16)\n",
"plt.savefig('2018-09-10-sgd-params.png', dpi=300)\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
import skorch.utils
from skorch import NeuralNetRegressor
import torch.nn as nn
import torch
import skorch
def _initialize(method, layer, gain=1):
weight = layer.weight.data
# _before = weight.data.clone()
kwargs = {'gain': gain} if 'xavier' in str(method) else {}
method(weight.data, **kwargs)
# assert torch.all(weight.data != _before)
class Autoencoder(nn.Module):
def __init__(self, activation='ReLU', init='xavier_uniform_',
**kwargs):
super().__init__()
self.activation = activation
self.init = init
self._iters = 0
init_method = getattr(torch.nn.init, init)
act_layer = getattr(nn, activation)
act_kwargs = {'inplace': True} if self.activation != 'PReLU' else {}
gain = 1
if self.activation in ['LeakyReLU', 'ReLU']:
name = 'leaky_relu' if self.activation == 'LeakyReLU' else 'relu'
gain = torch.nn.init.calculate_gain(name)
inter_dim = 28 * 28 // 4
latent_dim = inter_dim // 4
layers = [
nn.Linear(28 * 28, inter_dim),
act_layer(**act_kwargs),
nn.Linear(inter_dim, latent_dim),
act_layer(**act_kwargs)
]
for layer in layers:
if hasattr(layer, 'weight') and layer.weight.data.dim() > 1:
_initialize(init_method, layer, gain=gain)
self.encoder = nn.Sequential(*layers)
layers = [
nn.Linear(latent_dim, inter_dim),
act_layer(**act_kwargs),
nn.Linear(inter_dim, 28 * 28),
nn.Sigmoid()
]
layers = [
nn.Linear(latent_dim, 28 * 28),
nn.Sigmoid()
]
for layer in layers:
if hasattr(layer, 'weight') and layer.weight.data.dim() > 1:
_initialize(init_method, layer, gain=gain)
self.decoder = nn.Sequential(*layers)
def forward(self, x):
self._iters += 1
shape = x.size()
x = x.view(x.shape[0], -1)
x = self.encoder(x)
x = self.decoder(x)
return x.view(shape)
class NegLossScore(NeuralNetRegressor):
steps = 0
def partial_fit(self, *args, **kwargs):
super().partial_fit(*args, **kwargs)
self.steps += 1
def score(self, X, y):
X = skorch.utils.to_tensor(X, device=self.device)
y = skorch.utils.to_tensor(y, device=self.device)
self.initialize_criterion()
y_hat = self.predict(X)
y_hat = skorch.utils.to_tensor(y_hat, device=self.device)
loss = super().get_loss(y_hat, y, X=X, training=False).item()
print(f'steps = {self.steps}, loss = {loss}')
return -1 * loss
def initialize(self, *args, **kwargs):
super().initialize(*args, **kwargs)
self.callbacks_ = []
from keras.datasets import mnist
import numpy as np
import skimage.util
import random
import skimage.filters
import skimage
import scipy.signal
def noise_img(x):
noises = [
{"mode": "s&p", "amount": np.random.uniform(0.1, 0.1)},
{"mode": "gaussian", "var": np.random.uniform(0.10, 0.15)},
]
# noise = random.choice(noises)
noise = noises[1]
return skimage.util.random_noise(x, **noise)
def train_formatting(img):
img = img.reshape(28, 28).astype("float32")
return img.flat[:]
def blur_img(img):
assert img.ndim == 1
n = int(np.sqrt(img.shape[0]))
img = img.reshape(n, n)
h = np.zeros((n, n))
angle = np.random.uniform(-5, 5)
w = random.choice(range(1, 3))
h[n // 2, n // 2 - w : n // 2 + w] = 1
h = skimage.transform.rotate(h, angle)
h /= h.sum()
y = scipy.signal.convolve(img, h, mode="same")
return y.flat[:]
def dataset(n=None):
(x_train, _), (x_test, _) = mnist.load_data()
x = np.concatenate((x_train, x_test))
if n:
x = x[:n]
else:
n = int(70e3)
x = x.astype("float32") / 255.
x = np.reshape(x, (len(x), 28 * 28))
y = np.apply_along_axis(train_formatting, 1, x)
clean = y.copy()
noisy = y.copy()
# order = [noise_img, blur_img]
# order = [blur_img]
order = [noise_img]
random.shuffle(order)
for fn in order:
noisy = np.apply_along_axis(fn, 1, noisy)
noisy = noisy.astype("float32")
clean = clean.astype("float32")
# noisy = noisy.reshape(-1, 1, 28, 28).astype("float32")
# clean = clean.reshape(-1, 1, 28, 28).astype("float32")
return noisy, clean
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment