Skip to content

Instantly share code, notes, and snippets.

@myxyy
Last active December 25, 2022 15:09
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save myxyy/0e06c430652b35cda4d56aaf21eb7fa9 to your computer and use it in GitHub Desktop.
Save myxyy/0e06c430652b35cda4d56aaf21eb7fa9 to your computer and use it in GitHub Desktop.
# 以下を参考にしてます
# https://github.com/machine-perception-robotics-group/MPRGDeepLearningLectureNotebook/blob/master/13_rnn/06_Transformer.ipynb
import copy
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from timm.models.layers import trunc_normal_
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from torchmetrics import MeanMetric
import matplotlib.pyplot as plt
class SelfAttention(nn.Module):
def __init__(self, dim_in=16, dim_qkv=16, num_heads=8, dim_out=16):
super().__init__()
self.dim_in = dim_in
self.num_heads = num_heads
self.dim_qkv = dim_qkv
self.scale = dim_qkv ** -0.5
self.qkv = nn.Linear(dim_in, 3 * num_heads * dim_qkv, bias=False)
self.proj = nn.Linear(num_heads * dim_qkv, dim_out)
def forward(self, x):
batch, length, dim_in = x.shape
assert dim_in == self.dim_in, "dim_in does not match"
# (batch, length, 3 * num_heads * dim_qkv) => (batch, length, 3, num_heads, dim_qkv) => (3, batch, num_heads, length, dim_qkv)
qkv = self.qkv(x).reshape(batch, length, 3, self.num_heads, self.dim_qkv).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0) # (batch, num_heads, length, dim_qkv)
attn = ((q @ k.transpose(-2, -1)) * self.scale).softmax(dim=-1) # (batch, num_heads, length, length)
# (batch, num_heads, length, dim_qkv) => (batch, length, num_heads, dim_qkv) => (batch, length, num_heads * dim_qkv) => (batch, length, dim_out)
return self.proj((attn @ v).transpose(1, 2).reshape(batch, length, self.num_heads * self.dim_qkv))
class CrossAttention(nn.Module):
def __init__(self, dim_in_main=16, dim_in_side=48, dim_qk=16, dim_v=16, num_heads=8, dim_out=16):
super().__init__()
self.dim_in_main = dim_in_main
self.dim_in_side = dim_in_side
self.num_heads = num_heads
self.dim_qk = dim_qk
self.dim_v = dim_v
self.scale = dim_qk ** -0.5
self.q = nn.Linear(dim_in_main, num_heads * dim_qk, bias=False)
self.k = nn.Linear(dim_in_side, num_heads * dim_qk, bias=False)
self.v = nn.Linear(dim_in_side, num_heads * dim_v, bias=False)
self.proj = nn.Linear(num_heads * dim_v, dim_out)
# ((batch, length_main, dim_in_main), (batch, length_side, dim_in_side)) => (batch, length_main, dim_out)
def forward(self, x_main, x_side):
batch, length_main, dim_in_main = x_main.shape
assert dim_in_main == self.dim_in_main, "dim_in_main does not match"
batch_side, length_side, dim_in_side = x_side.shape
assert dim_in_side == self.dim_in_side, "dim_in_main does not match"
assert batch_side == batch, "batch size does not match"
q = self.q(x_main).reshape(batch, length_main, self.num_heads, self.dim_qk).permute(0, 2, 1, 3) # (batch, num_heads, length_main, dim_qk)
k = self.k(x_side).reshape(batch, length_side, self.num_heads, self.dim_qk).permute(0, 2, 1, 3) # (batch, num_heads, length_side, dim_qk)
v = self.v(x_side).reshape(batch, length_side, self.num_heads, self.dim_v).permute(0, 2, 1, 3) # (batch, num_heads, length_side, dim_v)
attn = ((q @ k.transpose(-2, -1)) * self.scale).softmax(dim=-1) # (batch, num_heads, length_main, length_side)
return self.proj((attn @ v).transpose(1, 2).reshape(batch, length_main, self.num_heads * self.dim_v))
class Mlp(nn.Module):
def __init__(self, dim_in, dim_hidden, dim_out):
super().__init__()
self.act = nn.GELU()
self.fc1 = nn.Linear(dim_in, dim_hidden)
self.fc2 = nn.Linear(dim_hidden, dim_out)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.fc2(x)
return x
class TransformerDecoder(nn.Module):
def __init__(self, dim_main, dim_side, dim_qk, dim_v, dim_qkv, num_heads):
super().__init__()
self.norm_layer = nn.LayerNorm(dim_main)
self.self_attn = SelfAttention(dim_main, dim_qkv, num_heads, dim_main)
#self.mlp1 = Mlp(dim_main, dim_main * 4, dim_main)
self.cross_attn = CrossAttention(dim_main, dim_side, dim_qk, dim_v, num_heads, dim_main)
self.mlp2 = Mlp(dim_main, dim_main * 4, dim_main)
# side is assumed to be layer-normalized
def forward(self, main, side_layer_normed):
x1 = self.norm_layer(main)
x1 = self.self_attn(x1)
x1 = x1 + main
"""
x2 = self.norm_layer(x1)
x2 = self.mlp1(x2)
x2 = x2 + x1
"""
x3 = self.norm_layer(x1)
x3 = self.cross_attn(x3, side_layer_normed)
x3 = x3 + x1
x4 = self.norm_layer(x3)
x4 = self.mlp2(x4)
x4 = x4 + x3
return x4
class TransformerDecoderWCA(nn.Module): # without cross attention
def __init__(self, dim_main, dim_side, dim_qk, dim_v, dim_qkv, num_heads):
super().__init__()
self.norm_layer = nn.LayerNorm(dim_main)
self.self_attn = SelfAttention(dim_main, dim_qkv, num_heads, dim_main)
#self.mlp1 = Mlp(dim_main, dim_main * 4, dim_main)
self.cross_attn = CrossAttention(dim_main, dim_side, dim_qk, dim_v, num_heads, dim_main)
self.mlp2 = Mlp(dim_main, dim_main * 4, dim_main)
# side is assumed to be layer-normalized
def forward(self, main, side_layer_normed):
x3 = self.norm_layer(main)
x3 = self.cross_attn(x3, side_layer_normed)
x3 = x3 + main
x4 = self.norm_layer(x3)
x4 = self.mlp2(x4)
x4 = x4 + x3
return x4
class StackedTransformerDecoder(nn.Module):
def __init__(self, dim_main, dim_side, dim_qk, dim_v, dim_qkv, num_heads, depth):
super().__init__()
decoder = TransformerDecoder(dim_main, dim_side, dim_qk, dim_v, dim_qkv, num_heads)
self.decoder_list = nn.ModuleList([copy.deepcopy(decoder) for i in range(depth)])
self.norm_layer = nn.LayerNorm(dim_side)
def forward(self, main, side):
side_normed = self.norm_layer(side)
for decoder in self.decoder_list:
main = decoder(main, side_normed)
return main
class StackedTransformerDecoderWCA(nn.Module):
def __init__(self, dim_main, dim_side, dim_qk, dim_v, dim_qkv, num_heads, depth):
super().__init__()
decoder = TransformerDecoderWCA(dim_main, dim_side, dim_qk, dim_v, dim_qkv, num_heads)
self.decoder_list = nn.ModuleList([copy.deepcopy(decoder) for i in range(depth)])
self.norm_layer = nn.LayerNorm(dim_side)
def forward(self, main, side):
side_normed = self.norm_layer(side)
for decoder in self.decoder_list:
main = decoder(main, side_normed)
return main
class VisionTBA(pl.LightningModule):
logger : TensorBoardLogger
def __init__(self, img_size=32, patch_size=4, in_chans=3, dim_char=16, size_char=16):
super().__init__()
self.patch_size = patch_size
self.in_chans = in_chans
self.num_patch_side = img_size // patch_size
dim_main = in_chans * patch_size * patch_size
size_main = (img_size // patch_size) ** 2
self.dim_main = dim_main
self.size_main = size_main
self.dim_char = dim_char
self.size_char = size_char
self.encoder = StackedTransformerDecoder(dim_char, dim_main, dim_char, dim_char, dim_char, 8, 8)
self.decoder = StackedTransformerDecoder(dim_main, dim_char, dim_main, dim_main, dim_main, 8, 8)
self.pos_embed_main = nn.Parameter(torch.zeros(size_main, dim_main))
self.pos_embed_latent = nn.Parameter(torch.zeros(size_char, dim_char))
self.criterion = torch.nn.MSELoss()
self.sigmoid = nn.Sigmoid()
trunc_normal_(self.pos_embed_main)
trunc_normal_(self.pos_embed_latent)
self.apply(self._init_weights)
print(f'parameters: {sum(p.numel() for p in self.parameters() if p.requires_grad)}')
self.train_loss_epoch = MeanMetric()
self.validate_loss_epoch = MeanMetric()
self.train_loss_epoch.reset()
self.validate_loss_epoch.reset()
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def patch(self, x):
batch, _, _, _ = x.shape
patch_size = self.patch_size
num_patch_side = self.num_patch_side
num_patch = num_patch_side ** 2
return x.unfold(2, patch_size, patch_size).unfold(3, patch_size, patch_size).permute(0,2,3,1,4,5).reshape(batch, num_patch, self.in_chans * patch_size * patch_size)
def unpatch(self, x):
batch, _, _ = x.shape
patch_size = self.patch_size
num_patch_side = self.num_patch_side
num_patch = num_patch_side ** 2
in_chans = self.in_chans
return torch.cat(
torch.cat(
x.reshape(batch, num_patch, in_chans, patch_size, patch_size).permute(0,2,1,3,4).reshape(batch, in_chans, num_patch_side, num_patch_side, patch_size, patch_size).unbind(3), dim=4
).unbind(2),
dim=2
)
def pos_embed_main_duplicated(self, batch_size):
return self.pos_embed_main.unsqueeze(dim=0).expand(batch_size, self.size_main, self.dim_main)
def pos_embed_latent_duplicated(self, batch_size):
return self.pos_embed_latent.unsqueeze(dim=0).expand(batch_size, self.size_char, self.dim_char)
def step(self, batch):
x, _ = batch
batch_size, _, _, _ = x.shape
latent = self.encoder(self.pos_embed_latent_duplicated(batch_size), self.patch(x) + self.pos_embed_main)
x_hat = self.unpatch(self.sigmoid(self.decoder(self.pos_embed_main_duplicated(batch_size), latent)))
loss = self.criterion(x, x_hat)
return loss
def training_step(self, batch, batch_idx):
loss = self.step(batch)
self.train_loss_epoch.update(loss)
return loss
def training_epoch_end(self, outputs):
self.log("train loss", self.train_loss_epoch.compute(), on_step=False, on_epoch=True)
self.train_loss_epoch.reset()
def validation_step(self, batch, batch_idx):
loss = self.step(batch)
self.validate_loss_epoch.update(loss)
if batch_idx in range(2):
x, _ = batch
x = x[0].unsqueeze(0)
latent = self.encoder(self.pos_embed_latent_duplicated(1), self.patch(x) + self.pos_embed_main)
x_hat = self.unpatch(self.sigmoid(self.decoder(self.pos_embed_main_duplicated(1), latent)))
self.log_image(x[0], latent[0], x_hat[0], self.pos_embed_main, self.pos_embed_latent, "validate image {}".format(batch_idx))
return loss
@torch.no_grad()
def log_image(self, x, latent, x_hat, pos_embed_main, pos_embed_latent, label):
x_image = x.permute(1,2,0).cpu().numpy()
latent_image = latent.cpu().numpy()
x_hat_image = x_hat.permute(1,2,0).cpu().numpy()
pos_embed_main_image = pos_embed_main.cpu().numpy()
pos_embed_latent_image = pos_embed_latent.cpu().numpy()
figure = plt.figure()
figure.add_subplot(151).imshow(x_image)
figure.add_subplot(152).imshow(latent_image)
figure.add_subplot(153).imshow(x_hat_image)
figure.add_subplot(154).imshow(pos_embed_main_image)
figure.add_subplot(155).imshow(pos_embed_latent_image)
self.logger.experiment.add_figure(label, figure, self.global_step)
def validation_epoch_end(self, outputs):
self.log("validate loss", self.validate_loss_epoch.compute(), on_step=False, on_epoch=True)
self.validate_loss_epoch.reset()
def configure_optimizers(self):
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.05)
return optimizer
if __name__ == '__main__':
transform = transforms.Compose([
transforms.ToTensor()
])
dataset_train = torchvision.datasets.CIFAR10("./", train=True, transform=transform, download=True)
dataset_test = torchvision.datasets.CIFAR10("./", train=True, transform=transform, download=False)
#print(dataset_train.__getitem__(1)[0].shape)
dataloader_train = torch.utils.data.DataLoader(dataset_train, batch_size=256, num_workers=2, pin_memory=True, drop_last=True)
dataloader_test = torch.utils.data.DataLoader(dataset_test, batch_size=1024, num_workers=2, pin_memory=True, drop_last=False)
model = VisionTBA()
trainer = pl.Trainer(devices=1, accelerator='gpu', max_epochs=500)
trainer.fit(model, dataloader_train, dataloader_test)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment