Created
March 11, 2021 06:59
-
-
Save ruotianluo/c7e1a6b424fe4adbcedf6dcab4b80ded to your computer and use it in GitHub Desktop.
involution_lambdanet
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
class involution(torch.nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.K = K = 3 | |
self.C = C = 256 | |
self.r = r = 64 | |
self.G = G = 64 | |
self.reduce = nn.Conv2d(C, C//r,1) | |
self.span = nn.Conv2d(C//r, K*K*G, 1, bias=False) | |
self.unfold = nn.Unfold(K, dilation=1, padding=K // 2, stride=1) | |
def forward(self, x): | |
# x: BxCxHxW | |
B,C,H,W = x.shape | |
G,r,K = self.G, self.r, self.K | |
x_unfolded = self.unfold(x) | |
x_unfolded = x_unfolded.view(B, G, C//G, K*K, H, W) | |
# kernel generation, Eqn.(6) | |
kernel = self.span(self.reduce(x)) # B,KxKxG,H,W | |
kernel = kernel.view(B, G, K*K, H, W).unsqueeze(2) | |
# Multiply-Add operation, Eqn.(4) | |
# out = torch.einsum('b g kk h w, b g cg kk h w -> b g cg h w', kernel, x_unfolded) | |
out = (kernel * x_unfolded).sum(dim=3) # B,G,C/G,H,W | |
out = out.view(B, C, H, W) | |
return out | |
class lambdanet(torch.nn.Module): | |
def __init__(self): | |
""" | |
Reference: https://github.com/lucidrains/lambda-networks/blob/main/lambda_networks/lambda_networks.py | |
One trivial difference from original lmabdanet is: | |
no k, no Yc | |
""" | |
super().__init__() | |
self.K = K = 3 | |
self.C = C = 256 | |
self.r = r = 64 | |
self.G = G = 64 | |
self.reduce = nn.Conv2d(C, C//r,1) | |
# The reason why the implementation here is complicated, part of the reason is mentioned below | |
# (Different groups shares C//r weights, if we can have different weights for different groups, | |
# then we don't need the group conv, normal conv is fine). | |
# To be honest, it is really hard to understand why is it. Let me try, | |
# First, the input is B,G, C//G, H,W | |
# The reason why C//G is not in feature dimension but in coordinate dimention (and the kernel size is 1) | |
# because for each feature dimension in C//G, the operation are the same | |
# (pointwise on C//G dimension, the feature in this dimension will not talk to each other). (definitely confusing) | |
# The idea of group convolution, is for each dimention in G, we expand it to C//r, where each dimension(g, cr) is the | |
# convoluted feature with cr template in group g. (cr in [1..C//r]) | |
self.pos_conv = nn.Conv3d(G, G*C//r, (1,K,K), padding=(0,K//2, K//2), groups=G, bias=False) | |
def forward(self, x): | |
# x: BxCxHxW | |
B,C,H,W = x.shape | |
G,r,K = self.G, self.r, self.K | |
# The idea is q gives us the weights on different KxK attention templates | |
q = self.reduce(x) # Bx C//rxHxW | |
x = x.view(B,G,C//G,H,W) | |
x = self.pos_conv(x) # B,G x C//r, C//G, H,W | |
# the idea hear, is for each template cr(1..C//r) in group g, we get the | |
x = x.view(B,G,C//r, C//G, H, W) | |
# The main different is we change the order of operation | |
# before its like (QK)V | |
# now its Q(KV) (This is why lambdanetwork can have large kernel size) | |
# another way to think of it, we don't actually have different kernel for different dimension, | |
# these kernels are just weighted combination of some kernel templates. | |
# so we can first do conv on the kernel templates | |
# then do the weighted sum on feature spaces. | |
# weighted sum according to weights(q) | |
x = x * q.reshape(B,1,C//r,1,H,W) # B,G,C//r, C//G, H, W | |
return x.sum(dim=2).view(B,C,H,W) | |
# In both models, the final features are the concat of heads/groups of features | |
# Difference, in lambdanet, they have different weights on template(k templates) for different heads(h heads), | |
# so thus the attended feature are different in different heads, the final feature is the concatenated | |
# the length of feature in each head is v. | |
# In involution, they have the same weights (C//r) of template in different head/group, but the templates for different groups are different | |
# For each group, there is C//r, KxK templates, in totoal G x C//r x KxK templates. | |
# If we ignore u(is 1 anyway) in lambda network: | |
# h is G, v is C//G, k is C//r. (Roughly speaking,) | |
# Figure 3 in lambdanet paper: | |
# position lambdas = einsum(embeddings, values, ’nmk,bmv−>bnkv’) | |
# you can think of, you have k nxm attention/kernel templates, then you apply it each template on bmv, | |
# and you can get bnkv, basically, for each template, you have a output feature map. | |
m1 = involution() | |
m2 = lambdanet() | |
inp = torch.randn(1,256,10,10) | |
K = 3 | |
C = 256 | |
r = 64 | |
G = 64 | |
# Make them equivalent | |
m2.reduce = m1.reduce | |
# tmp = m1.span.weight.view(K,K,G,C//r) #wrong | |
tmp = m1.span.weight.view(G, K,K, C//r) | |
tmp1 = m2.pos_conv.weight.view(G, C//r, K, K) | |
tmp1.data.copy_(tmp.permute(0,3,1,2)) | |
print(torch.isclose(m1(inp), m2(inp)).all()) | |
print((m1(inp)-m2(inp)).mean()) | |
# tmp1.data.copy_(tmp.permute(0,3,2,1)) | |
# print(torch.isclose(m1(inp), m2(inp)).all()) | |
# print((m1(inp)-m2(inp)).mean()) | |
# tmp1 = m2.pos_conv.weight.view(C//r, G, K, K) | |
# tmp1.data.copy_(tmp.permute(3,0,1,2)) | |
# print(torch.isclose(m1(inp), m2(inp)).all()) | |
# print((m1(inp)-m2(inp)).mean()) | |
# tmp1.data.copy_(tmp.permute(3,0,2,1)) | |
# print(torch.isclose(m1(inp), m2(inp)).all()) | |
# print((m1(inp)-m2(inp)).mean()) | |
# print(m1.span.weight.shape) | |
# print(m2.pos_conv.weight.shape) | |
# pos_conv = nn.Conv1d(3, 3*3, 1, padding=0, groups=3, bias=False) | |
# inp = torch.randn(1,3,2).fill_(1) | |
# import pudb;pu.db | |
# pos_conv(inp) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment