Skip to content

Instantly share code, notes, and snippets.

@ceshine
Last active September 30, 2017 03:47
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ceshine/a48f66985cb9c39b8ae09eafe37beb40 to your computer and use it in GitHub Desktop.
Save ceshine/a48f66985cb9c39b8ae09eafe37beb40 to your computer and use it in GitHub Desktop.
Key Code Blocks of Pytorch RNN Dropout Implementation
# https://github.com/salesforce/awd-lstm-lm/blob/dfd3cb0235d2caf2847a4d53e1cbd495b781b5d2/embed_regularize.py#L6
def embedded_dropout(embed, words, dropout=0.1, scale=None):
if dropout:
mask = embed.weight.data.new().resize_((embed.weight.size(0), 1)).bernoulli_(1 - dropout).expand_as(embed.weight) / (1 - dropout)
mask = Variable(mask)
masked_embed_weight = mask * embed.weight
else:
masked_embed_weight = embed.weight
if scale:
masked_embed_weight = scale.expand_as(masked_embed_weight) * masked_embed_weight
padding_idx = embed.padding_idx
if padding_idx is None:
padding_idx = -1
X = embed._backend.Embedding.apply(words, masked_embed_weight,
padding_idx, embed.max_norm, embed.norm_type,
embed.scale_grad_by_freq, embed.sparse
)
return X
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment