Skip to content

Instantly share code, notes, and snippets.

@tjyuyao
Created July 27, 2020 07:24
Show Gist options
  • Save tjyuyao/a8ea1a231443692a42650ff2003488f8 to your computer and use it in GitHub Desktop.
Save tjyuyao/a8ea1a231443692a42650ff2003488f8 to your computer and use it in GitHub Desktop.
This dropout always select const number of units to dropout.
import torch
def steady_dropout(x, prob=0.2):
assert len(x.shape) == 2, "expected data shape of (batch, feature), while getting " + x.shape
n_batch = x.shape[0]
n_feat = x.shape[1]
n_select = int(round(n_feat * prob))
prob = float(n_select) / float(n_feat)
r = torch.randn(n_batch, n_feat)
_, i = r.sort(dim=1)
j = torch.arange(n_batch).view(-1, 1) * n_feat + i
z = x.flatten()
z[j[:, :n_select].flatten()]=0
x = z.view(x.shape)
return x / prob
if __name__ == '__main__':
n_batch = 5
n_feat = 10
x = torch.randn(n_batch, n_feat)
y = steady_dropout(x)
print(y)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment