Skip to content

Instantly share code, notes, and snippets.

@ssnl
Created May 27, 2019 20:47
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save ssnl/44b082499381478150abfeabaf2701d2 to your computer and use it in GitHub Desktop.
Save ssnl/44b082499381478150abfeabaf2701d2 to your computer and use it in GitHub Desktop.
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
class ScaledParamModule(nn.Module):
# linear w: [ fan_out, fan_in ]
# conv w: [ nc_out, nc_in, k1, k2 ]
# convT w: [ nc_in, nc_out, k1, k2 ], but let's ignore this case because
# (1) the tf impl doesn't special-case
# (2) convT is only used for fusing Upsample & Conv2d, and in that case, the
# weight should be done as if it is for a Conv2d.
#
# NB: in tf code, use_wscale has default value False, but for StyleGAN it is
# True everywhere, so I changed it.
def scale_weight(self, gain=np.sqrt(2), use_wscale=True, lrmul=1, new_name='_weight'):
assert isinstance(self.weight, nn.Parameter)
runtime_coef = self.init_weight_(self.weight, gain, use_wscale, lrmul)
# add scale hook
self.add_scale_hook('weight', new_name, runtime_coef)
# helper for scale_weight and reset_parameters
def init_weight_(self, weight, gain, use_wscale, lrmul):
fan_in = np.prod(weight.shape[1:])
he_std = gain / np.sqrt(fan_in) # He init
# Equalized learning rate and custom learning rate multiplier.
if use_wscale:
init_std = 1.0 / lrmul
runtime_coef = he_std * lrmul
else:
init_std = he_std / lrmul
runtime_coef = lrmul
# Init variable using He init.
weight.data.normal_(0, init_std)
return runtime_coef
def scale_bias(self, lrmul=1, new_name='_bias'):
if self.bias is None:
assert not hasattr(self, new_name)
# do not delete so we don't have to restore in forward
# del self.bias
self.register_parameter(new_name, None)
return
bias = self.bias
assert isinstance(bias, nn.Parameter)
# zero out
bias.data.zero_()
# add scale hook
self.add_scale_hook('bias', new_name, lrmul)
def add_scale_hook(self, name, new_name, coef):
param = getattr(self, name)
assert isinstance(param, nn.Parameter)
assert not hasattr(self, new_name)
delattr(self, name)
self.register_parameter(new_name, param)
# Note that the following line uses `m` rather than `self`, and thus
# doesn't maintaing the reference and allows for deep copying.
self.register_forward_pre_hook(lambda m, inp: setattr(m, name, getattr(m, new_name) * coef))
class ScaledParamLinear(nn.Linear, ScaledParamModule):
def __init__(self, *args, gain=np.sqrt(2), use_wscale=True, lrmul=1, **kwargs):
self.gain = gain
self.use_wscale = use_wscale
self.lrmul = lrmul
super().__init__(*args, **kwargs)
self.scale_weight(gain, use_wscale, lrmul)
self.scale_bias(lrmul)
self.reset_parameters()
def reset_parameters(self):
weight = self._weight if hasattr(self, '_weight') else self.weight
self.init_weight_(weight, self.gain, self.use_wscale, self.lrmul)
bias = self._bias if hasattr(self, '_bias') else self.bias
if bias is not None:
bias.data.zero_()
def extra_repr(self):
return 'in_features={}, out_features={}, bias={}'.format(
self.in_features, self.out_features, self._bias is not None # use the _real param
)
class ScaledParamConv2d(nn.Conv2d, ScaledParamModule):
def __init__(self, *args, gain=np.sqrt(2), use_wscale=True, lrmul=1, **kwargs):
self.gain = gain
self.use_wscale = use_wscale
self.lrmul = lrmul
super().__init__(*args, **kwargs)
self.scale_weight(gain, use_wscale, lrmul)
self.scale_bias(lrmul)
self.reset_parameters()
def reset_parameters(self):
weight = self._weight if hasattr(self, '_weight') else self.weight
self.init_weight_(weight, self.gain, self.use_wscale, self.lrmul)
bias = self._bias if hasattr(self, '_bias') else self.bias
if bias is not None:
bias.data.zero_()
def extra_repr(self):
s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}'
', stride={stride}')
if self.padding != (0,) * len(self.padding):
s += ', padding={padding}'
if self.dilation != (1,) * len(self.dilation):
s += ', dilation={dilation}'
if self.output_padding != (0,) * len(self.output_padding):
s += ', output_padding={output_padding}'
if self.groups != 1:
s += ', groups={groups}'
if self._bias is None: # use the _real param
s += ', bias=False'
return s.format(**self.__dict__)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment