Last active
May 3, 2021 19:49
-
-
Save kristian-georgiev/1f5649fbd7cb5a2656fbca5f7f96f902 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import numpy as np | |
from matplotlib import pyplot as plt | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from robustness.tools.custom_modules import SequentialWithArgs, FakeReLU | |
from e2cnn import gspaces | |
from e2cnn import nn as enn | |
def conv7x7(in_type: enn.FieldType, out_type: enn.FieldType, stride=1, padding=3, | |
dilation=1, bias=False): | |
"""7x7 convolution with padding""" | |
return enn.R2Conv(in_type, out_type, 7, | |
stride=stride, | |
padding=padding, | |
dilation=dilation, | |
bias=bias, | |
sigma=None, | |
frequencies_cutoff=lambda r: 3*r, | |
) | |
def conv5x5(in_type: enn.FieldType, out_type: enn.FieldType, stride=1, padding=2, | |
dilation=1, bias=False): | |
"""5x5 convolution with padding""" | |
return enn.R2Conv(in_type, out_type, 5, | |
stride=stride, | |
padding=padding, | |
dilation=dilation, | |
bias=bias, | |
sigma=None, | |
frequencies_cutoff=lambda r: 3*r, | |
) | |
def conv3x3(in_type: enn.FieldType, out_type: enn.FieldType, stride=1, padding=1, | |
dilation=1, bias=False): | |
"""3x3 convolution with padding""" | |
return enn.R2Conv(in_type, out_type, 3, | |
stride=stride, | |
padding=padding, | |
dilation=dilation, | |
bias=bias, | |
sigma=None, | |
frequencies_cutoff=lambda r: 3*r, | |
) | |
def conv1x1(in_type: enn.FieldType, out_type: enn.FieldType, stride=1, padding=0, | |
dilation=1, bias=False): | |
"""1x1 convolution with padding""" | |
return enn.R2Conv(in_type, out_type, 1, | |
stride=stride, | |
padding=padding, | |
dilation=dilation, | |
bias=bias, | |
sigma=None, | |
frequencies_cutoff=lambda r: 3*r, | |
) | |
class EquivariantResNet(nn.Module): | |
def __init__(self, N, block, num_blocks, num_classes=10, feat_scale=1, wm=1): | |
super(EquivariantResNet, self).__init__() | |
widths = [64, 128, 256, 512] | |
widths = [int(w * wm) for w in widths] | |
self.in_planes = widths[0] | |
self.N = N | |
if self.N < 0: | |
self.gspace = gspaces.Rot2dOnR2(N=-1, maximum_frequency=-self.N) | |
self.out_type, trivials, others, irreps, labels = self.get_irreps_field_type( | |
self.in_planes, conv7x7, 7) | |
else: | |
self.gspace = gspaces.Rot2dOnR2(N=self.N) | |
num_planes = self.in_planes // N | |
self.out_type = enn.FieldType(self.gspace, [self.gspace.regular_repr] * num_planes) | |
labels = ["trivial" for r in self.out_type] | |
trivials = self.out_type | |
others = [] | |
irreps = [] | |
self.input_type = enn.FieldType(self.gspace, [self.gspace.trivial_repr] * 3) | |
self.in_type = self.input_type | |
self.conv1 = conv7x7(self.in_type, self.out_type, stride=2) | |
modules = [(enn.InnerBatchNorm(trivials), "trivial")] | |
if len(irreps) > 0: | |
modules += [(enn.NormBatchNorm(others), "others")] | |
self.bn1 = enn.MultipleModule(self.out_type, labels, modules) | |
modules = [(enn.ELU(trivials), "trivial")] | |
if len(irreps) > 0: | |
modules += [(enn.NormNonLinearity(others, function="n_relu"), "others")] | |
self.relu1 = enn.MultipleModule(self.out_type, labels, modules) | |
modules = [(enn.PointwiseMaxPool(trivials, kernel_size=3, stride=2, padding=1), "trivial")] | |
if len(irreps) > 0: | |
modules += [(enn.NormMaxPool(others, kernel_size=3, stride=2, padding=1), "others")] | |
self.maxpool = enn.MultipleModule(self.out_type, labels, modules) | |
self.in_type = self.out_type | |
self.layer1 = self._make_layer(block, widths[0], num_blocks[0], stride=1) | |
self.layer2 = self._make_layer(block, widths[1], num_blocks[1], stride=2) | |
self.layer3 = self._make_layer(block, widths[2], num_blocks[2], stride=2) | |
self.layer4 = self._make_layer(block, widths[3], num_blocks[3], stride=2, totrivial=True) | |
if self. N <= 0: | |
feat_len = feat_scale * widths[3] * block.expansion | |
# feat_len = 268 # hardcoded for 225x225, N=-3 | |
feat_len = 306 # hardcoded for 225x225, N=-5 | |
else: | |
feat_len = feat_scale * widths[3] * block.expansion // N | |
self.linear = nn.Linear(feat_len, num_classes) | |
self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) | |
print("MODEL TOPOLOGY:") | |
for i, (name, mod) in enumerate(self.named_modules()): | |
print(f"\t{i} - {name}") | |
def _make_layer(self, block, planes, num_blocks, stride, totrivial=False): | |
strides = [stride] + [1] * (num_blocks - 1) | |
layers = [] | |
if self.N < 0: | |
self.out_type, trivials, others, irreps, labels = self.get_irreps_field_type(planes, conv3x3, 3) | |
else: | |
num_planes = planes // self.N | |
self.out_type = enn.FieldType(self.gspace, [self.gspace.regular_repr] * num_planes) | |
labels = ["trivial" for r in self.out_type] | |
trivials = self.out_type | |
others = [] | |
irreps = [] | |
for ind_stride, stride in enumerate(strides): | |
is_trivial = totrivial and (ind_stride == num_blocks - 1) | |
layers.append(block(self.gspace, self.in_type, self.out_type, self.in_planes, planes, | |
trivials, others, irreps, labels, self.N, stride, totrivial=is_trivial)) | |
self.in_type = self.out_type | |
self.in_planes = planes * block.expansion | |
return SequentialWithArgs(*layers) | |
def get_irreps_field_type(self, planes, conv_type, conv_size): | |
irreps = [] | |
if self.N >= 0: | |
out_type = enn.FieldType(self.gspace, [self.gspace.regular_repr] * planes) | |
else: | |
gc = self.gspace | |
for n, irr in gc.fibergroup.irreps.items(): | |
if n != gc.trivial_repr.name: | |
irreps += [irr] * int(irr.size // irr.sum_of_squares_constituents) | |
irreps = list(irreps) | |
C = self.get_num_channels_for_fixed_params(planes, irreps, conv_type, conv_size) | |
trivials = enn.FieldType(gc, [gc.trivial_repr] * C) | |
if len(irreps) > 0: | |
others = enn.FieldType(gc, irreps * C).sorted() | |
out_type = trivials + others | |
else: | |
others = [] | |
out_type = trivials | |
labels = ["trivial" if r.is_trivial() else "others" for r in out_type] | |
out_types = out_type.group_by_labels(labels) | |
trivials = out_types["trivial"] | |
others = out_types["others"] | |
for r in trivials: | |
r.supported_nonlinearities.add("pointwise") | |
return out_type, trivials, others, irreps, labels | |
def get_num_channels_for_fixed_params(self, channels, irreps, conv_type, conv_size): | |
gc = self.gspace | |
r_in = enn.FieldType(gc, [gc.trivial_repr] + irreps) | |
r_out = enn.FieldType(gc, [gc.trivial_repr] + irreps) | |
tmp_cl = conv_type(r_in, r_out) | |
t = tmp_cl.basisexpansion.dimension() | |
t /= conv_size ** 2 / 12 # * 12 and remove / 12 | |
C = int(round(channels / np.sqrt(t))) | |
print(f'old num channels: {channels}, new num channels: {C}, factor: {t}') | |
return C | |
def forward(self, x, with_latent=False, fake_relu=False, no_relu=False): | |
x = enn.GeometricTensor(x, self.input_type) | |
out = self.conv1(x) | |
out = self.relu1(self.bn1(out)) | |
out = self.maxpool(out) | |
out = self.layer1(out) | |
out = self.layer2(out) | |
out = self.layer3(out) | |
out = self.layer4(out) | |
preout = out.tensor | |
out = out.tensor | |
out = self.avgpool(out) | |
out = torch.flatten(out, 1) | |
final = self.linear(out) | |
if with_latent: | |
return final, preout | |
return final | |
class EquivariantBasicBlock(nn.Module): | |
expansion = 1 | |
def __init__(self, gspace, in_type, out_type, in_planes, planes, trivials, others, irreps, | |
labels, N, stride=1, totrivial=False): | |
super(EquivariantBasicBlock, self).__init__() | |
self.conv1 = conv3x3(in_type, out_type, stride=stride) | |
modules = [(enn.InnerBatchNorm(trivials), "trivial")] | |
if len(irreps) > 0: | |
modules += [(enn.NormBatchNorm(others), "others")] | |
self.bn1 = enn.MultipleModule(out_type, labels, modules) | |
modules = [(enn.ELU(trivials), "trivial")] | |
if len(irreps) > 0: | |
modules += [(enn.NormNonLinearity(others, function="n_relu"), "others")] | |
self.relu1 = enn.MultipleModule(out_type, labels, modules) | |
inner_type = out_type | |
print("fibergroup order:", gspace.fibergroup.order()) | |
self.conv2 = conv3x3(inner_type, out_type) | |
modules = [(enn.InnerBatchNorm(trivials), "trivial")] | |
if len(irreps) > 0: | |
modules += [(enn.NormBatchNorm(others), "others")] | |
self.bn2 = enn.MultipleModule(out_type, labels, modules) | |
self.invariant_map = None | |
if totrivial: | |
print('should happen only once at the end') | |
if N <= 0: | |
modules = [(enn.IdentityModule(trivials), "trivial"), | |
(enn.NormPool(others), "others")] | |
self.invariant_map = enn.MultipleModule(out_type, labels, modules) | |
else: | |
self.invariant_map = enn.GroupPooling(out_type) | |
out_type = self.invariant_map.out_type | |
labels = ["trivial" if r.is_trivial() else "others" for r in out_type] | |
out_types = out_type.group_by_labels(labels) | |
trivials = out_types["trivial"] | |
for r in trivials: | |
r.supported_nonlinearities.add("pointwise") | |
irreps = [] | |
others = [] | |
self.shortcut = None | |
if stride != 1 or in_planes != planes: | |
if self.invariant_map is not None: | |
self.shortcut = enn.SequentialModule( | |
conv1x1(in_type, out_type, stride=stride), | |
self.bn2, | |
self.invariant_map) | |
else: | |
self.shortcut = enn.SequentialModule( | |
conv1x1(in_type, out_type, stride=stride), | |
self.bn2) | |
modules = [(enn.ELU(trivials), "trivial")] | |
if len(irreps) > 0: | |
modules += [(enn.NormNonLinearity(others, function="n_relu"), "others")] | |
self.relu2 = enn.MultipleModule(out_type, labels, modules) | |
def forward(self, x, fake_relu=False): | |
out = self.relu1(self.bn1(self.conv1(x))) | |
out = self.bn2(self.conv2(out)) | |
if self.invariant_map is not None: | |
out = self.invariant_map(out) | |
if self.shortcut is not None: | |
out += self.shortcut(x) | |
return self.relu2(out) | |
def EquivariantResNet18(N, **kwargs): | |
""" | |
ResNet18 equivariant to rotations discretized to C_N. | |
N < 0 indicates equivariance to SO(2), i.e. all rotations, | |
and then N indicates the maximum frequency irrep to use | |
Due to pooling/dilation, the image resolution | |
needs to be of the form 32k + 1, e.g. 225x225 | |
For SO(2)-invariant networks, change feat_len accordingly. | |
""" | |
return EquivariantResNet(N=N, block=EquivariantBasicBlock, | |
num_blocks=[2,2,2,2], **kwargs) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment