Skip to content

Instantly share code, notes, and snippets.

@rtqichen
Last active May 11, 2023 06:58
Show Gist options
  • Save rtqichen/b22a9c6bfc4f36e605a7b3ac1ab4122f to your computer and use it in GitHub Desktop.
Save rtqichen/b22a9c6bfc4f36e605a7b3ac1ab4122f to your computer and use it in GitHub Desktop.
Pytorch weight normalization - works for all nn.Module (probably)
## Weight norm is now added to pytorch as a pre-hook, so use that instead :)
import torch
import torch.nn as nn
from torch.nn import Parameter
from functools import wraps
class WeightNorm(nn.Module):
append_g = '_g'
append_v = '_v'
def __init__(self, module, weights):
super(WeightNorm, self).__init__()
self.module = module
self.weights = weights
self._reset()
def _reset(self):
for name_w in self.weights:
w = getattr(self.module, name_w)
# construct g,v such that w = g/||v|| * v
g = torch.norm(w)
v = w/g.expand_as(w)
g = Parameter(g.data)
v = Parameter(v.data)
name_g = name_w + self.append_g
name_v = name_w + self.append_v
# remove w from parameter list
del self.module._parameters[name_w]
# add g and v as new parameters
self.module.register_parameter(name_g, g)
self.module.register_parameter(name_v, v)
def _setweights(self):
for name_w in self.weights:
name_g = name_w + self.append_g
name_v = name_w + self.append_v
g = getattr(self.module, name_g)
v = getattr(self.module, name_v)
w = v*(g/torch.norm(v)).expand_as(v)
setattr(self.module, name_w, w)
def forward(self, *args):
self._setweights()
return self.module.forward(*args)
##############################################################
## An older version using a python decorator but might be buggy.
## Does not work when the module is replicated (e.g. nn.DataParallel)
def _decorate(forward, module, name, name_g, name_v):
@wraps(forward)
def decorated_forward(*args, **kwargs):
g = module.__getattr__(name_g)
v = module.__getattr__(name_v)
w = v*(g/torch.norm(v)).expand_as(v)
module.__setattr__(name, w)
return forward(*args, **kwargs)
return decorated_forward
def weight_norm(module, name):
param = module.__getattr__(name)
# construct g,v such that w = g/||v|| * v
g = torch.norm(param)
v = param/g.expand_as(param)
g = Parameter(g.data)
v = Parameter(v.data)
name_g = name + '_g'
name_v = name + '_v'
# remove w from parameter list
del module._parameters[name]
# add g and v as new parameters
module.register_parameter(name_g, g)
module.register_parameter(name_v, v)
# construct w every time before forward is called
module.forward = _decorate(module.forward, module, name, name_g, name_v)
return module
import torch
import torch.nn as nn
from pytorch_weight_norm import WeightNorm
x = torch.autograd.Variable(torch.randn(5,10,30,30))
m = nn.ConvTranspose2d(10,20,3)
y = m(x)
print(m._parameters.keys())
# odict_keys(['weight', 'bias'])
m = WeightNorm(m, ['weight'])
y_wn = m(x)
print(m.module._parameters.keys())
# odict_keys(['bias', 'weight_g', 'weight_v'])
print(torch.norm(y-y_wn).data[0])
# 1.3324766769073904e-05 (not important to get this smaller)
## can also use within sequential
## and is also stackable
net = nn.Sequential(
WeightNorm(nn.Linear(30,10), ['weight']),
nn.ReLU(),
WeightNorm(nn.Linear(10,20), ['weight', 'bias']),
)
@ypxie
Copy link

ypxie commented Aug 26, 2017

hello, thanks for sharing this elegant implementation. Where could I find the newer updated version?
Thanks!

@xwuaustin
Copy link

xwuaustin commented Sep 13, 2017

@rtqichen Thanks for contribution for this code.
@ all @greaber @Smerity @ypxie @ hanzhanggit
Hi everyone who read this post. I have some questions regarding to weight_norm. It would be great if you can help.
I tried to implement the weight_norm for each convolution and linear layer (check the code here https://github.com/xwuaustin/weight_norm/blob/master/cifar10_tutorial_weightNorm.py ). However, the training loss in CIFAR-10 seems no difference to the original setting (see the picture below) at the first 10 epochs (6 iterations equal to 1 epoch).

Now questions:

1. Is there something wrong with the code I modified? I used the code from cifar10_tutorial in pytorch. All I did is to add the wieghtNorm at each layer.

import torch.nn.utils.weight_norm as weightNorm
class Net(nn.Module):
def init(self):
super(Net, self).init()
### we use weight normalization after each convolutions and linear transfrom
self.conv1 = weightNorm(nn.Conv2d(3, 6, 5),name = "weight")
#print (self.conv1._parameters.keys())
self.pool = nn.MaxPool2d(2, 2)
self.conv2 =weightNorm(nn.Conv2d(6, 16, 5),name = "weight")
self.fc1 = weightNorm(nn.Linear(16 * 5 * 5, 120),name = "weight")
self.fc2 = weightNorm(nn.Linear(120, 84),name = "weight")
self.fc3 = weightNorm(nn.Linear(84, 10),name = "weight")

2 Is the update of the weights and bias, namely 'weight_g', 'weight_v', using the formulation:

3. Can we do the initialization as the paper suggested?

Thanks. Looking for your responds. :)

@foowaa
Copy link

foowaa commented Oct 21, 2021

Excellent work to solve weight_norm(...) in deep copy problem!
Thank you

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment