Skip to content

Instantly share code, notes, and snippets.

@izmailovpavel
Last active July 17, 2023 12:55
Show Gist options
  • Save izmailovpavel/daf2c3cd804ef9db2f793d858b224737 to your computer and use it in GitHub Desktop.
Save izmailovpavel/daf2c3cd804ef9db2f793d858b224737 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"# Computing Nonvacuous Generalization Bounds for Deep (Stochastic) Neural Networks with Many More Parameters than Training Data\n",
"\n",
"**Gintare Karolina Dziugaite, Daniel M. Roy (2017)**\n",
"\n",
"- [Paper](https://arxiv.org/pdf/1703.11008.pdf)\n",
"- [PAC-Bayes Tutorial](https://arxiv.org/pdf/1901.05353.pdf)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## Generalization Bounds\n",
"\n",
"PAC bounds:\n",
"\n",
"$$\\mathbb P(\\mathcal E (c) \\le \\delta) \\ge 1 - \\epsilon,$$\n",
"- $\\mathcal E (c)$ — error rate of the classifier $\\phi$\n",
"- $\\delta, \\epsilon \\in (0, 1)$\n",
"\n",
"Here $\\delta$ is a threshold that usually depends on $\\epsilon$ and data.\n",
"\n",
"\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Example: Simple Generalization Bound\n",
"\n",
"- Countable set of classifiers $c \\in \\mathcal C$, and a prior distribution $P(c)$ over them.\n",
"\n",
"Then with probability at least $1 - \\delta$ for _all_ classifiers $c$ correctly classifying (all) $m$ training examples\n",
"\n",
"$$\\mathcal E (c) \\le \\frac {\\log \\frac 1 {P(c)} + \\log \\frac 1 \\delta}{m}$$\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"**Proof:**\n",
"Suppose we have a classifier $c$ such that the error rate is greater than the bound\n",
"$$\n",
" \\mathcal E (c) \\ge \\frac {\\log \\frac 1 {P(c)} + \\log \\frac 1 \\delta}{m}.\n",
"$$\n",
"The probability of this event it agreeing with $m$ data points is\n",
"$$\n",
" (1 - \\mathcal E (c))^m \\le e^{-\\mathcal E(C) m} \\le \n",
" e^{- \\log \\frac 1 {P(c)} + \\log \\frac 1 \\delta} = P(c) \\delta.\n",
"$$\n",
"By union bound we get the statement of the theorem."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## PAC-Bayes Bound\n",
"\n",
"Idea: get generalization bounds for stochastic classifiers.\n",
"\n",
"Setting:\n",
"- Binary classification with a neural network\n",
"- We will get a generalization bound for a model average under an arbitrary distribution $Q$ over weights"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Notation\n",
"\n",
"- $S_m$ — dataset of $m$ datapoints\n",
"- $\\hat e(Q, S_m) = \\mathbb E_{h \\sim Q} (\\hat e(h, S_m))$ — empirical error of the model average under distribution $Q$ on $S_m$\n",
"- $e(Q)$ — true (expected) error of the model average under distribution $Q$\n",
"- $KL(\\hat e(Q, S_m) \\vert \\vert e(Q))$ — KL-divergence between _Bernoulli_ random variables with success probabilities $\\hat e(Q, S_m)$ and $e(Q)$:\n",
"$$\n",
" \\hat e(Q, S_m) \\cdot \\log \\frac{\\hat e(Q, S_m)}{e(Q)} + (1 - \\hat e(Q, S_m)) \\cdot \\log \\frac{1 - \\hat e(Q, S_m)}{1 -e(Q)}\n",
"$$\n",
"<span style=\"color:blue\">This will be our way of measuring difference between true and empirical error rate!</span>"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### The Bound \n",
"\n",
"**PAC-Bayes Bound** ([McAllester 1999](https://link.springer.com/article/10.1023/A:1007618624809)):\n",
"\n",
"Suppose $P$ is a prior distribution over the parameters of the network (can not depend on the data). Then with probability at least $1 - \\delta$ for all distributions $Q$ over the parameters we have:\n",
"$$\n",
" KL(\\hat e(Q, S_m) \\vert \\vert e(Q)) \\le\n",
" \\frac{KL(Q \\vert \\vert P) + \\log \\frac m \\delta}{m-1}.\n",
"$$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Idea\n",
"\n",
"$$\n",
" KL(\\hat e(Q, S_m) \\vert \\vert e(Q)) \\le\n",
" \\frac{KL(Q \\vert \\vert P) + \\log \\frac m \\delta}{m-1}.\n",
"$$\n",
"\n",
"The bound implicitly implies a lower bound on $e(Q)$. Let's try to optimize the bound with respect to $Q$!\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## Optimizing the bound\n",
"\n",
"$$\n",
" KL(\\hat e(Q, S_m) \\vert \\vert e(Q)) \\le\n",
" \\frac{KL(Q \\vert \\vert P) + \\log \\frac m \\delta}{m-1}.\n",
"$$\n",
"\n",
"First, we only have an implicit dependence on $e(Q)$ in the bound. We need to _invert the KL_."
]
},
{
"cell_type": "code",
"execution_count": 75,
"metadata": {
"slideshow": {
"slide_type": "skip"
}
},
"outputs": [],
"source": [
"import numpy as np\n",
"import math\n",
"from matplotlib import pyplot as plt\n",
"import seaborn as sns\n",
"sns.set_style('whitegrid')\n",
"\n",
"%matplotlib inline"
]
},
{
"cell_type": "code",
"execution_count": 81,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [],
"source": [
"def kl_bernoulli(q, p):\n",
" return q * math.log(q / p)+ (1 - q) * math.log((1 - q) / (1 - p))"
]
},
{
"cell_type": "code",
"execution_count": 84,
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 504x360 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.figure(figsize=(7, 5))\n",
"ps = np.linspace(0.05, 0.95, 20)\n",
"qs = np.linspace(0.1, 0.9, 5)\n",
"kls = [[kl_bernoulli(q, p) for p in ps] for q in qs]\n",
"for i, q in enumerate(qs):\n",
" plt.plot(ps, kls[i], label=r\"$q = {:.1f}$\".format(q), linewidth=2)\n",
" \n",
"plt.ylabel(r\"$KL(q || p)$\", fontsize=16)\n",
"plt.xlabel(\"p\", fontsize=16)\n",
"plt.legend(fontsize=14);"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Inverting the KL\n",
"\n",
"\n",
"For evaluating the bound, we can invert $KL$ numerically. However, in order to train the distribution $Q$ we need a way of getting the gradients. We will use the bound:\n",
"\n",
"$$KL(q \\vert \\vert p) \\ge 2(q - p)^2.$$\n",
"\n",
"So, if $KL(q \\vert \\vert p) \\le \\epsilon$, we have \n",
"$$\n",
"p \\le q + \\sqrt{\\frac \\epsilon 2}.\n",
"$$\n",
"Combining this expression with the McAllester bound, we get\n",
"$$\n",
" e(Q) \\le \\hat e(Q, S_m) + \\sqrt{\n",
" \\frac{KL(Q \\vert \\vert P) + \\log \\frac m \\delta}{2(m-1)}\n",
" }\n",
"$$"
]
},
{
"cell_type": "code",
"execution_count": 90,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 504x360 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.figure(figsize=(7, 5))\n",
"ps = np.linspace(0.05, 0.95, 20)\n",
"qs = [0.02]\n",
"kls = [[kl_bernoulli(q, p) for p in ps] for q in qs]\n",
"for i, q in enumerate(qs):\n",
" plt.plot(ps, kls[i], label=r\"$KL(q || p), q = {:.2f}$\".format(q), linewidth=2)\n",
" plt.plot(ps, [2 * (q - p)**2 for p in ps], label=r\"$2 (q - p)^2$\", linewidth=2)\n",
" \n",
"# plt.ylabel(\"KL(q || p)\", fontsize=16)\n",
"plt.xlabel(\"p\", fontsize=16)\n",
"plt.legend(fontsize=14);"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Parameterize the $Q$\n",
"\n",
"For $Q$ we will use the standard $N(m, s)$ form, where $s$ is a diagonal matrix.\n",
"\n",
"We will optimize the bound\n",
"\n",
"$$\n",
" e(Q) \\le \\hat e(Q, S_m) + \\sqrt{\n",
" \\frac{KL(Q \\vert \\vert P) + \\log \\frac m \\delta}{2(m-1)}\n",
" } = \n",
" \\mathbb E_{h \\sim N(m, s)} \\hat e(h, S_m)\n",
" + \\sqrt{\n",
" \\frac{KL(N(m, s) \\vert \\vert P) + \\log \\frac m \\delta}{2(m-1)}}.\n",
"$$\n",
"\n",
"<span style=\"color:blue\">The bound is of the form _Expected Loss + Complexity penalty_. Notice the similarity with variational ELBO.</span>\n",
"\n",
"We will optimize the bound with respect to $m, s$.\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"### Prior Choice\n",
"\n",
"- We will use $P(w) = N(\\mu, \\lambda I)$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"- For $\\mu$ we will use random initialization $w_0$\n",
" - <span style=\"color:blue\">The idea is to have a compact prior. In particular, we want to avoid covering all the symmetries in the parameterization.</span>"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"- $\\lambda$ will be trained to optimize the bound\n",
" - We need to account for optimizing the $\\lambda$ in the bound. \n",
" - Use a discrete set of values for $\\lambda$ of the form $\\lambda = 0.1 \\exp(- j / 100)$, $j \\in \\mathbb N$\n",
" - Use different $\\delta$ values that sum up to one for different values of $\\lambda$: $\\delta_j = \\frac \\delta {\\pi^2 j^2}.$\n",
" - By union bound, the bound will then hold simultaneously for all $\\lambda$ with probability $\\delta$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"source": [
"Final bound:\n",
"\n",
"$$\n",
" \\mathbb E_{h \\sim N(m, s)} \\hat e(h, S_m)\n",
" + \\sqrt{\n",
" \\frac{KL(N(m, s) \\vert \\vert N(w_0, \\lambda I)) + 2 \\log(b \\log \\frac c \\lambda) + \\log \\frac {\\pi^2 m} {6 \\delta}}{m-1}}.\n",
"$$"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"We minimize the bouns with respect to $m, s$ and $\\lambda$!"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## Example: MLP on MNIST"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {
"slideshow": {
"slide_type": "skip"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/izmailovpavel/anaconda3/envs/py37/lib/python3.7/site-packages/sklearn/utils/deprecation.py:144: FutureWarning: The sklearn.decomposition.pca module is deprecated in version 0.22 and will be removed in version 0.24. The corresponding classes / functions should instead be imported from sklearn.decomposition. Anything that cannot be imported from sklearn.decomposition is now part of the private API.\n",
" warnings.warn(message, FutureWarning)\n"
]
}
],
"source": [
"import math\n",
"import torch\n",
"import torch.nn.functional as F\n",
"import torchvision\n",
"import numpy as np\n",
"import tqdm\n",
"\n",
"from swag import data, models, utils, losses\n",
"from swag.posteriors import SWAG\n",
"from swag.utils import eval, train_epoch\n",
"from swag.losses import cross_entropy"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"You are going to run models on the test set. Are you sure?\n"
]
}
],
"source": [
"# Making MNIST dataloaders for binary classification\n",
"\n",
"loaders, num_classes = data.loaders(\n",
" \"MNIST\",\n",
" \"~/datasets/\",\n",
" 128,\n",
" 4,\n",
" torchvision.transforms.ToTensor(),\n",
" torchvision.transforms.ToTensor(),\n",
" use_validation=False,\n",
" split_classes=None,\n",
" shuffle_train=False\n",
")\n",
"loaders['train'].dataset.targets = 1 - 2*(loaders['train'].dataset.targets < 5).int()\n",
"loaders['test'].dataset.targets = 1 - 2*(loaders['test'].dataset.targets < 5).int()"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [],
"source": [
"# Stochastic Linear Layer\n",
"\n",
"class StochLinear(torch.nn.Module):\n",
" def __init__(self, in_features, out_features, bias=True, \n",
" init_inv_softplus_sigma=-3.0, eps=1e-6):\n",
" super().__init__()\n",
" self.weight_mu = torch.nn.Parameter(torch.zeros(out_features, in_features))\n",
" self.weight_ispsigma = torch.nn.Parameter(torch.empty(out_features, in_features).fill_(\n",
" init_inv_softplus_sigma))\n",
" if bias:\n",
" self.bias_mu = torch.nn.Parameter(torch.zeros(out_features))\n",
" self.bias_ispsigma = torch.nn.Parameter(torch.empty(out_features).fill_(init_inv_softplus_sigma))\n",
" self.eps = eps\n",
" self.with_bias = bias\n",
" \n",
" \n",
" def forward(self, x):\n",
" weight = self.weight_mu + torch.randn_like(self.weight_mu) * self.weight_sigma\n",
" if self.with_bias:\n",
" bias = self.bias_mu + torch.randn_like(self.bias_mu) * self.bias_sigma\n",
" else:\n",
" bias = None\n",
" return torch.nn.functional.linear(x, weight, bias)\n",
" \n",
" def copy_weights(self, linear_layer):\n",
" self.weight_mu.detach().copy_(linear_layer.weight)\n",
" if self.with_bias:\n",
" self.bias_mu.detach().copy_(linear_layer.bias)\n",
" \n",
" @property\n",
" def weight_sigma(self):\n",
" return torch.nn.functional.softplus(self.weight_ispsigma) + self.eps\n",
" \n",
" @property\n",
" def bias_sigma(self):\n",
" return torch.nn.functional.softplus(self.bias_ispsigma) + self.eps\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 56,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [],
"source": [
"class PACBound:\n",
" def __init__(self, prior_mean, lambda_init, likelihood_fn, num_data,\n",
" delta=0.025, b=100, c=0.1):\n",
" self.prior_mu = prior_mean\n",
" self.log_lambda_ = torch.tensor([math.log(lambda_init)],\n",
" requires_grad=True, device=prior_mean.device)\n",
" self.likelihood_fn = likelihood_fn\n",
" self.num_data = torch.tensor([num_data], requires_grad=False, \n",
" dtype=float, device=prior_mean.device)\n",
" self.delta = delta\n",
" self.b = b\n",
" self.c = c\n",
" \n",
" def kl(self, stoch_network):\n",
" mu_net, sigma_net = self._get_mu_sigma(stoch_network)\n",
" prior_sigma = torch.ones_like(sigma_net) * self.prior_sigma\n",
" return self._kl_gaussians(mu_net, sigma_net, self.prior_mu, prior_sigma)\n",
" \n",
" def train_bound(self, stoch_network, x, y):\n",
" preds = stoch_network(x).flatten()\n",
" acc = (preds * y > 0).float().mean()\n",
" likelihood = self.likelihood_fn(preds, y)\n",
" regularizer = torch.sqrt(self.B_RE(stoch_network) / 2)\n",
" return likelihood + regularizer, preds, \\\n",
" {'accuracy': acc, 'regularizer': regularizer}\n",
" \n",
" def B_RE(self, stoch_network):\n",
" kl = self.kl(stoch_network)\n",
" return (kl + \n",
" 2 * torch.log(torch.abs(self.b * (math.log(self.c) - self.log_lambda_))) + \n",
" math.log(math.pi**2 * self.num_data / (6 * self.delta))) / (self.num_data - 1)\n",
" \n",
" @property\n",
" def prior_sigma(self):\n",
" return 1 / torch.sqrt(torch.exp(self.log_lambda_))\n",
" \n",
" @staticmethod\n",
" def _kl_gaussians(mu_1, sigma_1, mu_2, sigma_2):\n",
" # sigma_i -- vector of stds\n",
" d = mu_1.numel()\n",
" kl = 0.5 * (torch.sum(torch.log(sigma_2) - torch.log(sigma_1)) * 2 +\n",
" torch.sum(sigma_1**2 / sigma_2**2) +\n",
" torch.sum((mu_2 - mu_1)**2 / sigma_2**2) - d)\n",
" return kl \n",
" \n",
" @staticmethod\n",
" def _get_mu_sigma(stoch_network):\n",
" mu_net, sigma_net = [torch.cat(a) for a in \n",
" zip(*[\n",
" [torch.cat([layer.weight_mu.flatten(), layer.bias_mu.flatten()]),\n",
" torch.cat([layer.weight_sigma.flatten(), layer.bias_sigma.flatten()])]\n",
" for \n",
" layer in stoch_network if\n",
" isinstance(layer, StochLinear)])\n",
" ]\n",
" return mu_net, sigma_net"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [],
"source": [
"def logistic_loss(preds, y):\n",
" loss = torch.log(1 + torch.exp(-preds.flatten() * y))/ math.log(2)\n",
" loss = loss.mean(dim=0)\n",
" \n",
" return loss\n",
"\n",
"# single hidden layer\n",
"def get_model():\n",
" return torch.nn.Sequential(torch.nn.Flatten(), \n",
" torch.nn.Linear(28**2, 600), \n",
" torch.nn.ReLU(), \n",
" torch.nn.Linear(600, 1))\n",
"\n",
"\n",
"def get_stoch_model(isp_sigma=-3.):\n",
" return torch.nn.Sequential(torch.nn.Flatten(), \n",
" StochLinear(28**2, 600, init_inv_softplus_sigma=isp_sigma), \n",
" torch.nn.ReLU(), \n",
" StochLinear(600, 1, init_inv_softplus_sigma=isp_sigma))\n",
"\n",
"def logistic_loss_swag(model, x, y):\n",
" preds = model(x).flatten()\n",
" acc = (preds * y > 0).float().mean()\n",
" return logistic_loss(preds, y), preds, {'accuracy': acc}"
]
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 10/10 [00:16<00:00, 1.61s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train: {'accuracy': tensor(0.9966, device='cuda:0')}\n",
"Test: {'accuracy': tensor(0.9815, device='cuda:0')}\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"# we will start by pre-training the model\n",
"\n",
"model = get_model()\n",
"model.cuda()\n",
"\n",
"# Save the random initialization for the prior\n",
"torch.save(model.state_dict(), \"model_init.pt\")\n",
"\n",
"loss_fn = logistic_loss_swag\n",
"opt = torch.optim.SGD(model.parameters(), lr=0.05, momentum=0.9)\n",
"\n",
"for epoch in tqdm.tqdm(range(10)):\n",
" train_res = utils.train_epoch(loaders['train'], model, loss_fn, opt, regression=True)\n",
" test_res = eval(loaders[\"test\"], model, loss_fn, regression=True)\n",
"print(\"Train: \", train_res['stats'])\n",
"print(\"Test: \", test_res['stats'])"
]
},
{
"cell_type": "code",
"execution_count": 72,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [],
"source": [
"# Initialize the stochastic model and the PAC bound\n",
"\n",
"init_model = get_model()\n",
"init_model.cuda()\n",
"init_model.load_state_dict(torch.load(\"model_init.pt\"))\n",
"\n",
"prior_mu = torch.cat([p.detach().flatten() for p in init_model.parameters()])\n",
"pac_bound = PACBound(prior_mu, lambda_init=1e-3, likelihood_fn=logistic_loss, \n",
" num_data=len(loaders['train'].dataset))\n",
"\n",
"stoch_model = get_stoch_model(isp_sigma=-2)\n",
"for stoch_layer, layer in zip(stoch_model, model):\n",
" if isinstance(stoch_layer, StochLinear):\n",
" stoch_layer.copy_weights(layer)\n",
"stoch_model.cuda();"
]
},
{
"cell_type": "code",
"execution_count": 73,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 50/50 [02:06<00:00, 2.52s/it]\n"
]
},
{
"data": {
"text/plain": [
"Text(0.5, 0, 'Iteration')"
]
},
"execution_count": 73,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 1008x360 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# Optimize the bound\n",
"loss_fn = pac_bound.train_bound\n",
"opt = torch.optim.Adam([{'name': 'stoch_net','params': stoch_model.parameters(), 'lr': 5e-3}, \n",
" {'name': 'lambda','params': pac_bound.log_lambda_, 'lr': 1e-2}]\n",
" )\n",
"lambdas = []\n",
"train_accs = []\n",
"b_re_values = []\n",
"\n",
"for epoch in tqdm.tqdm(range(50)):\n",
" train_res = utils.train_epoch(loaders['train'], stoch_model, loss_fn, opt, regression=True)\n",
" \n",
" train_accs.append(train_res['stats']['accuracy'].item())\n",
" lambdas.append(torch.exp(pac_bound.log_lambda_).item())\n",
" b_re_values.append(pac_bound.B_RE(stoch_model).item())\n",
" \n",
"\n",
"f, arr = plt.subplots(1, 2, figsize=(14,5))\n",
"\n",
"arr[0].plot(train_accs, lw=2)\n",
"arr[0].set_ylabel(\"Train Accuracy\", fontsize=16)\n",
"arr[0].set_xlabel(\"Iteration\", fontsize=16)\n",
"arr[1].plot(lambdas, lw=2)\n",
"arr[1].set_ylabel(r\"$\\lambda$\", fontsize=16)\n",
"arr[1].set_xlabel(\"Iteration\", fontsize=16)"
]
},
{
"cell_type": "code",
"execution_count": 74,
"metadata": {
"slideshow": {
"slide_type": "subslide"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/izmailovpavel/.local/lib/python3.7/site-packages/ipykernel_launcher.py:10: RuntimeWarning: divide by zero encountered in double_scalars\n",
" # Remove the CWD from sys.path while we load stuff.\n"
]
},
{
"data": {
"text/plain": [
"Text(0.5, 0, 'Iteration')"
]
},
"execution_count": 74,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 504x360 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# Let's visualize how the true error bound changes with iteration\n",
"ts = np.linspace(0, 1., 20)\n",
"error_bounds = []\n",
"for b_re, acc in zip(b_re_values, train_accs):\n",
" kls = [kl_bernoulli(1 - acc, t) for t in ts]\n",
" idx = np.argmin(np.abs(np.array(kls) - b_re))\n",
" error_bounds.append(ts[idx])\n",
"\n",
"plt.figure(figsize=(7,5))\n",
"plt.plot(error_bounds, lw=2)\n",
"plt.ylabel(\"Error Bound\", fontsize=16)\n",
"plt.xlabel(\"Iteration\", fontsize=16)\n",
"\n",
"# Note that here we ignore a few things: error estimation for the data term and also quantization of lambda"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## Results\n",
"\n",
"The authors apply the method to a bunch of small architectures on MNIST and get non-vacuous error bounds.\n",
"\n",
"<style type=\"text/css\">\n",
".tg {border-collapse:collapse;border-spacing:0;}\n",
".tg td{font-family:Arial, sans-serif;font-size:14px;padding:10px 5px;border-style:solid;border-width:1px;overflow:hidden;word-break:normal;border-color:black;}\n",
".tg th{font-family:Arial, sans-serif;font-size:14px;font-weight:normal;padding:10px 5px;border-style:solid;border-width:1px;overflow:hidden;word-break:normal;border-color:black;}\n",
".tg .tg-0pky{border-color:inherit;text-align:left;vertical-align:top}\n",
".tg .tg-0lax{text-align:left;vertical-align:top}\n",
"</style>\n",
"<table class=\"tg\">\n",
" <tr>\n",
" <th class=\"tg-0pky\"></th>\n",
" <th class=\"tg-0pky\">600</th>\n",
" <th class=\"tg-0pky\">1200</th>\n",
" <th class=\"tg-0pky\">600, 600</th>\n",
" <th class=\"tg-0pky\">1200, 1200</th>\n",
" <th class=\"tg-0lax\">600, 600, 600</th>\n",
" <th class=\"tg-0lax\">R-600</th>\n",
" </tr>\n",
" <tr>\n",
" <td class=\"tg-0pky\">Train Error</td>\n",
" <td class=\"tg-0pky\">0.001</td>\n",
" <td class=\"tg-0pky\">0.002</td>\n",
" <td class=\"tg-0pky\">0.</td>\n",
" <td class=\"tg-0pky\">0.</td>\n",
" <td class=\"tg-0lax\">0.</td>\n",
" <td class=\"tg-0lax\">0.007</td>\n",
" </tr>\n",
" <tr>\n",
" <td class=\"tg-0pky\">Test Error</td>\n",
" <td class=\"tg-0pky\">0.018</td>\n",
" <td class=\"tg-0pky\">0.018</td>\n",
" <td class=\"tg-0pky\">0.016</td>\n",
" <td class=\"tg-0pky\">0.016</td>\n",
" <td class=\"tg-0lax\">0.15</td>\n",
" <td class=\"tg-0lax\">0.508</td>\n",
" </tr>\n",
" <tr>\n",
" <td class=\"tg-0pky\">Bound</td>\n",
" <td class=\"tg-0pky\">0.161</td>\n",
" <td class=\"tg-0pky\">0.179</td>\n",
" <td class=\"tg-0pky\">0.186</td>\n",
" <td class=\"tg-0pky\">0.223</td>\n",
" <td class=\"tg-0lax\">0.201</td>\n",
" <td class=\"tg-0lax\">1.352</td>\n",
" </tr>\n",
" <tr>\n",
" <td class=\"tg-0pky\">#Parameters</td>\n",
" <td class=\"tg-0pky\">471k</td>\n",
" <td class=\"tg-0pky\">943k</td>\n",
" <td class=\"tg-0pky\">832k</td>\n",
" <td class=\"tg-0pky\">2384k</td>\n",
" <td class=\"tg-0lax\">1193k</td>\n",
" <td class=\"tg-0lax\">472k</td>\n",
" </tr>\n",
"</table>\n",
"\n",
"- Nonvacuous bounds in over-parameterized regime\n",
"- Still very much not tight\n",
"- On randomly labeled data (last column) the bound predicts high error"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"celltoolbar": "Slideshow",
"kernelspec": {
"display_name": "py37",
"language": "python",
"name": "py37"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment