Skip to content

Instantly share code, notes, and snippets.

Last active April 6, 2020 13:32
Show Gist options
  • Save nurpax/cbac236c68b38f27e2ec64cfdb978863 to your computer and use it in GitHub Desktop.
Save nurpax/cbac236c68b38f27e2ec64cfdb978863 to your computer and use it in GitHub Desktop.
import numpy as np
import math
import functools
import torch
import torch.nn as nn
from torch.nn import init
import torch.optim as optim
import torch.nn.functional as F
from torch.nn import Parameter as P
import layers
from sync_batchnorm import SynchronizedBatchNorm2d as SyncBatchNorm2d
# Architectures for G
# Attention is passed in in the format '32_64' to mean applying an attention
# block at both resolution 32x32 and 64x64. Just '64' will apply at 64x64.
def G_arch(ch=64, attention='64', ksize='333333', dilation='111111'):
arch = {}
arch[512] = {'in_channels' : [ch * item for item in [16, 16, 8, 8, 4, 2, 1]],
'out_channels' : [ch * item for item in [16, 8, 8, 4, 2, 1, 1]],
'upsample' : [True] * 7,
'resolution' : [8, 16, 32, 64, 128, 256, 512],
'attention' : {2**i: (2**i in [int(item) for item in attention.split('_')])
for i in range(3,10)}}
arch[256] = {'in_channels' : [ch * item for item in [16, 16, 8, 8, 4, 2]],
'out_channels' : [ch * item for item in [16, 8, 8, 4, 2, 1]],
'upsample' : [True] * 6,
'resolution' : [8, 16, 32, 64, 128, 256],
'attention' : {2**i: (2**i in [int(item) for item in attention.split('_')])
for i in range(3,9)}}
arch[128] = {'in_channels' : [ch * item for item in [16, 16, 8, 4, 2]],
'out_channels' : [ch * item for item in [16, 8, 4, 2, 1]],
'upsample' : [True] * 5,
'resolution' : [8, 16, 32, 64, 128],
'attention' : {2**i: (2**i in [int(item) for item in attention.split('_')])
for i in range(3,8)}}
arch[64] = {'in_channels' : [ch * item for item in [16, 16, 8, 4]],
'out_channels' : [ch * item for item in [16, 8, 4, 2]],
'upsample' : [True] * 4,
'resolution' : [8, 16, 32, 64],
'attention' : {2**i: (2**i in [int(item) for item in attention.split('_')])
for i in range(3,7)}}
arch[32] = {'in_channels' : [ch * item for item in [4, 4, 4]],
'out_channels' : [ch * item for item in [4, 4, 4]],
'upsample' : [True] * 3,
'resolution' : [8, 16, 32],
'attention' : {2**i: (2**i in [int(item) for item in attention.split('_')])
for i in range(3,6)}}
return arch
class Generator(nn.Module):
def __init__(self, G_ch=64, dim_z=128, bottom_width=4, resolution=128,
G_kernel_size=3, G_attn='64', n_classes=1000,
num_G_SVs=1, num_G_SV_itrs=1,
G_shared=True, shared_dim=0, hier=False, self_modulation=False,
cross_replica=False, mybn=False,
G_lr=5e-5, G_B1=0.0, G_B2=0.999, adam_eps=1e-8,
BN_eps=1e-5, SN_eps=1e-12, G_mixed_precision=False, G_fp16=False,
G_init='ortho', skip_init=False, no_optim=False,
G_param='SN', norm_style='bn',
super(Generator, self).__init__()
# Channel width mulitplier = G_ch
# Dimensionality of the latent space
self.dim_z = dim_z
# The initial spatial dimensions
self.bottom_width = bottom_width
# Resolution of the output
self.resolution = resolution
# Kernel size?
self.kernel_size = G_kernel_size
# Attention?
self.attention = G_attn
# number of classes, for use in categorical conditional generation
self.n_classes = n_classes
# Use shared embeddings?
self.G_shared = G_shared
# Dimensionality of the shared embedding? Unused if not using G_shared
self.shared_dim = shared_dim if shared_dim > 0 else dim_z
# Hierarchical latent space?
self.hier = hier
# Flat z with self-modulation like in "A U-Net Based Discriminator for Generative Adversarial Networks"
self.self_modulation = self_modulation
# Cross replica batchnorm?
self.cross_replica = cross_replica
# Use my batchnorm?
self.mybn = mybn
# nonlinearity for residual blocks
self.activation = G_activation
# Initialization style
self.init = G_init
# Parameterization style
self.G_param = G_param
# Normalization style
self.norm_style = norm_style
# Epsilon for BatchNorm?
self.BN_eps = BN_eps
# Epsilon for Spectral Norm?
self.SN_eps = SN_eps
# fp16?
self.fp16 = G_fp16
# Architecture dict
self.arch = G_arch(, self.attention)[resolution]
# If using hierarchical latents, adjust z
if self.hier:
# Number of places z slots into
self.num_slots = len(self.arch['in_channels']) + 1
self.z_chunk_size = (self.dim_z // self.num_slots)
# Recalculate latent dimensionality for even splitting into chunks
self.dim_z = self.z_chunk_size * self.num_slots
self.num_slots = 1
self.z_chunk_size = 0
# Which convs, batchnorms, and linear layers to use
if self.G_param == 'SN':
self.which_conv = functools.partial(layers.SNConv2d,
kernel_size=3, padding=1,
num_svs=num_G_SVs, num_itrs=num_G_SV_itrs,
self.which_linear = functools.partial(layers.SNLinear,
num_svs=num_G_SVs, num_itrs=num_G_SV_itrs,
self.which_conv = functools.partial(nn.Conv2d, kernel_size=3, padding=1)
self.which_linear = nn.Linear
# We use a non-spectral-normed embedding here regardless;
# For some reason applying SN to G's embedding seems to randomly cripple G
self.which_embedding = nn.Embedding
if not self.self_modulation:
bn_linear = (functools.partial(self.which_linear, bias=False) if self.G_shared else self.which_embedding)
input_size = self.shared_dim + self.z_chunk_size if self.G_shared else self.n_classes
bn_linear = nn.Linear
# THIS BREAKS with --G_shared
input_size = self.dim_z + (self.shared_dim if self.G_shared else 0)
#input_size = self.shared_dim + self.z_chunk_size if self.G_shared else self.n_classes
self.which_bn = functools.partial(layers.ccbn,
# Prepare model
# If not using shared embeddings, self.shared is just a passthrough
self.shared = (self.which_embedding(n_classes, self.shared_dim) if G_shared
else layers.identity())
# First linear layer
self.linear = self.which_linear(self.dim_z // self.num_slots,
self.arch['in_channels'][0] * (self.bottom_width **2))
# self.blocks is a doubly-nested list of modules, the outer loop intended
# to be over blocks at a given resolution (resblocks and/or self-attention)
# while the inner loop is over a given block
self.blocks = []
for index in range(len(self.arch['out_channels'])):
self.blocks += [[layers.GBlock(in_channels=self.arch['in_channels'][index],
upsample=(functools.partial(F.interpolate, scale_factor=2)
if self.arch['upsample'][index] else None))]]
# If attention on this block, attach it to the end
if self.arch['attention'][self.arch['resolution'][index]]:
print('Adding attention layer in G at resolution %d' % self.arch['resolution'][index])
self.blocks[-1] += [layers.Attention(self.arch['out_channels'][index], self.which_conv)]
# Turn self.blocks into a ModuleList so that it's all properly registered.
self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])
# output layer: batchnorm-relu-conv.
# Consider using a non-spectral conv here
self.output_layer = nn.Sequential(['out_channels'][-1],
self.which_conv(self.arch['out_channels'][-1], 3))
# Initialize weights. Optionally skip init for testing.
if not skip_init:
# Set up optimizer
# If this is an EMA copy, no need for an optim, so just return now
if no_optim:
return, self.B1, self.B2, self.adam_eps = G_lr, G_B1, G_B2, adam_eps
if G_mixed_precision:
print('Using fp16 adam in G...')
import utils
self.optim = utils.Adam16(params=self.parameters(),,
betas=(self.B1, self.B2), weight_decay=0,
self.optim = optim.Adam(params=self.parameters(),,
betas=(self.B1, self.B2), weight_decay=0,
# LR scheduling, left here for forward compatibility
# self.lr_sched = {'itr' : 0}# if else {}
# self.j = 0
# Initialize
def init_weights(self):
self.param_count = 0
for module in self.modules():
if (isinstance(module, nn.Conv2d)
or isinstance(module, nn.Linear)
or isinstance(module, nn.Embedding)):
if self.init == 'ortho':
elif self.init == 'N02':
init.normal_(module.weight, 0, 0.02)
elif self.init in ['glorot', 'xavier']:
print('Init style not recognized...')
self.param_count += sum([ for p in module.parameters()])
print('Param count for G''s initialized parameters: %d' % self.param_count)
# Note on this forward function: we pass in a y vector which has
# already been passed through G.shared to enable easy class-wise
# interpolation later. If we passed in the one-hot and then ran it through
# G.shared in this forward function, it would be harder to handle.
def forward(self, z, y):
# If hierarchical, concatenate zs and ys
if self.hier:
zs = torch.split(z, self.z_chunk_size, 1)
z = zs[0]
ys = [[y, item], 1) for item in zs[1:]]
elif self.self_modulation:
ys = [z] * len(self.blocks)
ys = [y] * len(self.blocks)
# First linear layer
h = self.linear(z)
# Reshape
h = h.view(h.size(0), -1, self.bottom_width, self.bottom_width)
# Loop over blocks
for index, blocklist in enumerate(self.blocks):
# Second inner loop in case block has multiple layers
for block in blocklist:
h = block(h, ys[index])
# Apply batchnorm-relu-conv-tanh at output
return torch.tanh(self.output_layer(h))
# Discriminator architecture, same paradigm as G's above
def D_arch(ch=64, attention='64',ksize='333333', dilation='111111'):
arch = {}
arch[256] = {'in_channels' : [3] + [ch*item for item in [1, 2, 4, 8, 8, 16]],
'out_channels' : [item * ch for item in [1, 2, 4, 8, 8, 16, 16]],
'downsample' : [True] * 6 + [False],
'resolution' : [128, 64, 32, 16, 8, 4, 4 ],
'attention' : {2**i: 2**i in [int(item) for item in attention.split('_')]
for i in range(2,8)}}
arch[128] = {'in_channels' : [3] + [ch*item for item in [1, 2, 4, 8, 16]],
'out_channels' : [item * ch for item in [1, 2, 4, 8, 16, 16]],
'downsample' : [True] * 5 + [False],
'resolution' : [64, 32, 16, 8, 4, 4],
'attention' : {2**i: 2**i in [int(item) for item in attention.split('_')]
for i in range(2,8)}}
arch[64] = {'in_channels' : [3] + [ch*item for item in [1, 2, 4, 8]],
'out_channels' : [item * ch for item in [1, 2, 4, 8, 16]],
'downsample' : [True] * 4 + [False],
'resolution' : [32, 16, 8, 4, 4],
'attention' : {2**i: 2**i in [int(item) for item in attention.split('_')]
for i in range(2,7)}}
arch[32] = {'in_channels' : [3] + [item * ch for item in [4, 4, 4]],
'out_channels' : [item * ch for item in [4, 4, 4, 4]],
'downsample' : [True, True, False, False],
'resolution' : [16, 16, 16, 16],
'attention' : {2**i: 2**i in [int(item) for item in attention.split('_')]
for i in range(2,6)}}
return arch
class Discriminator(nn.Module):
def __init__(self, D_ch=64, D_wide=True, resolution=128,
D_kernel_size=3, D_attn='64', n_classes=1000,
num_D_SVs=1, num_D_SV_itrs=1, D_activation=nn.ReLU(inplace=False),
D_lr=2e-4, D_B1=0.0, D_B2=0.999, adam_eps=1e-8,
SN_eps=1e-12, output_dim=1, D_mixed_precision=False, D_fp16=False,
D_init='ortho', skip_init=False, D_param='SN', **kwargs):
super(Discriminator, self).__init__()
# Width multiplier = D_ch
# Use Wide D as in BigGAN and SA-GAN or skinny D as in SN-GAN?
self.D_wide = D_wide
# Resolution
self.resolution = resolution
# Kernel size
self.kernel_size = D_kernel_size
# Attention?
self.attention = D_attn
# Number of classes
self.n_classes = n_classes
# Activation
self.activation = D_activation
# Initialization style
self.init = D_init
# Parameterization style
self.D_param = D_param
# Epsilon for Spectral Norm?
self.SN_eps = SN_eps
# Fp16?
self.fp16 = D_fp16
# Architecture
self.arch = D_arch(, self.attention)[resolution]
# Which convs, batchnorms, and linear layers to use
# No option to turn off SN in D right now
if self.D_param == 'SN':
self.which_conv = functools.partial(layers.SNConv2d,
kernel_size=3, padding=1,
num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
self.which_linear = functools.partial(layers.SNLinear,
num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
self.which_embedding = functools.partial(layers.SNEmbedding,
num_svs=num_D_SVs, num_itrs=num_D_SV_itrs,
# Prepare model
# self.blocks is a doubly-nested list of modules, the outer loop intended
# to be over blocks at a given resolution (resblocks and/or self-attention)
self.blocks = []
for index in range(len(self.arch['out_channels'])):
self.blocks += [[layers.DBlock(in_channels=self.arch['in_channels'][index],
preactivation=(index > 0),
downsample=(nn.AvgPool2d(2) if self.arch['downsample'][index] else None))]]
# If attention on this block, attach it to the end
if self.arch['attention'][self.arch['resolution'][index]]:
print('Adding attention layer in D at resolution %d' % self.arch['resolution'][index])
self.blocks[-1] += [layers.Attention(self.arch['out_channels'][index],
# Turn self.blocks into a ModuleList so that it's all properly registered.
self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])
# Linear output layer. The output dimension is typically 1, but may be
# larger if we're e.g. turning this into a VAE with an inference output
self.linear = self.which_linear(self.arch['out_channels'][-1], output_dim)
# Embedding for projection discrimination
self.embed = self.which_embedding(self.n_classes, self.arch['out_channels'][-1])
# Initialize weights
if not skip_init:
# Set up optimizer, self.B1, self.B2, self.adam_eps = D_lr, D_B1, D_B2, adam_eps
if D_mixed_precision:
print('Using fp16 adam in D...')
import utils
self.optim = utils.Adam16(params=self.parameters(),,
betas=(self.B1, self.B2), weight_decay=0, eps=self.adam_eps)
self.optim = optim.Adam(params=self.parameters(),,
betas=(self.B1, self.B2), weight_decay=0, eps=self.adam_eps)
# LR scheduling, left here for forward compatibility
# self.lr_sched = {'itr' : 0}# if else {}
# self.j = 0
# Initialize
def init_weights(self):
self.param_count = 0
for module in self.modules():
if (isinstance(module, nn.Conv2d)
or isinstance(module, nn.Linear)
or isinstance(module, nn.Embedding)):
if self.init == 'ortho':
elif self.init == 'N02':
init.normal_(module.weight, 0, 0.02)
elif self.init in ['glorot', 'xavier']:
print('Init style not recognized...')
self.param_count += sum([ for p in module.parameters()])
print('Param count for D''s initialized parameters: %d' % self.param_count)
def forward(self, x, y=None):
# Stick x into h for cleaner for loops without flow control
h = x
# Loop over blocks
for index, blocklist in enumerate(self.blocks):
for block in blocklist:
h = block(h)
# Apply global sum pooling as in SN-GAN
h = torch.sum(self.activation(h), [2, 3])
# Get initial class-unconditional output
out = self.linear(h)
# Get projection of final featureset onto class vectors and add to evidence
out = out + torch.sum(self.embed(y) * h, 1, keepdim=True)
return out
# Parallelized G_D to minimize cross-gpu communication
# Without this, Generator outputs would get all-gathered and then rebroadcast.
class G_D(nn.Module):
def __init__(self, G, D):
super(G_D, self).__init__()
self.G = G
self.D = D
def forward(self, z, gy, x=None, dy=None, train_G=False, return_G_z=False,
# If training G, enable grad tape
with torch.set_grad_enabled(train_G):
# Get Generator output given noise
G_z = self.G(z, self.G.shared(gy))
# Cast as necessary
if self.G.fp16 and not self.D.fp16:
G_z = G_z.float()
if self.D.fp16 and not self.G.fp16:
G_z = G_z.half()
# Split_D means to run D once with real data and once with fake,
# rather than concatenating along the batch dimension.
if split_D:
D_fake = self.D(G_z, gy)
if x is not None:
D_real = self.D(x, dy)
return D_fake, D_real
if return_G_z:
return D_fake, G_z
return D_fake
# If real data is provided, concatenate it with the Generator's output
# along the batch dimension for improved efficiency.
D_input =[G_z, x], 0) if x is not None else G_z
D_class =[gy, dy], 0) if dy is not None else gy
# Get Discriminator output
D_out = self.D(D_input, D_class)
if x is not None:
return torch.split(D_out, [G_z.shape[0], x.shape[0]]) # D_fake, D_real
if return_G_z:
return D_out, G_z
return D_out
''' Layers
This file contains various layers for the BigGAN models.
import numpy as np
import torch
import torch.nn as nn
from torch.nn import init
import torch.optim as optim
import torch.nn.functional as F
from torch.nn import Parameter as P
from sync_batchnorm import SynchronizedBatchNorm2d as SyncBN2d
# Projection of x onto y
def proj(x, y):
return, x.t()) * y /, y.t())
# Orthogonalize x wrt list of vectors ys
def gram_schmidt(x, ys):
for y in ys:
x = x - proj(x, y)
return x
# Apply num_itrs steps of the power method to estimate top N singular values.
def power_iteration(W, u_, update=True, eps=1e-12):
# Lists holding singular vectors and values
us, vs, svs = [], [], []
for i, u in enumerate(u_):
# Run one step of the power iteration
with torch.no_grad():
v = torch.matmul(u, W)
# Run Gram-Schmidt to subtract components of all other singular vectors
v = F.normalize(gram_schmidt(v, vs), eps=eps)
# Add to the list
vs += [v]
# Update the other singular vector
u = torch.matmul(v, W.t())
# Run Gram-Schmidt to subtract components of all other singular vectors
u = F.normalize(gram_schmidt(u, us), eps=eps)
# Add to the list
us += [u]
if update:
u_[i][:] = u
# Compute this singular value and add it to the list
svs += [torch.squeeze(torch.matmul(torch.matmul(v, W.t()), u.t()))]
#svs += [torch.sum(F.linear(u, W.transpose(0, 1)) * v)]
return svs, us, vs
# Convenience passthrough function
class identity(nn.Module):
def forward(self, input):
return input
# Spectral normalization base class
class SN(object):
def __init__(self, num_svs, num_itrs, num_outputs, transpose=False, eps=1e-12):
# Number of power iterations per step
self.num_itrs = num_itrs
# Number of singular values
self.num_svs = num_svs
# Transposed?
self.transpose = transpose
# Epsilon value for avoiding divide-by-0
self.eps = eps
# Register a singular vector for each sv
for i in range(self.num_svs):
self.register_buffer('u%d' % i, torch.randn(1, num_outputs))
self.register_buffer('sv%d' % i, torch.ones(1))
# Singular vectors (u side)
def u(self):
return [getattr(self, 'u%d' % i) for i in range(self.num_svs)]
# Singular values;
# note that these buffers are just for logging and are not used in training.
def sv(self):
return [getattr(self, 'sv%d' % i) for i in range(self.num_svs)]
# Compute the spectrally-normalized weight
def W_(self):
W_mat = self.weight.view(self.weight.size(0), -1)
if self.transpose:
W_mat = W_mat.t()
# Apply num_itrs power iterations
for _ in range(self.num_itrs):
svs, us, vs = power_iteration(W_mat, self.u,, eps=self.eps)
# Update the svs
with torch.no_grad(): # Make sure to do this in a no_grad() context or you'll get memory leaks!
for i, sv in enumerate(svs):[i][:] = sv
return self.weight / svs[0]
# 2D Conv layer with spectral norm
class SNConv2d(nn.Conv2d, SN):
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True,
num_svs=1, num_itrs=1, eps=1e-12):
nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, stride,
padding, dilation, groups, bias)
SN.__init__(self, num_svs, num_itrs, out_channels, eps=eps)
def forward(self, x):
return F.conv2d(x, self.W_(), self.bias, self.stride,
self.padding, self.dilation, self.groups)
# Linear layer with spectral norm
class SNLinear(nn.Linear, SN):
def __init__(self, in_features, out_features, bias=True,
num_svs=1, num_itrs=1, eps=1e-12):
nn.Linear.__init__(self, in_features, out_features, bias)
SN.__init__(self, num_svs, num_itrs, out_features, eps=eps)
def forward(self, x):
return F.linear(x, self.W_(), self.bias)
# Embedding layer with spectral norm
# We use num_embeddings as the dim instead of embedding_dim here
# for convenience sake
class SNEmbedding(nn.Embedding, SN):
def __init__(self, num_embeddings, embedding_dim, padding_idx=None,
max_norm=None, norm_type=2, scale_grad_by_freq=False,
sparse=False, _weight=None,
num_svs=1, num_itrs=1, eps=1e-12):
nn.Embedding.__init__(self, num_embeddings, embedding_dim, padding_idx,
max_norm, norm_type, scale_grad_by_freq,
sparse, _weight)
SN.__init__(self, num_svs, num_itrs, num_embeddings, eps=eps)
def forward(self, x):
return F.embedding(x, self.W_())
# A non-local block as used in SA-GAN
# Note that the implementation as described in the paper is largely incorrect;
# refer to the released code for the actual implementation.
class Attention(nn.Module):
def __init__(self, ch, which_conv=SNConv2d, name='attention'):
super(Attention, self).__init__()
# Channel multiplier = ch
self.which_conv = which_conv
self.theta = self.which_conv(, // 8, kernel_size=1, padding=0, bias=False)
self.phi = self.which_conv(, // 8, kernel_size=1, padding=0, bias=False)
self.g = self.which_conv(, // 2, kernel_size=1, padding=0, bias=False)
self.o = self.which_conv( // 2,, kernel_size=1, padding=0, bias=False)
# Learnable gain parameter
self.gamma = P(torch.tensor(0.), requires_grad=True)
def forward(self, x, y=None):
# Apply convs
theta = self.theta(x)
phi = F.max_pool2d(self.phi(x), [2,2])
g = F.max_pool2d(self.g(x), [2,2])
# Perform reshapes
theta = theta.view(-1, self. ch // 8, x.shape[2] * x.shape[3])
phi = phi.view(-1, self. ch // 8, x.shape[2] * x.shape[3] // 4)
g = g.view(-1, self. ch // 2, x.shape[2] * x.shape[3] // 4)
# Matmul and softmax to get attention maps
beta = F.softmax(torch.bmm(theta.transpose(1, 2), phi), -1)
# Attention map times g path
o = self.o(torch.bmm(g, beta.transpose(1,2)).view(-1, // 2, x.shape[2], x.shape[3]))
return self.gamma * o + x
# Fused batchnorm op
def fused_bn(x, mean, var, gain=None, bias=None, eps=1e-5):
# Apply scale and shift--if gain and bias are provided, fuse them here
# Prepare scale
scale = torch.rsqrt(var + eps)
# If a gain is provided, use it
if gain is not None:
scale = scale * gain
# Prepare shift
shift = mean * scale
# If bias is provided, use it
if bias is not None:
shift = shift - bias
return x * scale - shift
#return ((x - mean) / ((var + eps) ** 0.5)) * gain + bias # The unfused way.
# Manual BN
# Calculate means and variances using mean-of-squares minus mean-squared
def manual_bn(x, gain=None, bias=None, return_mean_var=False, eps=1e-5):
# Cast x to float32 if necessary
float_x = x.float()
# Calculate expected value of x (m) and expected value of x**2 (m2)
# Mean of x
m = torch.mean(float_x, [0, 2, 3], keepdim=True)
# Mean of x squared
m2 = torch.mean(float_x ** 2, [0, 2, 3], keepdim=True)
# Calculate variance as mean of squared minus mean squared.
var = (m2 - m **2)
# Cast back to float 16 if necessary
var = var.type(x.type())
m = m.type(x.type())
# Return mean and variance for updating stored mean/var if requested
if return_mean_var:
return fused_bn(x, m, var, gain, bias, eps), m.squeeze(), var.squeeze()
return fused_bn(x, m, var, gain, bias, eps)
# My batchnorm, supports standing stats
class myBN(nn.Module):
def __init__(self, num_channels, eps=1e-5, momentum=0.1):
super(myBN, self).__init__()
# momentum for updating running stats
self.momentum = momentum
# epsilon to avoid dividing by 0
self.eps = eps
# Momentum
self.momentum = momentum
# Register buffers
self.register_buffer('stored_mean', torch.zeros(num_channels))
self.register_buffer('stored_var', torch.ones(num_channels))
self.register_buffer('accumulation_counter', torch.zeros(1))
# Accumulate running means and vars
self.accumulate_standing = False
# reset standing stats
def reset_stats(self):
self.stored_mean[:] = 0
self.stored_var[:] = 0
self.accumulation_counter[:] = 0
def forward(self, x, gain, bias):
out, mean, var = manual_bn(x, gain, bias, return_mean_var=True, eps=self.eps)
# If accumulating standing stats, increment them
if self.accumulate_standing:
self.stored_mean[:] = self.stored_mean +
self.stored_var[:] = self.stored_var +
self.accumulation_counter += 1.0
# If not accumulating standing stats, take running averages
self.stored_mean[:] = self.stored_mean * (1 - self.momentum) + mean * self.momentum
self.stored_var[:] = self.stored_var * (1 - self.momentum) + var * self.momentum
return out
# If not in training mode, use the stored statistics
mean = self.stored_mean.view(1, -1, 1, 1)
var = self.stored_var.view(1, -1, 1, 1)
# If using standing stats, divide them by the accumulation counter
if self.accumulate_standing:
mean = mean / self.accumulation_counter
var = var / self.accumulation_counter
return fused_bn(x, mean, var, gain, bias, self.eps)
# Simple function to handle groupnorm norm stylization
def groupnorm(x, norm_style):
# If number of channels specified in norm_style:
if 'ch' in norm_style:
ch = int(norm_style.split('_')[-1])
groups = max(int(x.shape[1]) // ch, 1)
# If number of groups specified in norm style
elif 'grp' in norm_style:
groups = int(norm_style.split('_')[-1])
# If neither, default to groups = 16
groups = 16
return F.group_norm(x, groups)
# Class-conditional bn
# output size is the number of channels, input size is for the linear layers
# Andy's Note: this class feels messy but I'm not really sure how to clean it up
# Suggestions welcome! (By which I mean, refactor this and make a pull request
# if you want to make this more readable/usable).
class ccbn(nn.Module):
def __init__(self, output_size, input_size, which_linear, eps=1e-5, momentum=0.1,
cross_replica=False, mybn=False, norm_style='bn', self_modulation=False):
super(ccbn, self).__init__()
self.output_size, self.input_size = output_size, input_size
# Prepare gain and bias layers
if not self_modulation:
self.gain = which_linear(input_size, output_size)
self.bias = which_linear(input_size, output_size)
self.gain = nn.Sequential(which_linear(input_size, input_size, bias=True), nn.ReLU(), which_linear(input_size, output_size, bias=False))
self.bias = nn.Sequential(which_linear(input_size, input_size, bias=True), nn.ReLU(), which_linear(input_size, output_size, bias=False))
# epsilon to avoid dividing by 0
self.eps = eps
# Momentum
self.momentum = momentum
# Use cross-replica batchnorm?
self.cross_replica = cross_replica
# Use my batchnorm?
self.mybn = mybn
# Norm style?
self.norm_style = norm_style
if self.cross_replica: = SyncBN2d(output_size, eps=self.eps, momentum=self.momentum, affine=False)
elif self.mybn: = myBN(output_size, self.eps, self.momentum)
elif self.norm_style in ['bn', 'in']:
self.register_buffer('stored_mean', torch.zeros(output_size))
self.register_buffer('stored_var', torch.ones(output_size))
def forward(self, x, y):
# Calculate class-conditional gains and biases
gain = (1 + self.gain(y)).view(y.size(0), -1, 1, 1)
bias = self.bias(y).view(y.size(0), -1, 1, 1)
# If using my batchnorm
if self.mybn or self.cross_replica:
return, gain=gain, bias=bias)
# else:
if self.norm_style == 'bn':
out = F.batch_norm(x, self.stored_mean, self.stored_var, None, None,, 0.1, self.eps)
elif self.norm_style == 'in':
out = F.instance_norm(x, self.stored_mean, self.stored_var, None, None,, 0.1, self.eps)
elif self.norm_style == 'gn':
out = groupnorm(x, self.normstyle)
elif self.norm_style == 'nonorm':
out = x
return out * gain + bias
def extra_repr(self):
s = 'out: {output_size}, in: {input_size},'
s +=' cross_replica={cross_replica}'
return s.format(**self.__dict__)
# Normal, non-class-conditional BN
class bn(nn.Module):
def __init__(self, output_size, eps=1e-5, momentum=0.1,
cross_replica=False, mybn=False):
super(bn, self).__init__()
self.output_size= output_size
# Prepare gain and bias layers
self.gain = P(torch.ones(output_size), requires_grad=True)
self.bias = P(torch.zeros(output_size), requires_grad=True)
# epsilon to avoid dividing by 0
self.eps = eps
# Momentum
self.momentum = momentum
# Use cross-replica batchnorm?
self.cross_replica = cross_replica
# Use my batchnorm?
self.mybn = mybn
if self.cross_replica: = SyncBN2d(output_size, eps=self.eps, momentum=self.momentum, affine=False)
elif mybn: = myBN(output_size, self.eps, self.momentum)
# Register buffers if neither of the above
self.register_buffer('stored_mean', torch.zeros(output_size))
self.register_buffer('stored_var', torch.ones(output_size))
def forward(self, x, y=None):
if self.cross_replica or self.mybn:
gain = self.gain.view(1,-1,1,1)
bias = self.bias.view(1,-1,1,1)
return, gain=gain, bias=bias)
return F.batch_norm(x, self.stored_mean, self.stored_var, self.gain,
self.bias,, self.momentum, self.eps)
# Generator blocks
# Note that this class assumes the kernel size and padding (and any other
# settings) have been selected in the main generator module and passed in
# through the which_conv arg. Similar rules apply with which_bn (the input
# size [which is actually the number of channels of the conditional info] must
# be preselected)
class GBlock(nn.Module):
def __init__(self, in_channels, out_channels,
which_conv=nn.Conv2d, which_bn=bn, activation=None,
super(GBlock, self).__init__()
self.in_channels, self.out_channels = in_channels, out_channels
self.which_conv, self.which_bn = which_conv, which_bn
self.activation = activation
self.upsample = upsample
# Conv layers
self.conv1 = self.which_conv(self.in_channels, self.out_channels)
self.conv2 = self.which_conv(self.out_channels, self.out_channels)
self.learnable_sc = in_channels != out_channels or upsample
if self.learnable_sc:
self.conv_sc = self.which_conv(in_channels, out_channels,
kernel_size=1, padding=0)
# Batchnorm layers
self.bn1 = self.which_bn(in_channels)
self.bn2 = self.which_bn(out_channels)
# upsample layers
self.upsample = upsample
def forward(self, x, y):
h = self.activation(self.bn1(x, y))
if self.upsample:
h = self.upsample(h)
x = self.upsample(x)
h = self.conv1(h)
h = self.activation(self.bn2(h, y))
h = self.conv2(h)
if self.learnable_sc:
x = self.conv_sc(x)
return h + x
# Residual block for the discriminator
class DBlock(nn.Module):
def __init__(self, in_channels, out_channels, which_conv=SNConv2d, wide=True,
preactivation=False, activation=None, downsample=None,):
super(DBlock, self).__init__()
self.in_channels, self.out_channels = in_channels, out_channels
# If using wide D (as in SA-GAN and BigGAN), change the channel pattern
self.hidden_channels = self.out_channels if wide else self.in_channels
self.which_conv = which_conv
self.preactivation = preactivation
self.activation = activation
self.downsample = downsample
# Conv layers
self.conv1 = self.which_conv(self.in_channels, self.hidden_channels)
self.conv2 = self.which_conv(self.hidden_channels, self.out_channels)
self.learnable_sc = True if (in_channels != out_channels) or downsample else False
if self.learnable_sc:
self.conv_sc = self.which_conv(in_channels, out_channels,
kernel_size=1, padding=0)
def shortcut(self, x):
if self.preactivation:
if self.learnable_sc:
x = self.conv_sc(x)
if self.downsample:
x = self.downsample(x)
if self.downsample:
x = self.downsample(x)
if self.learnable_sc:
x = self.conv_sc(x)
return x
def forward(self, x):
if self.preactivation:
# h = self.activation(x) # NOT TODAY SATAN
# Andy's note: This line *must* be an out-of-place ReLU or it
# will negatively affect the shortcut connection.
h = F.relu(x)
h = x
h = self.conv1(h)
h = self.conv2(self.activation(h))
if self.downsample:
h = self.downsample(h)
return h + self.shortcut(x)
# dogball
Copy link

nurpax commented Apr 6, 2020


augment = False
num_workers = 8
pin_memory = True
shuffle = True
load_in_mem = True
use_multiepoch_sampler = True
model = BigGAN
G_param = SN
D_param = SN
G_ch = 64
D_ch = 64
G_depth = 1
D_depth = 1
D_wide = True
G_shared = False
shared_dim = 128
dim_z = 128
z_var = 1.0
hier = False
self_modulation = True
cross_replica = False
mybn = False
G_nl = inplace_relu
D_nl = inplace_relu
G_attn = 64
D_attn = 64
norm_style = bn
seed = 0
G_init = ortho
D_init = ortho
skip_init = False
G_lr = 0.0001
D_lr = 0.0005
G_B1 = 0.0
D_B1 = 0.0
G_B2 = 0.999
D_B2 = 0.999
batch_size = 20
G_batch_size = 0
num_G_accumulations = 1
num_D_steps = 1
num_D_accumulations = 1
split_D = False
num_epochs = 357
parallel = True
G_fp16 = False
D_fp16 = False
D_mixed_precision = False
G_mixed_precision = False
accumulate_stats = False
num_standing_accumulations = 16
G_eval_mode = True
save_every = 10000
num_save_copies = 2
num_best_copies = 5
which_best = IS
no_fid = False
test_every = 10000
num_inception_images = 50000
hashname = False
base_root = 
data_root = <omitted>
weights_root = weights
logs_root = logs
samples_root = samples
pbar = mine
name_suffix = 
experiment_name = 
config_from_name = False
ema = True
ema_decay = 0.9999
use_ema = True
ema_start = 20000
adam_eps = 1e-06
BN_eps = 1e-05
SN_eps = 1e-06
num_G_SVs = 1
num_D_SVs = 1
num_G_SV_itrs = 1
num_D_SV_itrs = 1
G_ortho = 0.0
D_ortho = 0.0
toggle_grads = True
which_train_fn = GAN
load_weights = 
resume = False
logstyle = %3.3e
log_G_spectra = False
log_D_spectra = False
sv_log_interval = 10

Copy link

nurpax commented Apr 6, 2020

Network dump:

  (activation): ReLU(inplace)
  (shared): identity()
  (linear): SNLinear(in_features=128, out_features=16384, bias=True)
  (blocks): ModuleList(
    (0): ModuleList(
      (0): GBlock(
        (activation): ReLU(inplace)
        (conv1): SNConv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (conv2): SNConv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (conv_sc): SNConv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1))
        (bn1): ccbn(
          out: 1024, in: 128, cross_replica=False
          (gain): Sequential(
            (0): Linear(in_features=128, out_features=128, bias=True)
            (1): ReLU()
            (2): Linear(in_features=128, out_features=1024, bias=False)
          (bias): Sequential(
            (0): Linear(in_features=128, out_features=128, bias=True)
            (1): ReLU()
            (2): Linear(in_features=128, out_features=1024, bias=False)
        (bn2): ccbn(
          out: 1024, in: 128, cross_replica=False
          (gain): Sequential(
            (0): Linear(in_features=128, out_features=128, bias=True)
            (1): ReLU()
            (2): Linear(in_features=128, out_features=1024, bias=False)
          (bias): Sequential(
            (0): Linear(in_features=128, out_features=128, bias=True)
            (1): ReLU()
            (2): Linear(in_features=128, out_features=1024, bias=False)
    (1): ModuleList(
      (0): GBlock(
        (activation): ReLU(inplace)
        (conv1): SNConv2d(1024, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (conv2): SNConv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (conv_sc): SNConv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1))
        (bn1): ccbn(
          out: 1024, in: 128, cross_replica=False
          (gain): Sequential(
            (0): Linear(in_features=128, out_features=128, bias=True)
            (1): ReLU()
            (2): Linear(in_features=128, out_features=1024, bias=False)
          (bias): Sequential(
            (0): Linear(in_features=128, out_features=128, bias=True)
            (1): ReLU()
            (2): Linear(in_features=128, out_features=1024, bias=False)
        (bn2): ccbn(
          out: 512, in: 128, cross_replica=False
          (gain): Sequential(
            (0): Linear(in_features=128, out_features=128, bias=True)
            (1): ReLU()
            (2): Linear(in_features=128, out_features=512, bias=False)
          (bias): Sequential(
            (0): Linear(in_features=128, out_features=128, bias=True)
            (1): ReLU()
            (2): Linear(in_features=128, out_features=512, bias=False)
    (2): ModuleList(
      (0): GBlock(
        (activation): ReLU(inplace)
        (conv1): SNConv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (conv2): SNConv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (conv_sc): SNConv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
        (bn1): ccbn(
          out: 512, in: 128, cross_replica=False
          (gain): Sequential(
            (0): Linear(in_features=128, out_features=128, bias=True)
            (1): ReLU()
            (2): Linear(in_features=128, out_features=512, bias=False)
          (bias): Sequential(
            (0): Linear(in_features=128, out_features=128, bias=True)
            (1): ReLU()
            (2): Linear(in_features=128, out_features=512, bias=False)
        (bn2): ccbn(
          out: 512, in: 128, cross_replica=False
          (gain): Sequential(
            (0): Linear(in_features=128, out_features=128, bias=True)
            (1): ReLU()
            (2): Linear(in_features=128, out_features=512, bias=False)
          (bias): Sequential(
            (0): Linear(in_features=128, out_features=128, bias=True)
            (1): ReLU()
            (2): Linear(in_features=128, out_features=512, bias=False)
    (3): ModuleList(
      (0): GBlock(
        (activation): ReLU(inplace)
        (conv1): SNConv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (conv2): SNConv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (conv_sc): SNConv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))
        (bn1): ccbn(
          out: 512, in: 128, cross_replica=False
          (gain): Sequential(
            (0): Linear(in_features=128, out_features=128, bias=True)
            (1): ReLU()
            (2): Linear(in_features=128, out_features=512, bias=False)
          (bias): Sequential(
            (0): Linear(in_features=128, out_features=128, bias=True)
            (1): ReLU()
            (2): Linear(in_features=128, out_features=512, bias=False)
        (bn2): ccbn(
          out: 256, in: 128, cross_replica=False
          (gain): Sequential(
            (0): Linear(in_features=128, out_features=128, bias=True)
            (1): ReLU()
            (2): Linear(in_features=128, out_features=256, bias=False)
          (bias): Sequential(
            (0): Linear(in_features=128, out_features=128, bias=True)
            (1): ReLU()
            (2): Linear(in_features=128, out_features=256, bias=False)
      (1): Attention(
        (theta): SNConv2d(256, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (phi): SNConv2d(256, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (g): SNConv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (o): SNConv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (4): ModuleList(
      (0): GBlock(
        (activation): ReLU(inplace)
        (conv1): SNConv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (conv2): SNConv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (conv_sc): SNConv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))
        (bn1): ccbn(
          out: 256, in: 128, cross_replica=False
          (gain): Sequential(
            (0): Linear(in_features=128, out_features=128, bias=True)
            (1): ReLU()
            (2): Linear(in_features=128, out_features=256, bias=False)
          (bias): Sequential(
            (0): Linear(in_features=128, out_features=128, bias=True)
            (1): ReLU()
            (2): Linear(in_features=128, out_features=256, bias=False)
        (bn2): ccbn(
          out: 128, in: 128, cross_replica=False
          (gain): Sequential(
            (0): Linear(in_features=128, out_features=128, bias=True)
            (1): ReLU()
            (2): Linear(in_features=128, out_features=128, bias=False)
          (bias): Sequential(
            (0): Linear(in_features=128, out_features=128, bias=True)
            (1): ReLU()
            (2): Linear(in_features=128, out_features=128, bias=False)
    (5): ModuleList(
      (0): GBlock(
        (activation): ReLU(inplace)
        (conv1): SNConv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (conv2): SNConv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (conv_sc): SNConv2d(128, 64, kernel_size=(1, 1), stride=(1, 1))
        (bn1): ccbn(
          out: 128, in: 128, cross_replica=False
          (gain): Sequential(
            (0): Linear(in_features=128, out_features=128, bias=True)
            (1): ReLU()
            (2): Linear(in_features=128, out_features=128, bias=False)
          (bias): Sequential(
            (0): Linear(in_features=128, out_features=128, bias=True)
            (1): ReLU()
            (2): Linear(in_features=128, out_features=128, bias=False)
        (bn2): ccbn(
          out: 64, in: 128, cross_replica=False
          (gain): Sequential(
            (0): Linear(in_features=128, out_features=128, bias=True)
            (1): ReLU()
            (2): Linear(in_features=128, out_features=64, bias=False)
          (bias): Sequential(
            (0): Linear(in_features=128, out_features=128, bias=True)
            (1): ReLU()
            (2): Linear(in_features=128, out_features=64, bias=False)
  (output_layer): Sequential(
    (0): bn()
    (1): ReLU(inplace)
    (2): SNConv2d(64, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (activation): ReLU(inplace)
  (blocks): ModuleList(
    (0): ModuleList(
      (0): DBlock(
        (activation): ReLU(inplace)
        (downsample): AvgPool2d(kernel_size=2, stride=2, padding=0)
        (conv1): SNConv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (conv2): SNConv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (conv_sc): SNConv2d(3, 64, kernel_size=(1, 1), stride=(1, 1))
    (1): ModuleList(
      (0): DBlock(
        (activation): ReLU(inplace)
        (downsample): AvgPool2d(kernel_size=2, stride=2, padding=0)
        (conv1): SNConv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (conv2): SNConv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (conv_sc): SNConv2d(64, 128, kernel_size=(1, 1), stride=(1, 1))
      (1): Attention(
        (theta): SNConv2d(128, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (phi): SNConv2d(128, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (g): SNConv2d(128, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (o): SNConv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (2): ModuleList(
      (0): DBlock(
        (activation): ReLU(inplace)
        (downsample): AvgPool2d(kernel_size=2, stride=2, padding=0)
        (conv1): SNConv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (conv2): SNConv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (conv_sc): SNConv2d(128, 256, kernel_size=(1, 1), stride=(1, 1))
    (3): ModuleList(
      (0): DBlock(
        (activation): ReLU(inplace)
        (downsample): AvgPool2d(kernel_size=2, stride=2, padding=0)
        (conv1): SNConv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (conv2): SNConv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (conv_sc): SNConv2d(256, 512, kernel_size=(1, 1), stride=(1, 1))
    (4): ModuleList(
      (0): DBlock(
        (activation): ReLU(inplace)
        (downsample): AvgPool2d(kernel_size=2, stride=2, padding=0)
        (conv1): SNConv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (conv2): SNConv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (conv_sc): SNConv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
    (5): ModuleList(
      (0): DBlock(
        (activation): ReLU(inplace)
        (downsample): AvgPool2d(kernel_size=2, stride=2, padding=0)
        (conv1): SNConv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (conv2): SNConv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (conv_sc): SNConv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1))
    (6): ModuleList(
      (0): DBlock(
        (activation): ReLU(inplace)
        (conv1): SNConv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (conv2): SNConv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (linear): SNLinear(in_features=1024, out_features=1, bias=True)
  (embed): SNEmbedding(1, 1024)
Number of params in G: 39125124 D: 43422914

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment