Skip to content

Instantly share code, notes, and snippets.

@cswhjiang
Created August 25, 2017 06:35
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save cswhjiang/be475ef9a3a7d1f781830ebfb7970719 to your computer and use it in GitHub Desktop.
Save cswhjiang/be475ef9a3a7d1f781830ebfb7970719 to your computer and use it in GitHub Desktop.
import torch
from torch.nn import Parameter
from functools import wraps
class WeightDrop(torch.nn.Module):
def __init__(self, module, weights, dropout=0, variational=False):
super(WeightDrop, self).__init__()
self.module = module
self.weights = weights
self.dropout = dropout
self.variational = variational
self._setup()
def _setup(self):
for name_w in self.weights:
print('Applying weight drop of {} to {}'.format(self.dropout, name_w))
w = getattr(self.module, name_w)
del self.module._parameters[name_w]
self.module.register_parameter(name_w + '_raw', Parameter(w.data))
self.module._all_weights = [list(self.module._parameters.keys())]
def _setweights(self):
for name_w in self.weights:
raw_w = getattr(self.module, name_w + '_raw')
w = None
if self.variational:
mask = torch.autograd.Variable(torch.ones(raw_w.size(0), 1))
if raw_w.is_cuda: mask = mask.cuda()
mask = torch.nn.functional.dropout(mask, p=self.dropout, training=True)
w = mask.expand_as(raw_w) * raw_w
else:
w = torch.nn.functional.dropout(raw_w, p=self.dropout, training=self.training)
setattr(self.module, name_w, w)
def forward(self, *args):
self._setweights()
return self.module.forward(*args)
if __name__ == '__main__':
a = torch.nn.LSTM(10, 20, 1)
w = WeightDrop(a, ['weight_hh_l0'])
print(a._all_weights)
print(a._parameters.keys())
a.cuda()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment