Skip to content

Instantly share code, notes, and snippets.

@ceshine
Last active September 30, 2017 02:13
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ceshine/5b0c194e63904de4fa69b75154b3dcc4 to your computer and use it in GitHub Desktop.
Save ceshine/5b0c194e63904de4fa69b75154b3dcc4 to your computer and use it in GitHub Desktop.
Key Code Blocks of Pytorch RNN Dropout Implementation
# https://github.com/salesforce/awd-lstm-lm/blob/dfd3cb0235d2caf2847a4d53e1cbd495b781b5d2/locked_dropout.py#L5
class LockedDropout(nn.Module):
# ...
def forward(self, x, dropout=0.5):
if not self.training or not dropout:
return x
m = x.data.new(1, x.size(1), x.size(2)).bernoulli_(1 - dropout)
mask = Variable(m, requires_grad=False) / (1 - dropout)
mask = mask.expand_as(x)
return mask * x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment