Created
September 24, 2022 15:50
-
-
Save ahyunSeo/3cd6d388fa5d17c55fe7126c8db7e373 to your computer and use it in GitHub Desktop.
Check equivariance
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import numpy as np | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from e2cnn import gspaces | |
from e2cnn import nn as enn | |
def conv7x7(in_type: enn.FieldType, out_type: enn.FieldType, stride=1, padding=0, | |
dilation=1, bias=False): | |
"""7x7 convolution with padding""" | |
return enn.R2Conv(in_type, out_type, 7, | |
stride=stride, | |
padding=padding, | |
dilation=dilation, | |
bias=bias, | |
sigma=None, | |
frequencies_cutoff=lambda r: 3*r, | |
) | |
def conv5x5(in_type: enn.FieldType, out_type: enn.FieldType, stride=1, padding=2, | |
dilation=1, bias=False): | |
"""5x5 convolution with padding""" | |
return enn.R2Conv(in_type, out_type, 5, | |
stride=stride, | |
padding=padding, | |
dilation=dilation, | |
bias=bias, | |
sigma=None, | |
frequencies_cutoff=lambda r: 3*r, | |
) | |
def conv3x3(in_type: enn.FieldType, out_type: enn.FieldType, stride=1, padding=1, | |
dilation=1, bias=False): | |
"""3x3 convolution with padding""" | |
return enn.R2Conv(in_type, out_type, 3, | |
stride=stride, | |
padding=padding, | |
dilation=dilation, | |
bias=bias, | |
sigma=None, | |
frequencies_cutoff=lambda r: 3*r, | |
) | |
def conv1x1(in_type: enn.FieldType, out_type: enn.FieldType, stride=1, padding=0, | |
dilation=1, bias=False): | |
"""1x1 convolution with padding""" | |
return enn.R2Conv(in_type, out_type, 1, | |
stride=stride, | |
padding=padding, | |
dilation=dilation, | |
bias=bias, | |
sigma=None, | |
frequencies_cutoff=lambda r: 3*r, | |
) | |
if __name__ == "__main__": | |
N = 8 | |
gspace = gspaces.Rot2dOnR2(N=N) | |
in_type = enn.FieldType(gspace, [gspace.trivial_repr] * 3) | |
out_type = enn.FieldType(gspace, [gspace.regular_repr] * 4) | |
conv1 = conv3x3(in_type, out_type, stride=1, padding=1) | |
model = conv1 | |
model.cuda() | |
model.eval() | |
# for _ in range(repeats): | |
for img_size in [181, 183, 185, 187]: | |
x = torch.randn([1, 3, img_size, img_size]).cuda() | |
x = enn.GeometricTensor(x, in_type) | |
xrot1 = x.transform(N-1) # 45 deg | |
xrot2 = x.transform(N-2) # 90 deg | |
with torch.no_grad(): | |
lat = model(x) | |
latrot1 = model(xrot1) | |
latrot2 = model(xrot2) | |
w = lat.shape[-1] | |
center = int((w-1)/2) | |
latrot1 = latrot1.transform(1) # 45 deg | |
latrot2 = latrot2.transform(2) # 90 deg | |
print(lat[0, 0, center, center].tensor.view(-1)) | |
print(latrot1[0, 0, center, center].tensor.view(-1)) | |
print(latrot2[0, 0, center, center].tensor.view(-1)) | |
print(lat.shape) | |
print(torch.allclose(lat[0, 0, center, center].tensor, \ | |
latrot1[0, 0, center, center].tensor)) | |
print(torch.allclose(lat[0, 0, center, center].tensor, \ | |
latrot2[0, 0, center, center].tensor)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment