Skip to content

Instantly share code, notes, and snippets.

@ruotianluo
Created March 11, 2021 06:59
Show Gist options
  • Save ruotianluo/c7e1a6b424fe4adbcedf6dcab4b80ded to your computer and use it in GitHub Desktop.
Save ruotianluo/c7e1a6b424fe4adbcedf6dcab4b80ded to your computer and use it in GitHub Desktop.
involution_lambdanet
import torch
import torch.nn as nn
class involution(torch.nn.Module):
def __init__(self):
super().__init__()
self.K = K = 3
self.C = C = 256
self.r = r = 64
self.G = G = 64
self.reduce = nn.Conv2d(C, C//r,1)
self.span = nn.Conv2d(C//r, K*K*G, 1, bias=False)
self.unfold = nn.Unfold(K, dilation=1, padding=K // 2, stride=1)
def forward(self, x):
# x: BxCxHxW
B,C,H,W = x.shape
G,r,K = self.G, self.r, self.K
x_unfolded = self.unfold(x)
x_unfolded = x_unfolded.view(B, G, C//G, K*K, H, W)
# kernel generation, Eqn.(6)
kernel = self.span(self.reduce(x)) # B,KxKxG,H,W
kernel = kernel.view(B, G, K*K, H, W).unsqueeze(2)
# Multiply-Add operation, Eqn.(4)
# out = torch.einsum('b g kk h w, b g cg kk h w -> b g cg h w', kernel, x_unfolded)
out = (kernel * x_unfolded).sum(dim=3) # B,G,C/G,H,W
out = out.view(B, C, H, W)
return out
class lambdanet(torch.nn.Module):
def __init__(self):
"""
Reference: https://github.com/lucidrains/lambda-networks/blob/main/lambda_networks/lambda_networks.py
One trivial difference from original lmabdanet is:
no k, no Yc
"""
super().__init__()
self.K = K = 3
self.C = C = 256
self.r = r = 64
self.G = G = 64
self.reduce = nn.Conv2d(C, C//r,1)
# The reason why the implementation here is complicated, part of the reason is mentioned below
# (Different groups shares C//r weights, if we can have different weights for different groups,
# then we don't need the group conv, normal conv is fine).
# To be honest, it is really hard to understand why is it. Let me try,
# First, the input is B,G, C//G, H,W
# The reason why C//G is not in feature dimension but in coordinate dimention (and the kernel size is 1)
# because for each feature dimension in C//G, the operation are the same
# (pointwise on C//G dimension, the feature in this dimension will not talk to each other). (definitely confusing)
# The idea of group convolution, is for each dimention in G, we expand it to C//r, where each dimension(g, cr) is the
# convoluted feature with cr template in group g. (cr in [1..C//r])
self.pos_conv = nn.Conv3d(G, G*C//r, (1,K,K), padding=(0,K//2, K//2), groups=G, bias=False)
def forward(self, x):
# x: BxCxHxW
B,C,H,W = x.shape
G,r,K = self.G, self.r, self.K
# The idea is q gives us the weights on different KxK attention templates
q = self.reduce(x) # Bx C//rxHxW
x = x.view(B,G,C//G,H,W)
x = self.pos_conv(x) # B,G x C//r, C//G, H,W
# the idea hear, is for each template cr(1..C//r) in group g, we get the
x = x.view(B,G,C//r, C//G, H, W)
# The main different is we change the order of operation
# before its like (QK)V
# now its Q(KV) (This is why lambdanetwork can have large kernel size)
# another way to think of it, we don't actually have different kernel for different dimension,
# these kernels are just weighted combination of some kernel templates.
# so we can first do conv on the kernel templates
# then do the weighted sum on feature spaces.
# weighted sum according to weights(q)
x = x * q.reshape(B,1,C//r,1,H,W) # B,G,C//r, C//G, H, W
return x.sum(dim=2).view(B,C,H,W)
# In both models, the final features are the concat of heads/groups of features
# Difference, in lambdanet, they have different weights on template(k templates) for different heads(h heads),
# so thus the attended feature are different in different heads, the final feature is the concatenated
# the length of feature in each head is v.
# In involution, they have the same weights (C//r) of template in different head/group, but the templates for different groups are different
# For each group, there is C//r, KxK templates, in totoal G x C//r x KxK templates.
# If we ignore u(is 1 anyway) in lambda network:
# h is G, v is C//G, k is C//r. (Roughly speaking,)
# Figure 3 in lambdanet paper:
# position lambdas = einsum(embeddings, values, ’nmk,bmv−>bnkv’)
# you can think of, you have k nxm attention/kernel templates, then you apply it each template on bmv,
# and you can get bnkv, basically, for each template, you have a output feature map.
m1 = involution()
m2 = lambdanet()
inp = torch.randn(1,256,10,10)
K = 3
C = 256
r = 64
G = 64
# Make them equivalent
m2.reduce = m1.reduce
# tmp = m1.span.weight.view(K,K,G,C//r) #wrong
tmp = m1.span.weight.view(G, K,K, C//r)
tmp1 = m2.pos_conv.weight.view(G, C//r, K, K)
tmp1.data.copy_(tmp.permute(0,3,1,2))
print(torch.isclose(m1(inp), m2(inp)).all())
print((m1(inp)-m2(inp)).mean())
# tmp1.data.copy_(tmp.permute(0,3,2,1))
# print(torch.isclose(m1(inp), m2(inp)).all())
# print((m1(inp)-m2(inp)).mean())
# tmp1 = m2.pos_conv.weight.view(C//r, G, K, K)
# tmp1.data.copy_(tmp.permute(3,0,1,2))
# print(torch.isclose(m1(inp), m2(inp)).all())
# print((m1(inp)-m2(inp)).mean())
# tmp1.data.copy_(tmp.permute(3,0,2,1))
# print(torch.isclose(m1(inp), m2(inp)).all())
# print((m1(inp)-m2(inp)).mean())
# print(m1.span.weight.shape)
# print(m2.pos_conv.weight.shape)
# pos_conv = nn.Conv1d(3, 3*3, 1, padding=0, groups=3, bias=False)
# inp = torch.randn(1,3,2).fill_(1)
# import pudb;pu.db
# pos_conv(inp)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment