Skip to content

Instantly share code, notes, and snippets.

@sajadn
sajadn / data_dependent_weight_norm_init.py
Last active April 29, 2020 18:38
Data dependent initialization of weight norm layers (convolutional)
import torch.nn as nn
from torch import _weight_norm
from torch.nn.utils import weight_norm
def data_dependent_init(model, data_batch):
def init_hook_(module, input, output):
std, mean = torch.std_mean(output, dim=[0, 2, 3])
g = getattr(module, 'weight_g')
g.data = g.data/std.reshape((len(std), 1, 1, 1))