Skip to content

Instantly share code, notes, and snippets.

@ndgnuh
Last active October 11, 2023 09:01
Show Gist options
  • Save ndgnuh/c9891cf987c7ab3333e9d4353fff07a6 to your computer and use it in GitHub Desktop.
Save ndgnuh/c9891cf987c7ab3333e9d4353fff07a6 to your computer and use it in GitHub Desktop.
Custom torch layers, modules and utilities, ready to be copy-and-pasted
import torch
from torch import nn
from torch.autograd import Function
from torch.nn import functional as F
class BackwardGradNormFn(Function):
"""
A normalization layer that does nothing to the input, but
normalize the gradient.
Reference: https://arxiv.org/abs/2106.09475
Very cool idea, I tried applying this on the convolution stem
(the first two convs layer which scale down resolutions) and
it is quite good.
I'm not sure about applying it everywhere like the paper said though.
"""
@staticmethod
def forward(ctx, input):
return input
@staticmethod
def backward(ctx, grad_output):
norm = torch.norm(grad_output)
if norm > 0:
grad_output = grad_output / (norm)
# grad_output = torch.clamp(grad_output, -1000, 1000)
return grad_output
class BackwardGradNorm(nn.Module):
def forward(self, x):
if self.training:
return BackwardGradNormFn.apply(x)
else:
return x
class AccNorm(nn.Module):
"""Don't have 8 NVIDIA A100-s for the 100-batchsize? Gotcha!
This is a normalization hack to:
- gain the benefit from batch normalization without
having to crank up the batch size or having to have
the GPU to do so; normalization for everyone!
- deal with the annoying drawbacks from batch normalization, such as
train/validate performance difference, batch size dependent.
"""
def __init__(
self,
hidden_size: int,
virtual_batch_size: int = 75,
eps: float = 1e-5,
momentum: float = 0.1,
):
super().__init__()
self.momentum = momentum
self.eps = eps
self.T = virtual_batch_size
shape = (1, hidden_size, 1, 1)
self.weight = nn.Parameter(torch.ones(shape))
self.bias = nn.Parameter(torch.zeros(shape))
self.register_buffer("t", torch.tensor(0))
self.register_buffer("t0", torch.tensor(1))
self.register_buffer("mean", torch.zeros(shape))
self.register_buffer("acc_mean", torch.zeros(shape))
self.register_buffer("var", torch.ones(shape))
self.register_buffer("acc_var", torch.ones(shape))
self.register_buffer("std", torch.ones(shape))
@torch.no_grad()
def update(self, x):
var, mean = torch.var_mean(x, (-2, -1), keepdim=True)
bsize = x.shape[0]
self.t = self.t + bsize
self.acc_mean = self.acc_mean + mean.sum(dim=0, keepdim=True)
self.acc_var = self.acc_var + var.sum(dim=0, keepdim=True)
if self.t >= self.t0:
# Calculate mean statistics
mean = self.acc_mean / self.t
var = self.acc_var / self.t
# Update running stats
mom = self.momentum
self.mean = self.mean * (1 - mom) + mom * mean
self.var = self.var * (1 - mom) + mom * var
self.std = torch.sqrt(self.var + self.eps)
# Reset accumulator
self.t.fill_(0)
self.acc_mean.fill_(0)
self.acc_var.fill_(1)
# Scale up the virtual batch size until reaching the limit
t1 = (self.t0 * 1.5).type(torch.long)
self.t0 = torch.clamp(t1, 1, self.T)
def forward(self, x):
if self.training:
self.update(x)
mean, std = self.mean, self.std
x = (x - mean) / std
x = x * self.weight + self.bias
return x
class AdaGreedNorm(nn.Module):
"""My goated normalization layer. Based on Adam and existing normalization layers"""
def __init__(self, num_channels: int, eps=1e-5, betas=(0.9, 0.999)):
super().__init__()
self.weight = nn.Parameter(torch.ones(1, num_channels, 1, 1))
self.bias = nn.Parameter(torch.zeros(1, num_channels, 1, 1))
self.register_buffer("need_init", torch.tensor(True))
self.eps = eps
self.betas = betas
self.register_buffer("m_t", torch.zeros(1))
self.register_buffer("v_t", torch.zeros(1))
self.register_buffer("v_t_max", torch.zeros(1))
self.register_buffer("t", torch.ones(1))
self.register_buffer("m_t_hat", torch.zeros(1))
self.register_buffer("v_t_hat", torch.zeros(1))
self.register_buffer("mean", torch.zeros(1))
self.register_buffer("std", torch.ones(1))
@torch.no_grad()
def update_stats(self, x):
eps = self.eps
b1, b2 = self.betas
# This is why it is called greedy
v, m = torch.var_mean(x)
# Update running stats
self.m_t = self.m_t * b1 + (1 - b1) * m
self.v_t = self.v_t * b2 + (1 - b2) * v
self.m_t_hat = self.m_t / (1 - b1**self.t)
self.v_t_hat = self.v_t / (1 - b2**self.t)
self.v_t_max = torch.maximum(self.v_t_max, self.v_t_hat)
self.t = self.t + 1
# Calculate shift and std
self.mean = self.m_t
self.std = torch.sqrt(self.v_t_max + eps)
def forward(self, x):
# Training
if self.training:
self.update_stats(x)
# Standardize
x = (x - self.mean) / self.std
x = x * self.weight + self.bias
return x
class WSConv2d(nn.Conv2d):
"""Weight standardized Convolution layer.
Ref: https://arxiv.org/abs/1903.10520v2
"""
def __init__(self, *args, eps=1e-5, gain=True, **kwargs):
super().__init__(*args, **kwargs)
self.eps = eps
if gain:
self.gain = nn.Parameter(torch.ones(self.out_channels, 1, 1, 1))
else:
self.gain = 1
ks = self.kernel_size
self.fan_in = ks[0] * ks[1] * self.in_channels
self.register_buffer("nweight", torch.ones_like(self.weight))
def get_weight(self):
weight = self.weight
fan_in = self.fan_in
eps = self.eps
if self.training:
var, mean = torch.var_mean(weight, dim=(1, 2, 3), keepdim=True)
# Standardize
weight = (weight - mean) / torch.sqrt(var * fan_in + eps)
# Ha! Self, gain weight, get it?
weight = self.gain * weight
self.nweight = weight.clone().detach()
else:
weight = self.nweight
return weight
def forward(self, x):
weight = self.get_weight()
return F.conv2d(
x,
weight=weight,
bias=self.bias,
padding=self.padding,
dilation=self.dilation,
groups=self.groups,
stride=self.stride,
)
from typing import List, Union
import torch
from torch import nn
class GlobalResponseNorm(nn.Module):
def __init__(self, channels: int, eps: float = 1e-5):
super().__init__()
self.weight = nn.Parameter(torch.randn(1, 1, 1, channels))
self.bias = nn.Parameter(torch.randn(1, 1, 1, channels))
self.eps = eps
def forward(self, x):
# x dims: B H W C
Gx = torch.norm(x, dim=(-2, -3), p=2, keepdim=True)
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + self.eps)
x = x + (x * Nx) * self.weight + self.bias
return x
class PermuteDim(nn.Module):
def __init__(self, src: str, dst: str):
super().__init__()
self.perm = [src.index(s) for s in dst]
self.extra_repr = lambda: f"from='{src}', to='{dst}', perm={self.perm}"
def forward(self, x):
x = x.permute(self.perm)
return x
class ConvNextBlock(nn.Module):
def __init__(self, channels: int, expansion: int = 4):
super().__init__()
self.conv_mlp = nn.Sequential(
nn.Conv2d(channels, channels, 7, padding=3, groups=channels),
PermuteDim("bchw", "bhwc"),
nn.LayerNorm(channels),
nn.Linear(channels, channels * expansion),
nn.GELU(approximate="tanh"),
GlobalResponseNorm(channels * expansion),
nn.Linear(channels * expansion, channels),
PermuteDim("bhwc", "bchw"),
)
def forward(self, x):
return self.conv_mlp(x) + x
class DownSample(nn.Sequential):
def __init__(
self,
in_channels: int,
out_channels: int,
stride: int = 2,
prenorm: bool = True,
):
super().__init__()
down = nn.Conv2d(in_channels, out_channels, stride, stride)
if not prenorm:
self.down = down
self.ch_last = PermuteDim("bchw", "bhwc")
self.norm = nn.LayerNorm(in_channels if prenorm else out_channels)
self.ch_first = PermuteDim("bhwc", "bchw")
if prenorm:
self.down = down
class ConvNext(nn.Module):
def __init__(
self,
channels: List[int],
num_layers: List[int],
expansion: int = 4,
strides: Union[int, List[int]] = 2,
patch_size: int = 4,
):
super().__init__()
layers = [DownSample(3, channels[0], patch_size, prenorm=False)]
n = len(num_layers)
if isinstance(strides, int):
strides = [strides] * n
for i, nl in enumerate(num_layers):
c1 = channels[i]
c2 = channels[i + 1]
stride = strides[i]
for _ in range(nl):
layers.append(ConvNextBlock(c1, expansion))
if i != n - 1:
layers.append(DownSample(c1, c2, stride=stride))
self.layers = nn.ModuleList(layers)
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
if __name__ == "__main__":
model = ConvNext([20, 40, 60, 80, 80], [2, 2, 6, 2])
# Reference: https://arxiv.org/abs/1603.07285
def get_conv_output_size_1(x: int, k: int, s: int = 1, p: int = 0, d: int = 0):
"""Get output resolution of the convolution operation.
For reference, see https://arxiv.org/abs/1603.07285.
Args:
x (int): The input resolution
k (int): Kernel size
s (int): Stride (Default: 1)
p (int): Padding (Default: 0)
d (int): Dilation (Default: 0)
Returns:
_ (int): The output resolution
"""
if d > 0:
k = k + (k - 1) * (d - 1)
return int((x + 2 * p - k) / s) + 1
def get_conv_output_size(x, *configs):
"""
Get output resolution of the convolution operation.
This function uses `get_conv_output_size_1`.
Args:
x (int): The input resolution
configs (List[Tuple]):
List of tuples of (kernel size, stride, padding, dilation).
Returns:
_ (int): The output resolution
"""
for config in configs:
x = get_conv_output_size_1(x, *config)
return x
import torch
from torch import Tensor, nn
# Corner pooling: unbind-stack version
@torch.jit.script
def corner_pool(x: Tensor, dim: int, flip: bool):
sz = x.size(dim)
outputs = list(x.unbind(dim))
for i in range(1, sz):
if flip:
i_in = sz - i
i_out = sz - i - 1
else:
i_in = i - 1
i_out = i
outputs[i_out] = torch.maximum(outputs[i_out], outputs[i_in])
return torch.stack(outputs, dim=dim)
class TopPool(nn.Module):
def forward(self, x):
return corner_pool(x, -2, True)
class BottomPool(nn.Module):
def forward(self, x):
return corner_pool(x, -2, False)
class LeftPool(nn.Module):
def forward(self, x):
return corner_pool(x, -1, True)
class RightPool(nn.Module):
def forward(self, x):
return corner_pool(x, -1, False)
import copy
from typing import Optional
import torch
from torch import autograd, nn
class ReversibleFN(autograd.Function):
@staticmethod
def forward(ctx, Fm, Gm, x, *params):
x = x.detach()
with torch.no_grad():
x1, x2 = torch.chunk(x, chunks=2, dim=1)
y1 = x1 + Fm(x2)
y2 = x2 + Gm(y1)
y = torch.cat((y1, y2), dim=1)
del x1, x2, y1, y2
ctx.Fm = Fm
ctx.Gm = Gm
ctx.save_for_backward(x)
return y
@staticmethod
def backward(ctx, grad_output):
Fm = ctx.Fm
Gm = ctx.Gm
Fparams = tuple(Fm.parameters())
Gparams = tuple(Gm.parameters())
x = ctx.saved_tensors[0]
x1, x2 = torch.chunk(x, 2, dim=1)
# compute outputs building a sub-graph
with torch.set_grad_enabled(True):
x1.requires_grad = True
x2.requires_grad = True
y1 = x1 + Fm(x2)
y2 = x2 + Gm(y1)
y = torch.cat([y1, y2], dim=1)
inputs = (x1, x2) + Fparams + Gparams
grads = autograd.grad(y, inputs, grad_output)
grad_input = torch.cat([grads[0], grads[1]], dim=1)
return (None, None, grad_input) + tuple(grads[2:])
class Reversible(nn.Module):
def __init__(self, Fm: nn.Module, Gm: Optional[nn.Module] = None):
super().__init__()
self.Fm = Fm
if Gm is None:
Gm = copy.deepcopy(Fm)
self.Gm = Gm
def forward(self, x):
if self.training:
params = list(self.Fm.parameters()) + list(self.Gm.parameters())
y = ReversibleFN.apply(self.Fm, self.Gm, x, *params)
else:
x1, x2 = torch.chunk(x, chunks=2, dim=1)
y1 = x1 + self.Fm(x2)
y2 = x2 + self.Gm(y1)
y = torch.cat((y1, y2), dim=1)
return y
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment