Skip to content

Instantly share code, notes, and snippets.

@ikhlestov
Created September 12, 2017 17:18
Show Gist options
  • Save ikhlestov/031e0f4e83b968cede8df1d19f3d4714 to your computer and use it in GitHub Desktop.
Save ikhlestov/031e0f4e83b968cede8df1d19f3d4714 to your computer and use it in GitHub Desktop.
pytorch: weights initialization
import torch
from torch.autograd import Variable
# new way with `init` module
w = torch.Tensor(3, 5)
torch.nn.init.normal(w)
# work for Variables also
w2 = Variable(w)
torch.nn.init.normal(w2)
# old styled direct access to tensors data attribute
w2.data.normal_()
# example for some module
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
# for loop approach with direct access
class MyModel(nn.Module):
def __init__(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.bias.data.zero_()
@febriy
Copy link

febriy commented Jul 3, 2019

Hi there,

If you don't mind sharing, may I know what is happening in the code here:

example for some module

def weights_init(m):
classname = m.class.name
if classname.find('Conv') != -1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)

Thank you!

@illarion-rl
Copy link

illarion-rl commented Jul 3, 2019

Hello @febriy

It's just an example function, that can be applied to the whole network and initialize corresponding layer accordingly(in this case - convolution and batchNorm). Here is an example:

net = nn.Sequential(
    nn.Linear(2, 2),
    nn.Conv2d(1, 20, 5),
    nn.BatchNorm(20),
)
net.apply(weights_init)

In the code above Conv2d and BatchNorm layers will be reinitialized by weights_init function.

@ahmedghali
Copy link

why not Linear layer too?

@ikhlestov
Copy link
Author

@ghaliahmed If you asked why I don't initialize linear layer - this is just because I use that code as an example, not as a production one.
Or you've mentioned something else?

@ahmedghali
Copy link

@ikhlestov thank's for your reponse

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment