Skip to content

Instantly share code, notes, and snippets.

@sgugger
Created September 5, 2018 01:36
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sgugger/8a751a0539ad408ffbeb9cae9fe9b622 to your computer and use it in GitHub Desktop.
Save sgugger/8a751a0539ad408ffbeb9cae9fe9b622 to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
class WeightDropout(nn.Module):
"A module that warps another layer in which some weights will be replaced by 0 during training."
def __init__(self, module, dropout, layer_names=['weight_hh_l0']):
super().__init__()
self.module,self.dropout,self.layer_names = module,dropout,layer_names
def _setweights(self):
for layer in self.layer_names:
raw_w = getattr(self, f'{layer}_raw')
w1 = F.dropout(raw_w, p=self.dropout, training=self.training)
#Hacky version: replaces the parameter named layer by a tensor
#In 0.4.1: works fine
#In Master: will return an error "got an incorrect number of RNN parameters"
#What we need is some way to replace the parameter named layer by this new value w1 while keeping the
#graph history so that the gradients of raw_w are computed, and then raw_w is updated in the optimizer.
del self.module._parameters[layer]
setattr(self.module, layer, w1)
def forward(self, *args):
self._setweights()
return self.module.forward(*args)
def reset(self):
for layer in self.layer_names:
#Makes a copy of the weights of the selected layers.
w = getattr(self.module, layer)
self.register_parameter(f'{layer}_raw', nn.Parameter(w.data))
if hasattr(self.module, 'reset'): self.module.reset()
def update_raw(self):
for layer in self.layer_names:
w = getattr(self.module, layer)
mask = w != 0.
self.raw_weights[layer][mask] = w[mask] * (1-self.dropout)
module = nn.LSTM(20, 20)
dp_module = WeightDropout(module, 0.5)
dp_module.reset()
opt = optim.SGD(dp_module.parameters(), 10)
dp_module.train()
x = torch.randn(2,5,20)
x.requires_grad_(requires_grad=True)
h = (torch.zeros(1,5,20), torch.zeros(1,5,20))
#Error will come here in Master
x,h = dp_module(x,h)
target = torch.randint(0,20,(10,)).long()
loss = F.nll_loss(x.view(-1,20), target)
loss.backward()
opt.step()
w, w_raw = getattr(dp_module.module, 'weight_hh_l0'),getattr(dp_module,'weight_hh_l0_raw')
print(w.grad)
print(w_raw.grad)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment