-
-
Save ssnl/44b082499381478150abfeabaf2701d2 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class ScaledParamModule(nn.Module): | |
# linear w: [ fan_out, fan_in ] | |
# conv w: [ nc_out, nc_in, k1, k2 ] | |
# convT w: [ nc_in, nc_out, k1, k2 ], but let's ignore this case because | |
# (1) the tf impl doesn't special-case | |
# (2) convT is only used for fusing Upsample & Conv2d, and in that case, the | |
# weight should be done as if it is for a Conv2d. | |
# | |
# NB: in tf code, use_wscale has default value False, but for StyleGAN it is | |
# True everywhere, so I changed it. | |
def scale_weight(self, gain=np.sqrt(2), use_wscale=True, lrmul=1, new_name='_weight'): | |
assert isinstance(self.weight, nn.Parameter) | |
runtime_coef = self.init_weight_(self.weight, gain, use_wscale, lrmul) | |
# add scale hook | |
self.add_scale_hook('weight', new_name, runtime_coef) | |
# helper for scale_weight and reset_parameters | |
def init_weight_(self, weight, gain, use_wscale, lrmul): | |
fan_in = np.prod(weight.shape[1:]) | |
he_std = gain / np.sqrt(fan_in) # He init | |
# Equalized learning rate and custom learning rate multiplier. | |
if use_wscale: | |
init_std = 1.0 / lrmul | |
runtime_coef = he_std * lrmul | |
else: | |
init_std = he_std / lrmul | |
runtime_coef = lrmul | |
# Init variable using He init. | |
weight.data.normal_(0, init_std) | |
return runtime_coef | |
def scale_bias(self, lrmul=1, new_name='_bias'): | |
if self.bias is None: | |
assert not hasattr(self, new_name) | |
# do not delete so we don't have to restore in forward | |
# del self.bias | |
self.register_parameter(new_name, None) | |
return | |
bias = self.bias | |
assert isinstance(bias, nn.Parameter) | |
# zero out | |
bias.data.zero_() | |
# add scale hook | |
self.add_scale_hook('bias', new_name, lrmul) | |
def add_scale_hook(self, name, new_name, coef): | |
param = getattr(self, name) | |
assert isinstance(param, nn.Parameter) | |
assert not hasattr(self, new_name) | |
delattr(self, name) | |
self.register_parameter(new_name, param) | |
# Note that the following line uses `m` rather than `self`, and thus | |
# doesn't maintaing the reference and allows for deep copying. | |
self.register_forward_pre_hook(lambda m, inp: setattr(m, name, getattr(m, new_name) * coef)) | |
class ScaledParamLinear(nn.Linear, ScaledParamModule): | |
def __init__(self, *args, gain=np.sqrt(2), use_wscale=True, lrmul=1, **kwargs): | |
self.gain = gain | |
self.use_wscale = use_wscale | |
self.lrmul = lrmul | |
super().__init__(*args, **kwargs) | |
self.scale_weight(gain, use_wscale, lrmul) | |
self.scale_bias(lrmul) | |
self.reset_parameters() | |
def reset_parameters(self): | |
weight = self._weight if hasattr(self, '_weight') else self.weight | |
self.init_weight_(weight, self.gain, self.use_wscale, self.lrmul) | |
bias = self._bias if hasattr(self, '_bias') else self.bias | |
if bias is not None: | |
bias.data.zero_() | |
def extra_repr(self): | |
return 'in_features={}, out_features={}, bias={}'.format( | |
self.in_features, self.out_features, self._bias is not None # use the _real param | |
) | |
class ScaledParamConv2d(nn.Conv2d, ScaledParamModule): | |
def __init__(self, *args, gain=np.sqrt(2), use_wscale=True, lrmul=1, **kwargs): | |
self.gain = gain | |
self.use_wscale = use_wscale | |
self.lrmul = lrmul | |
super().__init__(*args, **kwargs) | |
self.scale_weight(gain, use_wscale, lrmul) | |
self.scale_bias(lrmul) | |
self.reset_parameters() | |
def reset_parameters(self): | |
weight = self._weight if hasattr(self, '_weight') else self.weight | |
self.init_weight_(weight, self.gain, self.use_wscale, self.lrmul) | |
bias = self._bias if hasattr(self, '_bias') else self.bias | |
if bias is not None: | |
bias.data.zero_() | |
def extra_repr(self): | |
s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}' | |
', stride={stride}') | |
if self.padding != (0,) * len(self.padding): | |
s += ', padding={padding}' | |
if self.dilation != (1,) * len(self.dilation): | |
s += ', dilation={dilation}' | |
if self.output_padding != (0,) * len(self.output_padding): | |
s += ', output_padding={output_padding}' | |
if self.groups != 1: | |
s += ', groups={groups}' | |
if self._bias is None: # use the _real param | |
s += ', bias=False' | |
return s.format(**self.__dict__) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment