Skip to content

Instantly share code, notes, and snippets.

@sajadn
Last active April 29, 2020 18:38
Show Gist options
  • Save sajadn/f29b14cd52023b614c334708c2a4345f to your computer and use it in GitHub Desktop.
Save sajadn/f29b14cd52023b614c334708c2a4345f to your computer and use it in GitHub Desktop.
Data dependent initialization of weight norm layers (convolutional)
import torch.nn as nn
from torch import _weight_norm
from torch.nn.utils import weight_norm
def data_dependent_init(model, data_batch):
def init_hook_(module, input, output):
std, mean = torch.std_mean(output, dim=[0, 2, 3])
g = getattr(module, 'weight_g')
g.data = g.data/std.reshape((len(std), 1, 1, 1))
b = getattr(module, 'bias')
b.data = (b.data - mean)/std
setattr(module, 'weight', _weight_norm(getattr(module, 'weight_v'), g, dim=0))
return module._conv2d_forward(input[0], module.weight)
handles = []
for m in model.modules():
if isinstance(m, torch.nn.Conv2d):
handles.append(m.register_forward_hook(init_hook_))
model(data_batch)
for h in handles:
h.remove()
model = nn.Sequential(
nn.ELU(),
weight_norm(nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1)),
nn.ELU(),
weight_norm(nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, stride=1, padding=1)))
#data_batch = get_batch_of_data()
#make sure the data is zero centered
data_dependent_init(model, data_batch) #normalizes the output of each layer
#Training ...
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment