Skip to content

Instantly share code, notes, and snippets.

@AruniRC
Created December 15, 2017 23:00
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save AruniRC/88ae3f7d8f7b395f3453a7f6bf725905 to your computer and use it in GitHub Desktop.
Save AruniRC/88ae3f7d8f7b395f3453a7f6bf725905 to your computer and use it in GitHub Desktop.
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
pass # leave the default PyTorch init
# m.weight.data.zero_()
# if m.bias is not None:
# m.bias.data.zero_()
if isinstance(m, nn.ConvTranspose2d):
assert m.kernel_size[0] == m.kernel_size[1]
initial_weight = get_upsampling_weight(
m.in_channels, m.out_channels, m.kernel_size[0])
m.weight.data.copy_(initial_weight)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment