Skip to content

Instantly share code, notes, and snippets.

@ceshine
Last active Sep 30, 2017
Embed
What would you like to do?
Key Code Blocks of Pytorch RNN Dropout Implementation
# https://github.com/salesforce/awd-lstm-lm/blob/dfd3cb0235d2caf2847a4d53e1cbd495b781b5d2/locked_dropout.py#L5
class LockedDropout(nn.Module):
# ...
def forward(self, x, dropout=0.5):
if not self.training or not dropout:
return x
m = x.data.new(1, x.size(1), x.size(2)).bernoulli_(1 - dropout)
mask = Variable(m, requires_grad=False) / (1 - dropout)
mask = mask.expand_as(x)
return mask * x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment