Skip to content

Instantly share code, notes, and snippets.

@rtqichen
Last active May 11, 2023 06:58
Show Gist options
  • Save rtqichen/b22a9c6bfc4f36e605a7b3ac1ab4122f to your computer and use it in GitHub Desktop.
Save rtqichen/b22a9c6bfc4f36e605a7b3ac1ab4122f to your computer and use it in GitHub Desktop.
Pytorch weight normalization - works for all nn.Module (probably)
## Weight norm is now added to pytorch as a pre-hook, so use that instead :)
import torch
import torch.nn as nn
from torch.nn import Parameter
from functools import wraps
class WeightNorm(nn.Module):
append_g = '_g'
append_v = '_v'
def __init__(self, module, weights):
super(WeightNorm, self).__init__()
self.module = module
self.weights = weights
self._reset()
def _reset(self):
for name_w in self.weights:
w = getattr(self.module, name_w)
# construct g,v such that w = g/||v|| * v
g = torch.norm(w)
v = w/g.expand_as(w)
g = Parameter(g.data)
v = Parameter(v.data)
name_g = name_w + self.append_g
name_v = name_w + self.append_v
# remove w from parameter list
del self.module._parameters[name_w]
# add g and v as new parameters
self.module.register_parameter(name_g, g)
self.module.register_parameter(name_v, v)
def _setweights(self):
for name_w in self.weights:
name_g = name_w + self.append_g
name_v = name_w + self.append_v
g = getattr(self.module, name_g)
v = getattr(self.module, name_v)
w = v*(g/torch.norm(v)).expand_as(v)
setattr(self.module, name_w, w)
def forward(self, *args):
self._setweights()
return self.module.forward(*args)
##############################################################
## An older version using a python decorator but might be buggy.
## Does not work when the module is replicated (e.g. nn.DataParallel)
def _decorate(forward, module, name, name_g, name_v):
@wraps(forward)
def decorated_forward(*args, **kwargs):
g = module.__getattr__(name_g)
v = module.__getattr__(name_v)
w = v*(g/torch.norm(v)).expand_as(v)
module.__setattr__(name, w)
return forward(*args, **kwargs)
return decorated_forward
def weight_norm(module, name):
param = module.__getattr__(name)
# construct g,v such that w = g/||v|| * v
g = torch.norm(param)
v = param/g.expand_as(param)
g = Parameter(g.data)
v = Parameter(v.data)
name_g = name + '_g'
name_v = name + '_v'
# remove w from parameter list
del module._parameters[name]
# add g and v as new parameters
module.register_parameter(name_g, g)
module.register_parameter(name_v, v)
# construct w every time before forward is called
module.forward = _decorate(module.forward, module, name, name_g, name_v)
return module
import torch
import torch.nn as nn
from pytorch_weight_norm import WeightNorm
x = torch.autograd.Variable(torch.randn(5,10,30,30))
m = nn.ConvTranspose2d(10,20,3)
y = m(x)
print(m._parameters.keys())
# odict_keys(['weight', 'bias'])
m = WeightNorm(m, ['weight'])
y_wn = m(x)
print(m.module._parameters.keys())
# odict_keys(['bias', 'weight_g', 'weight_v'])
print(torch.norm(y-y_wn).data[0])
# 1.3324766769073904e-05 (not important to get this smaller)
## can also use within sequential
## and is also stackable
net = nn.Sequential(
WeightNorm(nn.Linear(30,10), ['weight']),
nn.ReLU(),
WeightNorm(nn.Linear(10,20), ['weight', 'bias']),
)
@Smerity
Copy link

Smerity commented Jun 14, 2017

I love this implementation - very elegant :) I'll likely steal this pattern for future code!

The only issue that I've run into is that calling torch.save(model, f) results in:
_pickle.PicklingError: Can't pickle <function RNNBase.forward at 0x7f6d6c9fc9d8>: it's not the same object as torch.nn.modules.rnn.RNNBase.forward

I'm not certain if there's a way to make Pickle or PyTorch happy with the monkey patching that's involved. I'd imagine this would be a problem if you saved the weights directly rather than saving a pickled version of the file but I'm uncertain.

@greaber
Copy link

greaber commented Jun 18, 2017

I am seeing the memory usage of my model more than quadruple when using this code (wrapping linear layers and GRUCells), but I can't figure out what is causing it.

@Smerity, I didn't encounter the problem you report -- not sure why.

@greaber
Copy link

greaber commented Jun 19, 2017

The problem is that the code was recomputing and allocating new storage for w on every call of forward, which is fine for feed-forward nets but not for RNNs. I made a modified version that only recomputes w the first time forward is called and then after each backprop. I also modified the code so that you can pass a list of parameters to weight_norm and it will wrap all of them. (There are some other differences from the original code that I made, but I don't think they are significant.)

import torch
from torch.nn import Parameter
from functools import wraps

def _make_hook(module, name):
    def hook(module, grad_input, grad_output):
        module.recompute_w[name] = True
        return None
    return hook

def _decorate(forward, module, name):
    @wraps(forward)
    def decorated_forward(*args, **kwargs):
        g = getattr(module, name + '_g')
        v = getattr(module, name + '_v')
        if module.recompute_w[name]:
            setattr(module, name, v*(g/torch.norm(v)).expand_as(v))
            module.recompute_w[name] = False
        return forward(*args, **kwargs)
    return decorated_forward

def weight_norm(module, name):

    if isinstance(name, list):
        for x in name:
            module = weight_norm(module, x)
        return module

    param = getattr(module, name)

    # construct g,v such that w = g/||v|| * v
    g = torch.norm(param.data)
    v = param.data / g
    delattr(module, name)
    setattr(module, name + '_g', Parameter(torch.Tensor([g])))
    setattr(module, name + '_v', Parameter(v))

    # construct w every time before forward is called
    module.forward = _decorate(module.forward, module, name)

    if not hasattr(module, 'recompute_w'):
        module.recompute_w = dict()
    module.recompute_w[name] = True
    module.register_backward_hook(_make_hook(module, name))

    return module

@rtqichen
Copy link
Author

rtqichen commented Jun 22, 2017

Oops, I didn't realize this got popular. Was only meant to temporarily share with a friend, and I've actually had a much better version since.

I've uploaded the newer version (which basically decorates with a pytorch module instead of the builtin decorator) and it should be much more robust (eg. saving should be fine and allows multiple gpu support) but I haven't actually tested it for quite a while.

@Smerity thanks! I also had trouble saving, but at some point it worked... python decorator is likely not the best way to go about it. Thankfully it's not the only way :)

@Graeber that looks like a nice feature! But I've not included it in the gist because of the extra hook required and because I'm a bit rusty from pytorch atm.

@Smerity
Copy link

Smerity commented Jun 22, 2017

@rtqichen - thanks for posting the original and the updated version! Even if only meant for a friend it was certainly appreciated it ^_^

@greaber - I hadn't even thought of that! I presume you're not using the cuDNN LSTM and instead are using an LSTM cell timestep by timestep? If you were using the cuDNN LSTM it'd avoid this issue as it should only be calling the forward once per set of input I think..?

@skaae
Copy link

skaae commented Jun 27, 2017

This breaks printing of modules for conv layers. A quick fix is to add

            if name_w == 'bias':
                self.module.bias = None

to _reset
EDIT: Thanks for sharing your code :)

@hanzhanggit
Copy link

@rtqichen could not find the data dependent init. I thought it was important to the weight norm. Isn't it?

@dlmacedo
Copy link

How to incorporate the Pytorch 0.2.0 support of Weight Normalization in new RNN projects?

http://pytorch.org/docs/master/nn.html#torch.nn.utils.weight_norm

@ypxie
Copy link

ypxie commented Aug 26, 2017

hello, thanks for sharing this elegant implementation. Where could I find the newer updated version?
Thanks!

@xwuaustin
Copy link

xwuaustin commented Sep 13, 2017

@rtqichen Thanks for contribution for this code.
@ all @greaber @Smerity @ypxie @ hanzhanggit
Hi everyone who read this post. I have some questions regarding to weight_norm. It would be great if you can help.
I tried to implement the weight_norm for each convolution and linear layer (check the code here https://github.com/xwuaustin/weight_norm/blob/master/cifar10_tutorial_weightNorm.py ). However, the training loss in CIFAR-10 seems no difference to the original setting (see the picture below) at the first 10 epochs (6 iterations equal to 1 epoch).

Now questions:

1. Is there something wrong with the code I modified? I used the code from cifar10_tutorial in pytorch. All I did is to add the wieghtNorm at each layer.

import torch.nn.utils.weight_norm as weightNorm
class Net(nn.Module):
def init(self):
super(Net, self).init()
### we use weight normalization after each convolutions and linear transfrom
self.conv1 = weightNorm(nn.Conv2d(3, 6, 5),name = "weight")
#print (self.conv1._parameters.keys())
self.pool = nn.MaxPool2d(2, 2)
self.conv2 =weightNorm(nn.Conv2d(6, 16, 5),name = "weight")
self.fc1 = weightNorm(nn.Linear(16 * 5 * 5, 120),name = "weight")
self.fc2 = weightNorm(nn.Linear(120, 84),name = "weight")
self.fc3 = weightNorm(nn.Linear(84, 10),name = "weight")

2 Is the update of the weights and bias, namely 'weight_g', 'weight_v', using the formulation:

3. Can we do the initialization as the paper suggested?

Thanks. Looking for your responds. :)

@foowaa
Copy link

foowaa commented Oct 21, 2021

Excellent work to solve weight_norm(...) in deep copy problem!
Thank you

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment