Skip to content

Instantly share code, notes, and snippets.

@sgugger
Created September 5, 2018 01:54
Show Gist options
  • Save sgugger/84b7710b9ff6645940719c4cd546afdb to your computer and use it in GitHub Desktop.
Save sgugger/84b7710b9ff6645940719c4cd546afdb to your computer and use it in GitHub Desktop.
Notebooks/Weight_drop.ipynb
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "%load_ext autoreload\n%autoreload 2",
"execution_count": 1,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "import torch\nimport torch.nn as nn\nimport torch.optim as optim\nimport torch.nn.functional as F",
"execution_count": 2,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "class WeightDropout(nn.Module):\n \"A module that warps another layer in which some weights will be replaced by 0 during training.\"\n \n def __init__(self, module, dropout, layer_names=['weight_hh_l0']):\n super().__init__()\n self.module,self.dropout,self.layer_names = module,dropout,layer_names\n \n def _setweights(self):\n for layer in self.layer_names:\n raw_w = getattr(self, f'{layer}_raw')\n w1 = F.dropout(raw_w, p=self.dropout, training=self.training)\n #Hacky version: replaces the parameter named layer by a tensor\n #In 0.4.1: works fine\n #In Master: will return an error \"got an incorrect number of RNN parameters\"\n #What we need is some way to replace the parameter named layers by this new value while keeping the\n #graph history so that the gradients of raw_w are computed then raw_w is updated in the optimizer.\n del self.module._parameters[layer]\n setattr(self.module, layer, w1)\n \n def forward(self, *args):\n self._setweights()\n return self.module.forward(*args)\n \n def reset(self):\n for layer in self.layer_names:\n #Makes a copy of the weights of the selected layers.\n w = getattr(self.module, layer)\n self.register_parameter(f'{layer}_raw', nn.Parameter(w.data))\n if hasattr(self.module, 'reset'): self.module.reset()\n \n def update_raw(self):\n for layer in self.layer_names:\n w = getattr(self.module, layer)\n mask = w != 0.\n self.raw_weights[layer][mask] = w[mask] * (1-self.dropout)",
"execution_count": 3,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "module = nn.LSTM(20, 20)\ndp_module = WeightDropout(module, 0.5)\ndp_module.reset()\nopt = optim.SGD(dp_module.parameters(), 10)\ndp_module.train()",
"execution_count": 4,
"outputs": [
{
"data": {
"text/plain": "WeightDropout(\n (module): LSTM(20, 20)\n)"
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "Error will come here in master"
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "x = torch.randn(2,5,20)\nx.requires_grad_(requires_grad=True)\nh = (torch.zeros(1,5,20), torch.zeros(1,5,20))\nx,h = dp_module(x,h)",
"execution_count": 5,
"outputs": []
},
{
"metadata": {},
"cell_type": "markdown",
"source": "If no error in 0.4.1, checking w_raw has gradients."
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "target = torch.randint(0,20,(10,)).long()\nloss = F.nll_loss(x.view(-1,20), target)\nloss.backward()\nopt.step()",
"execution_count": 6,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "w, w_raw = getattr(dp_module.module, 'weight_hh_l0'),getattr(dp_module,'weight_hh_l0_raw')\nw.grad, w_raw.grad",
"execution_count": 7,
"outputs": [
{
"data": {
"text/plain": "(None, tensor([[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n ...,\n [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n [ 0.0000, 0.0002, -0.0000, ..., -0.0000, 0.0002, -0.0003]]))"
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "",
"execution_count": null,
"outputs": []
}
],
"metadata": {
"kernelspec": {
"name": "python3",
"display_name": "Python 3",
"language": "python"
},
"language_info": {
"name": "python",
"version": "3.6.6",
"mimetype": "text/x-python",
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"pygments_lexer": "ipython3",
"nbconvert_exporter": "python",
"file_extension": ".py"
},
"gist": {
"id": "",
"data": {
"description": "Notebooks/Weight_drop.ipynb",
"public": true
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment