Skip to content

Instantly share code, notes, and snippets.

@simongrest
Created March 28, 2019 21:19
  • Star 4 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
Star You must be signed in to star a gist
Save simongrest/52404966f0c46f750a823a44618bb06c to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# All you need is a good init and Orthonormal Initialization\n",
"#### A simple implementation based on two papers:\n",
"- Dmytro Mishkin, Jiri Matas (https://arxiv.org/abs/1511.06422) and \n",
"- Andrew Saxe, James McClelland and Surya Ganguli (https://arxiv.org/abs/1312.6120)"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from fastai import datasets\n",
"import gzip, pickle\n",
"from matplotlib import pyplot as plt"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Layer Sequential Unit-Variance Initialization\n",
"The main idea in the 'All you need is a good init' paper is an algorithm the authors call 'Layer Sequential Unit-Variance Initialization' or LSUV. Instead of trying to compute a formula for how to scale weights in terms of the dimensions of particular layers, the algorithm instead takes an empirical approach. You feed a batch of input data through the network layer by layer and adjust the initial weights of each layer until the scale of the outputs is sufficiently close to 1. \n",
"\n",
"```\n",
"for each layer L do:\n",
" Initialize weights of L (WL) with some reasonable starting point \n",
" (see the discussion of the Saxe et al. paper below)\n",
" do:\n",
" increment iteration counter Ti++\n",
" do the forward pass with a mini-batch\n",
" calculate the variance of the output of the layer - Var(L(xb))\n",
" Scale the weights WL by sqrt(Var(L(xb))) i.e. WL = WL / sqrt(Var(L(xb)))\n",
" while \n",
" |Var(L(xb)) − 1.0| ≥ some tolerance and the Ti < max iterations```\n",
"\n",
"The function below is my implementation of the above pseudo-code."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"def LSUV(model, tol_var=0.01, t_max=100):\n",
" o = x\n",
" for m in model:\n",
" if hasattr(m,'weight'):\n",
" t = 0\n",
" u = m(o)\n",
" while (u.var() - 1).abs() > tol_var and t < t_max:\n",
" t += 1\n",
" m.weight.data = m.weight.data/u.std()\n",
" u = m(o)\n",
" o = u\n",
" else:\n",
" o = m(o)\n",
" return model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Orthogonal Initialization for Convolutional Layers\n",
"\n",
"The idea of using an orthogonal initialization was introduced by Andrew Saxe et al. in their paper *Exact solutions to the nonlinear dynamics of learning in deep linear neural networks* (https://arxiv.org/abs/1312.6120). The idea here is to choose the initial values of the weight matrices such that the dot product of any distinct rows is zero while each row has norm 1. This can be written in terms of the *Kronecker Delta* $\\delta_{ij}$:\n",
"\n",
"$$\\delta_{ij} = \\begin{cases}\n",
"0 &\\text{if } i \\neq j, \\\\\n",
"1 &\\text{if } i=j. \\end{cases}$$\n",
"\n",
"So we want for each pair of rows $\\mathbf w_{i}$, $ \\mathbf w_{j}$ of the weight matrix:\n",
"\n",
"$${\\mathbf w_{i}}^\\mathrm{T}\\mathbf w_{j} = \\delta_{ij}$$\n",
"\n",
"There are a couple of properties of rows that have this property that are useful in the context of neural networks:\n",
"- They preserve scale, so for some vector $\\mathbf x$ we have $\\|\\mathbf W \\mathbf x\\|=\\|\\mathbf x\\|$. \n",
"\n",
"This will help with maintaining the scale of outputs through affine functions in the network\n",
"\n",
"- The rows are orthogonal to each other\n",
"\n",
"The intuitive benefit of this property is that different rows of the weight matrix will learn different features of the inputs.\n",
"\n",
"### Convolutional layers\n",
"\n",
"The idea of the initialization with the weight matrices of convolutional networks is similar but needs some tweaking. Each $k \\times k$ kernel $\\mathbf W$ is multiplied elementwise with a $k \\times k$ portion of our input $\\mathbf X$ and then summed up. If we flatten out these $k \\times k$ matrices into length $k^2$ vectors we can think of this part of the covolution as the dot product of these two length $k^2$ vectors.\n",
"\n",
"We would like the different kernels for the output channels of each kernel to learn different things. Again intuitively we want to choose the $k^2$ vector representations of kernels from each channel to be orthogonal to each other.\n",
"\n",
"One way to achieve this is to start with a matrix with random weights and use singular value decomposition (https://en.wikipedia.org/wiki/Singular_value_decomposition) to extract an matrix with orthogonal rows of the required shape. For our purposes what is important about singular value decomposition is that it allows you to express a matrix as a product of an orthonormal matrix $\\mathbf{M}$, a diagonal matrix $\\boldsymbol{\\Sigma}$ and another orthonormal matrix $\\mathbf{V}^{\\mathrm{T}$.\n",
"\n",
"$$\\mathbf{M} = \\mathbf{U} \\boldsymbol{\\Sigma} \\mathbf{V}^{\\mathrm{T}}$$ \n",
"\n",
"We can take the rows of $\\mathbf{V}^{\\mathrm{T}}$, reshape them to be $k \\times k$ and use them as initialisations for our kernels.\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"class OrthInitConv2D(torch.nn.Conv2d):\n",
" def __init__(self, in_channels, out_channels, kernel_size, stride=1,\n",
" padding=0, dilation=1, groups=1, bias=True):\n",
" super().__init__(in_channels, out_channels, kernel_size, stride,\n",
" padding, dilation, groups, bias)\n",
"\n",
" def reset_parameters(self):\n",
" with torch.no_grad():\n",
" self.weight.normal_(0,1)\n",
" self.bias.zero_()\n",
" W = self.weight.data.view([self.weight.shape[0],-1])\n",
" _, _, Vt = torch.svd(W)\n",
" self.weight.data = torch.Tensor(Vt).view(self.weight.shape)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Simple example using this initialisation"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Get MNIST data and normalize it"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"MNIST_URL='http://deeplearning.net/data/mnist/mnist.pkl'\n",
"\n",
"def get_data():\n",
" path = datasets.download_data(MNIST_URL, ext='.gz')\n",
" with gzip.open(path, 'rb') as f:\n",
" ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding='latin-1')\n",
" return map(torch.tensor, (x_train,y_train,x_valid,y_valid))\n",
"\n",
"def normalize(x, m, s): return (x-m)/s"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"x_train,y_train,x_valid,y_valid = get_data()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"train_mean,train_std = x_train.mean(),x_train.std()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"x_train = normalize(x_train, train_mean, train_std)\n",
"x_valid = normalize(x_valid, train_mean, train_std)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Get a batch of 1000 inputs"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"x = x_train[:1000].view([-1,1,28,28])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Create a simple CNN"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"class Lambda(torch.nn.Module):\n",
" def __init__(self, func):\n",
" super().__init__()\n",
" self.func = func\n",
"\n",
" def forward(self, x): return self.func(x)\n",
"\n",
"def flatten(x): return x.view(x.shape[0], -1)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"def get_model(convtype=torch.nn.Conv2d, extra_depth=1):\n",
" model = torch.nn.Sequential(\n",
" convtype(1,8,5,stride=2,padding=2),\n",
" convtype(8,16,3,stride=2,padding=1),\n",
" convtype(16,32,3,stride=2,padding=1),\n",
" *[convtype(32,32,3,stride=2,padding=1) for i in range(extra_depth)]\n",
" )\n",
" return model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Compare variance from PyTorch `nn.Conv2d` init and `LSUV` init"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Standard PyTorch `nn.Conv2d` init"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"shallow_pytorch_stds = [get_model()(x).std().item() for i in range(100)]\n",
"deep_pytorch_stds = [get_model(extra_depth=30)(x).std().item() for i in range(100)]\n",
"shallow_orthnormal_stds = [LSUV(get_model(convtype=OrthInitConv2D))(x).std().item() for i in range(100)]\n",
"deep_orthnormal_stds = [LSUV(get_model(convtype=OrthInitConv2D, extra_depth=30))(x).std().item() for i in range(100)]"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 1440x1080 with 4 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"_, ax = plt.subplots(2,2, figsize=(1,15))\n",
"\n",
"ax[0,0].hist(shallow_pytorch_stds, bins=10, density=False);\n",
"ax[0,0].set_title('Shallow Pytorch Network Standard Deviation');\n",
"\n",
"ax[0,1].hist(shallow_orthnormal_stds, bins=10, density=False, range=[1-0.0001,1+0.0001]);\n",
"ax[0,1].set_title('Shallow Orthnormal Network Standard Deviation');\n",
"\n",
"ax[1,0].hist(deep_pytorch_stds, bins=10, density=False);\n",
"ax[1,0].set_title('Deep Pytorch Network Standard Deviation');\n",
"\n",
"ax[1,1].hist(deep_orthnormal_stds, bins=10, density=False, range=[1-0.0001,1+0.0001]);\n",
"ax[1,1].set_title('Deep Orthnormal Network Standard Deviation');"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.1"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment