Skip to content

Instantly share code, notes, and snippets.

@sgugger
Last active September 5, 2018 19:52
Show Gist options
  • Save sgugger/f35d53e4d50f2e1cea0f4ffaf00298c3 to your computer and use it in GitHub Desktop.
Save sgugger/f35d53e4d50f2e1cea0f4ffaf00298c3 to your computer and use it in GitHub Desktop.
Notebooks/Weight_drop.ipynb
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "%load_ext autoreload\n%autoreload 2",
"execution_count": 1,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "import torch\nimport torch.nn as nn\nimport torch.optim as optim\nimport torch.nn.functional as F",
"execution_count": 2,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "class WeightDropout(nn.Module):\n \"A module that warps another layer in which some weights will be replaced by 0 during training.\"\n \n def __init__(self, module, dropout, layer_names=['weight_hh_l0']):\n super().__init__()\n self.module,self.dropout,self.layer_names = module,dropout,layer_names\n \n def _setweights(self):\n for layer in self.layer_names:\n raw_w = getattr(self, f'{layer}_raw')\n self.module._parameters[layer] = F.dropout(raw_w, p=self.dropout, training=self.training)\n \n def forward(self, *args):\n self._setweights()\n return self.module.forward(*args)\n \n def reset(self):\n for layer in self.layer_names:\n #Makes a copy of the weights of the selected layers.\n w = getattr(self.module, layer)\n self.register_parameter(f'{layer}_raw', nn.Parameter(w.data))\n if hasattr(self.module, 'reset'): self.module.reset()",
"execution_count": 3,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "module = nn.LSTM(20, 20)\ndp_module = WeightDropout(module, 0.5)\ndp_module.reset()\nopt = optim.SGD(dp_module.parameters(), 10)\ndp_module.train()",
"execution_count": 4,
"outputs": [
{
"data": {
"text/plain": "WeightDropout(\n (module): LSTM(20, 20)\n)"
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "x = torch.randn(2,5,20)\nx.requires_grad_(requires_grad=True)\nh = (torch.zeros(1,5,20), torch.zeros(1,5,20))\nx,h = dp_module(x,h)",
"execution_count": 5,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "target = torch.randint(0,20,(10,)).long()\nloss = F.nll_loss(x.view(-1,20), target)\nloss.backward()\nopt.step()",
"execution_count": 6,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "w, w_raw = getattr(dp_module.module, 'weight_hh_l0'),getattr(dp_module,'weight_hh_l0_raw')\nw.grad, w_raw.grad",
"execution_count": 7,
"outputs": [
{
"data": {
"text/plain": "(None, tensor([[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n [-0.0010, 0.0003, 0.0000, ..., 0.0005, -0.0000, 0.0000],\n [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n ...,\n [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]]))"
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "",
"execution_count": null,
"outputs": []
}
],
"metadata": {
"_draft": {
"nbviewer_url": "https://gist.github.com/f35d53e4d50f2e1cea0f4ffaf00298c3"
},
"gist": {
"id": "f35d53e4d50f2e1cea0f4ffaf00298c3",
"data": {
"description": "Notebooks/Weight_drop.ipynb",
"public": true
}
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3",
"language": "python"
},
"language_info": {
"name": "python",
"version": "3.6.6",
"mimetype": "text/x-python",
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"pygments_lexer": "ipython3",
"nbconvert_exporter": "python",
"file_extension": ".py"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment