Skip to content

Instantly share code, notes, and snippets.

@jizongFox
Created September 27, 2022 15:55
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jizongFox/a938204074b750d2376568d5b441d42a to your computer and use it in GitHub Desktop.
Save jizongFox/a938204074b750d2376568d5b441d42a to your computer and use it in GitHub Desktop.
test_function.py
from functools import partial
import matplotlib.pyplot as plt
import torch
from torch import nn
from models.efficientsplitformer.efficientsplitcmtformer2 import _Meta4D, Meta4D, SelfSeparableMeta4D
from models.nn.tokenmixer import MultiHeadSelfAttention, _TokenMixer
from models.nn.mbblock import MBConv, MBConvAttention
class MultiHeadSelfAttentionMeta4D(_Meta4D):
def __init__(self, dim, *, num_heads, mlp_ratio: float = 4.0, drop_path: float = 0.0,
act_layer=nn.GELU, **kwargs) -> None:
token_mixer = MultiHeadSelfAttention(dim=dim, num_heads=num_heads)
super().__init__(dim, token_mixer=token_mixer, mlp_ratio=mlp_ratio, drop_path=drop_path, act_layer=act_layer,
**kwargs)
input_ = torch.randn(1, 64, 15, 15, requires_grad=True)
output_mhsa = Meta4D(dim=64, num_heads=8)(input_)
output_mhsa[:, 32, 7, 7].mean().backward()
input_grad = input_.grad
plt.figure()
plt.imshow(torch.log(input_grad[0,].abs().mean(0)))
plt.title("Meta4D")
plt.colorbar()
input_.grad = None
output_ssa = SelfSeparableMeta4D(dim=64, num_heads=8)(input_)
output_ssa[:, 32, 7, 7].mean().backward()
input_grad = input_.grad
plt.figure()
plt.imshow(torch.log(input_grad[0,].abs().mean(0)))
plt.colorbar()
plt.title("SelfSeparableMeta4D")
input_.grad = None
output_ssa = MBConv(dim=64, token_size=7 * 7)(input_)
output_ssa[:, 32, 7, 7].mean().backward()
input_grad = input_.grad
plt.figure()
plt.imshow(torch.log(input_grad[0,].abs().mean(0)))
plt.colorbar()
plt.title("MBConv")
input_.grad = None
output_ssa = MultiHeadSelfAttentionMeta4D(dim=64, num_heads=8)(input_)
output_ssa[:, 32, 7, 7].mean().backward()
input_grad = input_.grad
plt.figure()
plt.imshow(torch.log(input_grad[0,].abs().mean(0)))
plt.colorbar()
plt.title("MultiHeadSelfAttentionMeta4D")
input_.grad = None
output_ssa = MBConvAttention(dim=64, num_heads=8)(input_)
output_ssa[:, 32, 7, 7].mean().backward()
input_grad = input_.grad
plt.figure()
plt.imshow(torch.log(input_grad[0,].abs().mean(0)))
plt.colorbar()
plt.title("BConvAttention")
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment