Created
August 20, 2019 13:07
-
-
Save JohnGiorgi/c030de1dd8cb84ad0970d1cc87e2ed86 to your computer and use it in GitHub Desktop.
A torch module for word dropout. Will randomly replace some words with a specified word (e.g. an UNK token).
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch.nn.modules.dropout import _DropoutNd | |
class WordDropout(_DropoutNd): | |
"""During training, randomly replaces some of the elements of the input tensor with | |
`dropout_constant` with probability `p` using samples from a Bernoulli distriution. Each channel | |
will be zerored out independently on every forward call. | |
Input is expected to be a 2D tensor of indices representing tokenized sentences. | |
During evaluation, this module is a no-op, returning `input`. | |
Args: | |
p (float): probability of an element to be replaced with `dropout_constant`. Default: 0.1. | |
dropout_constant (int): Value to replace dropped out elements with. | |
Shape: | |
- Input: `(N, T)`. `N` is the batch dimension and `T` is the number of indices per sample. | |
- Output: `(N, T)`. Output is of the same shape as input | |
Examples: | |
>>> bs = 32 | |
>>> sent_len = 512 # Max len of padded sentences | |
>>> V = 10000 # Vocab size | |
>>> m = nn.WordDropout(p=0.2) | |
>>> input = torch.randint(0, V, (bs, sent_len)) | |
>>> output = m(input) | |
""" | |
def __init__(self, p=0.1, dropout_constant=1): | |
super(WordDropout, self).__init__(p) | |
self.dropout_constant = dropout_constant | |
def forward(self, input): | |
if not self.training or not self.p: | |
return input | |
keep = torch.empty_like(input).bernoulli_(1 - self.p).bool() | |
input = torch.where(keep, input, torch.empty_like(input).fill_(self.dropout_constant)) | |
return input |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment