Skip to content

Instantly share code, notes, and snippets.

@JohnGiorgi
Created August 20, 2019 13:07
Show Gist options
  • Save JohnGiorgi/c030de1dd8cb84ad0970d1cc87e2ed86 to your computer and use it in GitHub Desktop.
Save JohnGiorgi/c030de1dd8cb84ad0970d1cc87e2ed86 to your computer and use it in GitHub Desktop.
A torch module for word dropout. Will randomly replace some words with a specified word (e.g. an UNK token).
import torch
from torch.nn.modules.dropout import _DropoutNd
class WordDropout(_DropoutNd):
"""During training, randomly replaces some of the elements of the input tensor with
`dropout_constant` with probability `p` using samples from a Bernoulli distriution. Each channel
will be zerored out independently on every forward call.
Input is expected to be a 2D tensor of indices representing tokenized sentences.
During evaluation, this module is a no-op, returning `input`.
Args:
p (float): probability of an element to be replaced with `dropout_constant`. Default: 0.1.
dropout_constant (int): Value to replace dropped out elements with.
Shape:
- Input: `(N, T)`. `N` is the batch dimension and `T` is the number of indices per sample.
- Output: `(N, T)`. Output is of the same shape as input
Examples:
>>> bs = 32
>>> sent_len = 512 # Max len of padded sentences
>>> V = 10000 # Vocab size
>>> m = nn.WordDropout(p=0.2)
>>> input = torch.randint(0, V, (bs, sent_len))
>>> output = m(input)
"""
def __init__(self, p=0.1, dropout_constant=1):
super(WordDropout, self).__init__(p)
self.dropout_constant = dropout_constant
def forward(self, input):
if not self.training or not self.p:
return input
keep = torch.empty_like(input).bernoulli_(1 - self.p).bool()
input = torch.where(keep, input, torch.empty_like(input).fill_(self.dropout_constant))
return input
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment