Created
March 28, 2022 13:31
-
-
Save hushell/bcd490e5f897b91da789ed1d857b5ec2 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit_pytorch.py | |
import torch | |
import torch.nn.functional as F | |
from einops import rearrange | |
from torch import nn | |
from torch.nn.modules.utils import _pair | |
MIN_NUM_PATCHES = 16 | |
class Residual(nn.Module): | |
def __init__(self, fn): | |
super().__init__() | |
self.fn = fn | |
def forward(self, x, **kwargs): | |
return self.fn(x, **kwargs) + x | |
class PreNorm(nn.Module): | |
def __init__(self, dim, fn): | |
super().__init__() | |
self.norm = nn.LayerNorm(dim) | |
self.fn = fn | |
def forward(self, x, **kwargs): | |
return self.fn(self.norm(x), **kwargs) | |
class FeedForward(nn.Module): | |
def __init__(self, dim, hidden_dim, dropout = 0.): | |
super().__init__() | |
self.net = nn.Sequential( | |
nn.Linear(dim, hidden_dim), | |
nn.GELU(), | |
nn.Dropout(dropout), | |
nn.Linear(hidden_dim, dim), | |
nn.Dropout(dropout) | |
) | |
def forward(self, x): | |
return self.net(x) | |
class Attention(nn.Module): | |
def __init__(self, dim, heads = 8, dropout = 0.): | |
super().__init__() | |
self.heads = heads | |
self.scale = dim ** -0.5 | |
self.to_qkv = nn.Linear(dim, dim * 3, bias = False) | |
self.to_out = nn.Sequential( | |
nn.Linear(dim, dim), | |
nn.Dropout(dropout) | |
) | |
def forward(self, x, mask = None): | |
b, n, _, h = *x.shape, self.heads | |
qkv = self.to_qkv(x).chunk(3, dim = -1) | |
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) | |
dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale | |
if mask is not None: | |
mask = F.pad(mask.flatten(1), (1, 0), value = True) | |
assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions' | |
mask = mask[:, None, :] * mask[:, :, None] | |
dots.masked_fill_(~mask, float('-inf')) | |
del mask | |
attn = dots.softmax(dim=-1) | |
out = torch.einsum('bhij,bhjd->bhid', attn, v) | |
out = rearrange(out, 'b h n d -> b n (h d)') | |
out = self.to_out(out) | |
return out | |
class Transformer(nn.Module): | |
def __init__(self, dim, recurrent_steps, heads, mlp_dim, dropout, depth=1): | |
super().__init__() | |
self.recurrent_steps = recurrent_steps | |
self.depth = depth | |
self.layers = nn.ModuleList([]) | |
for _ in range(self.depth): | |
self.layers.append(nn.ModuleList([ | |
Residual(PreNorm(dim, Attention(dim, heads = heads, dropout = dropout))), | |
Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))) | |
])) | |
def forward(self, x, mask = None): | |
for j in range(self.depth): | |
for i in range(self.recurrent_steps): | |
x = self.layers[j][0](x, mask = mask) | |
x = self.layers[j][1](x) | |
return x | |
class RecurrentViT(nn.Module): | |
def __init__(self, image_size, patch_size, num_classes, dim, | |
recurrent_steps, depth, heads, mlp_dim, channels = 3, dropout = 0., emb_dropout = 0.): | |
super().__init__() | |
assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size' | |
num_patches = (image_size // patch_size) ** 2 | |
patch_dim = channels * patch_size ** 2 | |
assert num_patches >= MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for attention to be effective. try decreasing your patch size' | |
self.patch_size = patch_size | |
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) | |
self.patch_to_embedding = nn.Linear(patch_dim, dim) | |
self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) | |
self.dropout = nn.Dropout(emb_dropout) | |
assert recurrent_steps * depth == 12, 'The architecture should match vanilla ViT' | |
self.transformer = Transformer(dim, recurrent_steps, heads, mlp_dim, dropout, depth=depth) | |
self.to_cls_token = nn.Identity() | |
self.mlp_head = nn.Sequential( | |
nn.LayerNorm(dim), | |
nn.Linear(dim, mlp_dim), | |
nn.GELU(), | |
nn.Dropout(dropout), | |
nn.Linear(mlp_dim, num_classes) | |
) | |
def forward(self, img, mask = None): | |
p = self.patch_size | |
x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p) | |
x = self.patch_to_embedding(x) | |
b, n, _ = x.shape | |
cls_tokens = self.cls_token.expand(b, -1, -1) | |
x = torch.cat((cls_tokens, x), dim=1) | |
x += self.pos_embedding[:, :(n + 1)] | |
x = self.dropout(x) | |
x = self.transformer(x, mask) | |
x = self.to_cls_token(x[:, 0]) | |
return self.mlp_head(x) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment