Skip to content

Instantly share code, notes, and snippets.

@nbertagnolli
Last active January 7, 2022 23:56
Show Gist options
  • Save nbertagnolli/9533264b6c6bd03a2f0012bf242db13a to your computer and use it in GitHub Desktop.
Save nbertagnolli/9533264b6c6bd03a2f0012bf242db13a to your computer and use it in GitHub Desktop.
The most basic form of dropout
class Dropout(torch.nn.Module):
def __init__(self, p: float=0.5):
super(Dropout, self).__init__()
self.p = p
if self.p < 0 or self.p > 1:
raise ValueError("p must be a probability")
def forward(self, x):
if self.training:
x = x.mul(torch.empty(x.size()[1]).uniform_(0, 1) >= self.p)
return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment