Skip to content

Instantly share code, notes, and snippets.

@davidnvq
Created June 5, 2019 09:55
Show Gist options
  • Save davidnvq/94a305570891ec7fc422b9886d378636 to your computer and use it in GitHub Desktop.
Save davidnvq/94a305570891ec7fc422b9886d378636 to your computer and use it in GitHub Desktop.
Test Dropout
import torch
import torch.nn as nn
nn.Dropout(p=0.0).eval()(torch.ones(4, 10))
"""Output
tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])
"""
nn.Dropout(p=1.0)(torch.ones(4, 10))
"""Output
tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])
"""
nn.Dropout(p=0.0)(torch.ones(4, 10))
"""Output
tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])
"""
nn.Dropout(p=0.5)(torch.ones(4, 10))
"""Output
tensor([[2., 2., 0., 0., 0., 2., 0., 2., 2., 2.],
[0., 0., 2., 0., 2., 0., 0., 2., 2., 2.],
[2., 0., 2., 2., 0., 0., 2., 2., 2., 2.],
[2., 2., 2., 2., 0., 2., 0., 2., 2., 2.]])
"""
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment