Skip to content

Instantly share code, notes, and snippets.

@ahyunSeo
Created September 24, 2022 15:50
Show Gist options
  • Save ahyunSeo/3cd6d388fa5d17c55fe7126c8db7e373 to your computer and use it in GitHub Desktop.
Save ahyunSeo/3cd6d388fa5d17c55fe7126c8db7e373 to your computer and use it in GitHub Desktop.
Check equivariance
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from e2cnn import gspaces
from e2cnn import nn as enn
def conv7x7(in_type: enn.FieldType, out_type: enn.FieldType, stride=1, padding=0,
dilation=1, bias=False):
"""7x7 convolution with padding"""
return enn.R2Conv(in_type, out_type, 7,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias,
sigma=None,
frequencies_cutoff=lambda r: 3*r,
)
def conv5x5(in_type: enn.FieldType, out_type: enn.FieldType, stride=1, padding=2,
dilation=1, bias=False):
"""5x5 convolution with padding"""
return enn.R2Conv(in_type, out_type, 5,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias,
sigma=None,
frequencies_cutoff=lambda r: 3*r,
)
def conv3x3(in_type: enn.FieldType, out_type: enn.FieldType, stride=1, padding=1,
dilation=1, bias=False):
"""3x3 convolution with padding"""
return enn.R2Conv(in_type, out_type, 3,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias,
sigma=None,
frequencies_cutoff=lambda r: 3*r,
)
def conv1x1(in_type: enn.FieldType, out_type: enn.FieldType, stride=1, padding=0,
dilation=1, bias=False):
"""1x1 convolution with padding"""
return enn.R2Conv(in_type, out_type, 1,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias,
sigma=None,
frequencies_cutoff=lambda r: 3*r,
)
if __name__ == "__main__":
N = 8
gspace = gspaces.Rot2dOnR2(N=N)
in_type = enn.FieldType(gspace, [gspace.trivial_repr] * 3)
out_type = enn.FieldType(gspace, [gspace.regular_repr] * 4)
conv1 = conv3x3(in_type, out_type, stride=1, padding=1)
model = conv1
model.cuda()
model.eval()
# for _ in range(repeats):
for img_size in [181, 183, 185, 187]:
x = torch.randn([1, 3, img_size, img_size]).cuda()
x = enn.GeometricTensor(x, in_type)
xrot1 = x.transform(N-1) # 45 deg
xrot2 = x.transform(N-2) # 90 deg
with torch.no_grad():
lat = model(x)
latrot1 = model(xrot1)
latrot2 = model(xrot2)
w = lat.shape[-1]
center = int((w-1)/2)
latrot1 = latrot1.transform(1) # 45 deg
latrot2 = latrot2.transform(2) # 90 deg
print(lat[0, 0, center, center].tensor.view(-1))
print(latrot1[0, 0, center, center].tensor.view(-1))
print(latrot2[0, 0, center, center].tensor.view(-1))
print(lat.shape)
print(torch.allclose(lat[0, 0, center, center].tensor, \
latrot1[0, 0, center, center].tensor))
print(torch.allclose(lat[0, 0, center, center].tensor, \
latrot2[0, 0, center, center].tensor))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment