Skip to content

Instantly share code, notes, and snippets.

@nbertagnolli
Last active January 7, 2022 23:56
Show Gist options
  • Save nbertagnolli/bc8e6ac2daa34c48cef21ba9a771767b to your computer and use it in GitHub Desktop.
Save nbertagnolli/bc8e6ac2daa34c48cef21ba9a771767b to your computer and use it in GitHub Desktop.
A real implementation of dropout with proper normalization.
class TrueDropout(torch.nn.Module):
def __init__(self, p: float=0.5):
super(TrueDropout, self).__init__()
self.p = p
if self.p < 0 or self.p > 1:
raise ValueError("p must be a probability")
def forward(self, x):
if self.training:
x = x.mul(torch.empty(x.size()[1]).uniform_(0, 1) >= self.p) * (1 / (1 - self.p))
return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment