Skip to content

Instantly share code, notes, and snippets.

What would you like to do?
pytorch: weights initialization
import torch
from torch.autograd import Variable
# new way with `init` module
w = torch.Tensor(3, 5)
# work for Variables also
w2 = Variable(w)
# old styled direct access to tensors data attribute
# example for some module
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:, 0.02)
elif classname.find('BatchNorm') != -1:, 0.02)
# for loop approach with direct access
class MyModel(nn.Module):
def __init__(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
elif isinstance(m, nn.Linear):
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.