Skip to content

Instantly share code, notes, and snippets.

@kristian-georgiev
Last active May 3, 2021 19:49
Show Gist options
  • Save kristian-georgiev/1f5649fbd7cb5a2656fbca5f7f96f902 to your computer and use it in GitHub Desktop.
Save kristian-georgiev/1f5649fbd7cb5a2656fbca5f7f96f902 to your computer and use it in GitHub Desktop.
import torch
import numpy as np
from matplotlib import pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
from robustness.tools.custom_modules import SequentialWithArgs, FakeReLU
from e2cnn import gspaces
from e2cnn import nn as enn
def conv7x7(in_type: enn.FieldType, out_type: enn.FieldType, stride=1, padding=3,
dilation=1, bias=False):
"""7x7 convolution with padding"""
return enn.R2Conv(in_type, out_type, 7,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias,
sigma=None,
frequencies_cutoff=lambda r: 3*r,
)
def conv5x5(in_type: enn.FieldType, out_type: enn.FieldType, stride=1, padding=2,
dilation=1, bias=False):
"""5x5 convolution with padding"""
return enn.R2Conv(in_type, out_type, 5,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias,
sigma=None,
frequencies_cutoff=lambda r: 3*r,
)
def conv3x3(in_type: enn.FieldType, out_type: enn.FieldType, stride=1, padding=1,
dilation=1, bias=False):
"""3x3 convolution with padding"""
return enn.R2Conv(in_type, out_type, 3,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias,
sigma=None,
frequencies_cutoff=lambda r: 3*r,
)
def conv1x1(in_type: enn.FieldType, out_type: enn.FieldType, stride=1, padding=0,
dilation=1, bias=False):
"""1x1 convolution with padding"""
return enn.R2Conv(in_type, out_type, 1,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias,
sigma=None,
frequencies_cutoff=lambda r: 3*r,
)
class EquivariantResNet(nn.Module):
def __init__(self, N, block, num_blocks, num_classes=10, feat_scale=1, wm=1):
super(EquivariantResNet, self).__init__()
widths = [64, 128, 256, 512]
widths = [int(w * wm) for w in widths]
self.in_planes = widths[0]
self.N = N
if self.N < 0:
self.gspace = gspaces.Rot2dOnR2(N=-1, maximum_frequency=-self.N)
self.out_type, trivials, others, irreps, labels = self.get_irreps_field_type(
self.in_planes, conv7x7, 7)
else:
self.gspace = gspaces.Rot2dOnR2(N=self.N)
num_planes = self.in_planes // N
self.out_type = enn.FieldType(self.gspace, [self.gspace.regular_repr] * num_planes)
labels = ["trivial" for r in self.out_type]
trivials = self.out_type
others = []
irreps = []
self.input_type = enn.FieldType(self.gspace, [self.gspace.trivial_repr] * 3)
self.in_type = self.input_type
self.conv1 = conv7x7(self.in_type, self.out_type, stride=2)
modules = [(enn.InnerBatchNorm(trivials), "trivial")]
if len(irreps) > 0:
modules += [(enn.NormBatchNorm(others), "others")]
self.bn1 = enn.MultipleModule(self.out_type, labels, modules)
modules = [(enn.ELU(trivials), "trivial")]
if len(irreps) > 0:
modules += [(enn.NormNonLinearity(others, function="n_relu"), "others")]
self.relu1 = enn.MultipleModule(self.out_type, labels, modules)
modules = [(enn.PointwiseMaxPool(trivials, kernel_size=3, stride=2, padding=1), "trivial")]
if len(irreps) > 0:
modules += [(enn.NormMaxPool(others, kernel_size=3, stride=2, padding=1), "others")]
self.maxpool = enn.MultipleModule(self.out_type, labels, modules)
self.in_type = self.out_type
self.layer1 = self._make_layer(block, widths[0], num_blocks[0], stride=1)
self.layer2 = self._make_layer(block, widths[1], num_blocks[1], stride=2)
self.layer3 = self._make_layer(block, widths[2], num_blocks[2], stride=2)
self.layer4 = self._make_layer(block, widths[3], num_blocks[3], stride=2, totrivial=True)
if self. N <= 0:
feat_len = feat_scale * widths[3] * block.expansion
# feat_len = 268 # hardcoded for 225x225, N=-3
feat_len = 306 # hardcoded for 225x225, N=-5
else:
feat_len = feat_scale * widths[3] * block.expansion // N
self.linear = nn.Linear(feat_len, num_classes)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
print("MODEL TOPOLOGY:")
for i, (name, mod) in enumerate(self.named_modules()):
print(f"\t{i} - {name}")
def _make_layer(self, block, planes, num_blocks, stride, totrivial=False):
strides = [stride] + [1] * (num_blocks - 1)
layers = []
if self.N < 0:
self.out_type, trivials, others, irreps, labels = self.get_irreps_field_type(planes, conv3x3, 3)
else:
num_planes = planes // self.N
self.out_type = enn.FieldType(self.gspace, [self.gspace.regular_repr] * num_planes)
labels = ["trivial" for r in self.out_type]
trivials = self.out_type
others = []
irreps = []
for ind_stride, stride in enumerate(strides):
is_trivial = totrivial and (ind_stride == num_blocks - 1)
layers.append(block(self.gspace, self.in_type, self.out_type, self.in_planes, planes,
trivials, others, irreps, labels, self.N, stride, totrivial=is_trivial))
self.in_type = self.out_type
self.in_planes = planes * block.expansion
return SequentialWithArgs(*layers)
def get_irreps_field_type(self, planes, conv_type, conv_size):
irreps = []
if self.N >= 0:
out_type = enn.FieldType(self.gspace, [self.gspace.regular_repr] * planes)
else:
gc = self.gspace
for n, irr in gc.fibergroup.irreps.items():
if n != gc.trivial_repr.name:
irreps += [irr] * int(irr.size // irr.sum_of_squares_constituents)
irreps = list(irreps)
C = self.get_num_channels_for_fixed_params(planes, irreps, conv_type, conv_size)
trivials = enn.FieldType(gc, [gc.trivial_repr] * C)
if len(irreps) > 0:
others = enn.FieldType(gc, irreps * C).sorted()
out_type = trivials + others
else:
others = []
out_type = trivials
labels = ["trivial" if r.is_trivial() else "others" for r in out_type]
out_types = out_type.group_by_labels(labels)
trivials = out_types["trivial"]
others = out_types["others"]
for r in trivials:
r.supported_nonlinearities.add("pointwise")
return out_type, trivials, others, irreps, labels
def get_num_channels_for_fixed_params(self, channels, irreps, conv_type, conv_size):
gc = self.gspace
r_in = enn.FieldType(gc, [gc.trivial_repr] + irreps)
r_out = enn.FieldType(gc, [gc.trivial_repr] + irreps)
tmp_cl = conv_type(r_in, r_out)
t = tmp_cl.basisexpansion.dimension()
t /= conv_size ** 2 / 12 # * 12 and remove / 12
C = int(round(channels / np.sqrt(t)))
print(f'old num channels: {channels}, new num channels: {C}, factor: {t}')
return C
def forward(self, x, with_latent=False, fake_relu=False, no_relu=False):
x = enn.GeometricTensor(x, self.input_type)
out = self.conv1(x)
out = self.relu1(self.bn1(out))
out = self.maxpool(out)
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
preout = out.tensor
out = out.tensor
out = self.avgpool(out)
out = torch.flatten(out, 1)
final = self.linear(out)
if with_latent:
return final, preout
return final
class EquivariantBasicBlock(nn.Module):
expansion = 1
def __init__(self, gspace, in_type, out_type, in_planes, planes, trivials, others, irreps,
labels, N, stride=1, totrivial=False):
super(EquivariantBasicBlock, self).__init__()
self.conv1 = conv3x3(in_type, out_type, stride=stride)
modules = [(enn.InnerBatchNorm(trivials), "trivial")]
if len(irreps) > 0:
modules += [(enn.NormBatchNorm(others), "others")]
self.bn1 = enn.MultipleModule(out_type, labels, modules)
modules = [(enn.ELU(trivials), "trivial")]
if len(irreps) > 0:
modules += [(enn.NormNonLinearity(others, function="n_relu"), "others")]
self.relu1 = enn.MultipleModule(out_type, labels, modules)
inner_type = out_type
print("fibergroup order:", gspace.fibergroup.order())
self.conv2 = conv3x3(inner_type, out_type)
modules = [(enn.InnerBatchNorm(trivials), "trivial")]
if len(irreps) > 0:
modules += [(enn.NormBatchNorm(others), "others")]
self.bn2 = enn.MultipleModule(out_type, labels, modules)
self.invariant_map = None
if totrivial:
print('should happen only once at the end')
if N <= 0:
modules = [(enn.IdentityModule(trivials), "trivial"),
(enn.NormPool(others), "others")]
self.invariant_map = enn.MultipleModule(out_type, labels, modules)
else:
self.invariant_map = enn.GroupPooling(out_type)
out_type = self.invariant_map.out_type
labels = ["trivial" if r.is_trivial() else "others" for r in out_type]
out_types = out_type.group_by_labels(labels)
trivials = out_types["trivial"]
for r in trivials:
r.supported_nonlinearities.add("pointwise")
irreps = []
others = []
self.shortcut = None
if stride != 1 or in_planes != planes:
if self.invariant_map is not None:
self.shortcut = enn.SequentialModule(
conv1x1(in_type, out_type, stride=stride),
self.bn2,
self.invariant_map)
else:
self.shortcut = enn.SequentialModule(
conv1x1(in_type, out_type, stride=stride),
self.bn2)
modules = [(enn.ELU(trivials), "trivial")]
if len(irreps) > 0:
modules += [(enn.NormNonLinearity(others, function="n_relu"), "others")]
self.relu2 = enn.MultipleModule(out_type, labels, modules)
def forward(self, x, fake_relu=False):
out = self.relu1(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
if self.invariant_map is not None:
out = self.invariant_map(out)
if self.shortcut is not None:
out += self.shortcut(x)
return self.relu2(out)
def EquivariantResNet18(N, **kwargs):
"""
ResNet18 equivariant to rotations discretized to C_N.
N < 0 indicates equivariance to SO(2), i.e. all rotations,
and then N indicates the maximum frequency irrep to use
Due to pooling/dilation, the image resolution
needs to be of the form 32k + 1, e.g. 225x225
For SO(2)-invariant networks, change feat_len accordingly.
"""
return EquivariantResNet(N=N, block=EquivariantBasicBlock,
num_blocks=[2,2,2,2], **kwargs)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment