Last active
December 25, 2022 15:09
-
-
Save myxyy/0e06c430652b35cda4d56aaf21eb7fa9 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# 以下を参考にしてます | |
# https://github.com/machine-perception-robotics-group/MPRGDeepLearningLectureNotebook/blob/master/13_rnn/06_Transformer.ipynb | |
import copy | |
import torch | |
import torch.nn as nn | |
import torchvision | |
import torchvision.transforms as transforms | |
from timm.models.layers import trunc_normal_ | |
import pytorch_lightning as pl | |
from pytorch_lightning.loggers import TensorBoardLogger | |
from torchmetrics import MeanMetric | |
import matplotlib.pyplot as plt | |
class SelfAttention(nn.Module): | |
def __init__(self, dim_in=16, dim_qkv=16, num_heads=8, dim_out=16): | |
super().__init__() | |
self.dim_in = dim_in | |
self.num_heads = num_heads | |
self.dim_qkv = dim_qkv | |
self.scale = dim_qkv ** -0.5 | |
self.qkv = nn.Linear(dim_in, 3 * num_heads * dim_qkv, bias=False) | |
self.proj = nn.Linear(num_heads * dim_qkv, dim_out) | |
def forward(self, x): | |
batch, length, dim_in = x.shape | |
assert dim_in == self.dim_in, "dim_in does not match" | |
# (batch, length, 3 * num_heads * dim_qkv) => (batch, length, 3, num_heads, dim_qkv) => (3, batch, num_heads, length, dim_qkv) | |
qkv = self.qkv(x).reshape(batch, length, 3, self.num_heads, self.dim_qkv).permute(2, 0, 3, 1, 4) | |
q, k, v = qkv.unbind(0) # (batch, num_heads, length, dim_qkv) | |
attn = ((q @ k.transpose(-2, -1)) * self.scale).softmax(dim=-1) # (batch, num_heads, length, length) | |
# (batch, num_heads, length, dim_qkv) => (batch, length, num_heads, dim_qkv) => (batch, length, num_heads * dim_qkv) => (batch, length, dim_out) | |
return self.proj((attn @ v).transpose(1, 2).reshape(batch, length, self.num_heads * self.dim_qkv)) | |
class CrossAttention(nn.Module): | |
def __init__(self, dim_in_main=16, dim_in_side=48, dim_qk=16, dim_v=16, num_heads=8, dim_out=16): | |
super().__init__() | |
self.dim_in_main = dim_in_main | |
self.dim_in_side = dim_in_side | |
self.num_heads = num_heads | |
self.dim_qk = dim_qk | |
self.dim_v = dim_v | |
self.scale = dim_qk ** -0.5 | |
self.q = nn.Linear(dim_in_main, num_heads * dim_qk, bias=False) | |
self.k = nn.Linear(dim_in_side, num_heads * dim_qk, bias=False) | |
self.v = nn.Linear(dim_in_side, num_heads * dim_v, bias=False) | |
self.proj = nn.Linear(num_heads * dim_v, dim_out) | |
# ((batch, length_main, dim_in_main), (batch, length_side, dim_in_side)) => (batch, length_main, dim_out) | |
def forward(self, x_main, x_side): | |
batch, length_main, dim_in_main = x_main.shape | |
assert dim_in_main == self.dim_in_main, "dim_in_main does not match" | |
batch_side, length_side, dim_in_side = x_side.shape | |
assert dim_in_side == self.dim_in_side, "dim_in_main does not match" | |
assert batch_side == batch, "batch size does not match" | |
q = self.q(x_main).reshape(batch, length_main, self.num_heads, self.dim_qk).permute(0, 2, 1, 3) # (batch, num_heads, length_main, dim_qk) | |
k = self.k(x_side).reshape(batch, length_side, self.num_heads, self.dim_qk).permute(0, 2, 1, 3) # (batch, num_heads, length_side, dim_qk) | |
v = self.v(x_side).reshape(batch, length_side, self.num_heads, self.dim_v).permute(0, 2, 1, 3) # (batch, num_heads, length_side, dim_v) | |
attn = ((q @ k.transpose(-2, -1)) * self.scale).softmax(dim=-1) # (batch, num_heads, length_main, length_side) | |
return self.proj((attn @ v).transpose(1, 2).reshape(batch, length_main, self.num_heads * self.dim_v)) | |
class Mlp(nn.Module): | |
def __init__(self, dim_in, dim_hidden, dim_out): | |
super().__init__() | |
self.act = nn.GELU() | |
self.fc1 = nn.Linear(dim_in, dim_hidden) | |
self.fc2 = nn.Linear(dim_hidden, dim_out) | |
def forward(self, x): | |
x = self.fc1(x) | |
x = self.act(x) | |
x = self.fc2(x) | |
return x | |
class TransformerDecoder(nn.Module): | |
def __init__(self, dim_main, dim_side, dim_qk, dim_v, dim_qkv, num_heads): | |
super().__init__() | |
self.norm_layer = nn.LayerNorm(dim_main) | |
self.self_attn = SelfAttention(dim_main, dim_qkv, num_heads, dim_main) | |
#self.mlp1 = Mlp(dim_main, dim_main * 4, dim_main) | |
self.cross_attn = CrossAttention(dim_main, dim_side, dim_qk, dim_v, num_heads, dim_main) | |
self.mlp2 = Mlp(dim_main, dim_main * 4, dim_main) | |
# side is assumed to be layer-normalized | |
def forward(self, main, side_layer_normed): | |
x1 = self.norm_layer(main) | |
x1 = self.self_attn(x1) | |
x1 = x1 + main | |
""" | |
x2 = self.norm_layer(x1) | |
x2 = self.mlp1(x2) | |
x2 = x2 + x1 | |
""" | |
x3 = self.norm_layer(x1) | |
x3 = self.cross_attn(x3, side_layer_normed) | |
x3 = x3 + x1 | |
x4 = self.norm_layer(x3) | |
x4 = self.mlp2(x4) | |
x4 = x4 + x3 | |
return x4 | |
class TransformerDecoderWCA(nn.Module): # without cross attention | |
def __init__(self, dim_main, dim_side, dim_qk, dim_v, dim_qkv, num_heads): | |
super().__init__() | |
self.norm_layer = nn.LayerNorm(dim_main) | |
self.self_attn = SelfAttention(dim_main, dim_qkv, num_heads, dim_main) | |
#self.mlp1 = Mlp(dim_main, dim_main * 4, dim_main) | |
self.cross_attn = CrossAttention(dim_main, dim_side, dim_qk, dim_v, num_heads, dim_main) | |
self.mlp2 = Mlp(dim_main, dim_main * 4, dim_main) | |
# side is assumed to be layer-normalized | |
def forward(self, main, side_layer_normed): | |
x3 = self.norm_layer(main) | |
x3 = self.cross_attn(x3, side_layer_normed) | |
x3 = x3 + main | |
x4 = self.norm_layer(x3) | |
x4 = self.mlp2(x4) | |
x4 = x4 + x3 | |
return x4 | |
class StackedTransformerDecoder(nn.Module): | |
def __init__(self, dim_main, dim_side, dim_qk, dim_v, dim_qkv, num_heads, depth): | |
super().__init__() | |
decoder = TransformerDecoder(dim_main, dim_side, dim_qk, dim_v, dim_qkv, num_heads) | |
self.decoder_list = nn.ModuleList([copy.deepcopy(decoder) for i in range(depth)]) | |
self.norm_layer = nn.LayerNorm(dim_side) | |
def forward(self, main, side): | |
side_normed = self.norm_layer(side) | |
for decoder in self.decoder_list: | |
main = decoder(main, side_normed) | |
return main | |
class StackedTransformerDecoderWCA(nn.Module): | |
def __init__(self, dim_main, dim_side, dim_qk, dim_v, dim_qkv, num_heads, depth): | |
super().__init__() | |
decoder = TransformerDecoderWCA(dim_main, dim_side, dim_qk, dim_v, dim_qkv, num_heads) | |
self.decoder_list = nn.ModuleList([copy.deepcopy(decoder) for i in range(depth)]) | |
self.norm_layer = nn.LayerNorm(dim_side) | |
def forward(self, main, side): | |
side_normed = self.norm_layer(side) | |
for decoder in self.decoder_list: | |
main = decoder(main, side_normed) | |
return main | |
class VisionTBA(pl.LightningModule): | |
logger : TensorBoardLogger | |
def __init__(self, img_size=32, patch_size=4, in_chans=3, dim_char=16, size_char=16): | |
super().__init__() | |
self.patch_size = patch_size | |
self.in_chans = in_chans | |
self.num_patch_side = img_size // patch_size | |
dim_main = in_chans * patch_size * patch_size | |
size_main = (img_size // patch_size) ** 2 | |
self.dim_main = dim_main | |
self.size_main = size_main | |
self.dim_char = dim_char | |
self.size_char = size_char | |
self.encoder = StackedTransformerDecoder(dim_char, dim_main, dim_char, dim_char, dim_char, 8, 8) | |
self.decoder = StackedTransformerDecoder(dim_main, dim_char, dim_main, dim_main, dim_main, 8, 8) | |
self.pos_embed_main = nn.Parameter(torch.zeros(size_main, dim_main)) | |
self.pos_embed_latent = nn.Parameter(torch.zeros(size_char, dim_char)) | |
self.criterion = torch.nn.MSELoss() | |
self.sigmoid = nn.Sigmoid() | |
trunc_normal_(self.pos_embed_main) | |
trunc_normal_(self.pos_embed_latent) | |
self.apply(self._init_weights) | |
print(f'parameters: {sum(p.numel() for p in self.parameters() if p.requires_grad)}') | |
self.train_loss_epoch = MeanMetric() | |
self.validate_loss_epoch = MeanMetric() | |
self.train_loss_epoch.reset() | |
self.validate_loss_epoch.reset() | |
def _init_weights(self, m): | |
if isinstance(m, nn.Linear): | |
trunc_normal_(m.weight, std=.02) | |
if isinstance(m, nn.Linear) and m.bias is not None: | |
nn.init.constant_(m.bias, 0) | |
elif isinstance(m, nn.LayerNorm): | |
nn.init.constant_(m.bias, 0) | |
nn.init.constant_(m.weight, 1.0) | |
def patch(self, x): | |
batch, _, _, _ = x.shape | |
patch_size = self.patch_size | |
num_patch_side = self.num_patch_side | |
num_patch = num_patch_side ** 2 | |
return x.unfold(2, patch_size, patch_size).unfold(3, patch_size, patch_size).permute(0,2,3,1,4,5).reshape(batch, num_patch, self.in_chans * patch_size * patch_size) | |
def unpatch(self, x): | |
batch, _, _ = x.shape | |
patch_size = self.patch_size | |
num_patch_side = self.num_patch_side | |
num_patch = num_patch_side ** 2 | |
in_chans = self.in_chans | |
return torch.cat( | |
torch.cat( | |
x.reshape(batch, num_patch, in_chans, patch_size, patch_size).permute(0,2,1,3,4).reshape(batch, in_chans, num_patch_side, num_patch_side, patch_size, patch_size).unbind(3), dim=4 | |
).unbind(2), | |
dim=2 | |
) | |
def pos_embed_main_duplicated(self, batch_size): | |
return self.pos_embed_main.unsqueeze(dim=0).expand(batch_size, self.size_main, self.dim_main) | |
def pos_embed_latent_duplicated(self, batch_size): | |
return self.pos_embed_latent.unsqueeze(dim=0).expand(batch_size, self.size_char, self.dim_char) | |
def step(self, batch): | |
x, _ = batch | |
batch_size, _, _, _ = x.shape | |
latent = self.encoder(self.pos_embed_latent_duplicated(batch_size), self.patch(x) + self.pos_embed_main) | |
x_hat = self.unpatch(self.sigmoid(self.decoder(self.pos_embed_main_duplicated(batch_size), latent))) | |
loss = self.criterion(x, x_hat) | |
return loss | |
def training_step(self, batch, batch_idx): | |
loss = self.step(batch) | |
self.train_loss_epoch.update(loss) | |
return loss | |
def training_epoch_end(self, outputs): | |
self.log("train loss", self.train_loss_epoch.compute(), on_step=False, on_epoch=True) | |
self.train_loss_epoch.reset() | |
def validation_step(self, batch, batch_idx): | |
loss = self.step(batch) | |
self.validate_loss_epoch.update(loss) | |
if batch_idx in range(2): | |
x, _ = batch | |
x = x[0].unsqueeze(0) | |
latent = self.encoder(self.pos_embed_latent_duplicated(1), self.patch(x) + self.pos_embed_main) | |
x_hat = self.unpatch(self.sigmoid(self.decoder(self.pos_embed_main_duplicated(1), latent))) | |
self.log_image(x[0], latent[0], x_hat[0], self.pos_embed_main, self.pos_embed_latent, "validate image {}".format(batch_idx)) | |
return loss | |
@torch.no_grad() | |
def log_image(self, x, latent, x_hat, pos_embed_main, pos_embed_latent, label): | |
x_image = x.permute(1,2,0).cpu().numpy() | |
latent_image = latent.cpu().numpy() | |
x_hat_image = x_hat.permute(1,2,0).cpu().numpy() | |
pos_embed_main_image = pos_embed_main.cpu().numpy() | |
pos_embed_latent_image = pos_embed_latent.cpu().numpy() | |
figure = plt.figure() | |
figure.add_subplot(151).imshow(x_image) | |
figure.add_subplot(152).imshow(latent_image) | |
figure.add_subplot(153).imshow(x_hat_image) | |
figure.add_subplot(154).imshow(pos_embed_main_image) | |
figure.add_subplot(155).imshow(pos_embed_latent_image) | |
self.logger.experiment.add_figure(label, figure, self.global_step) | |
def validation_epoch_end(self, outputs): | |
self.log("validate loss", self.validate_loss_epoch.compute(), on_step=False, on_epoch=True) | |
self.validate_loss_epoch.reset() | |
def configure_optimizers(self): | |
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.05) | |
return optimizer | |
if __name__ == '__main__': | |
transform = transforms.Compose([ | |
transforms.ToTensor() | |
]) | |
dataset_train = torchvision.datasets.CIFAR10("./", train=True, transform=transform, download=True) | |
dataset_test = torchvision.datasets.CIFAR10("./", train=True, transform=transform, download=False) | |
#print(dataset_train.__getitem__(1)[0].shape) | |
dataloader_train = torch.utils.data.DataLoader(dataset_train, batch_size=256, num_workers=2, pin_memory=True, drop_last=True) | |
dataloader_test = torch.utils.data.DataLoader(dataset_test, batch_size=1024, num_workers=2, pin_memory=True, drop_last=False) | |
model = VisionTBA() | |
trainer = pl.Trainer(devices=1, accelerator='gpu', max_epochs=500) | |
trainer.fit(model, dataloader_train, dataloader_test) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment