Created
September 12, 2022 10:24
-
-
Save vedantroy/e270978b78f24ff6e2c919f6aad9e78b to your computer and use it in GitHub Desktop.
Reversible VIT
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from types import SimpleNamespace | |
import yahp as hp | |
import torch | |
from torch import nn | |
from torch.autograd import Function as Function | |
from einops import rearrange | |
from einops.layers.torch import Rearrange | |
# A one-file implementation of the reversible VIT architecture | |
# Lacking: | |
# - Stochastic Depth (for now) | |
# - Dropout (never used) | |
def norm(dim: int): | |
return nn.LayerNorm(dim, eps=1e-6, elementwise_affine=True) | |
def mlp(in_features: int, hidden_features: int, out_features: int): | |
# If you need dropout, just get more data ... | |
# None of the ViT configs (that I checked) in PySlowFast | |
# use dropout | |
# (Possibly b/c they are using stochastic depth instead ...) | |
# check in-case someone passes a float into mlp_ratio | |
assert isinstance(hidden_features, int) or (hidden_features.is_integer()) | |
return nn.Sequential( | |
nn.Linear(in_features, int(hidden_features)), | |
nn.GELU(), | |
nn.Linear(int(hidden_features), out_features), | |
) | |
def mlp_block(dim: int, mlp_ratio: int): | |
return nn.Sequential( | |
norm(dim), | |
mlp(in_features=dim, hidden_features=dim * mlp_ratio, out_features=dim), | |
) | |
# Taken from lucidrain's ViT repository | |
class Attention(nn.Module): | |
def __init__(self, dim: int, heads: int): | |
super().__init__() | |
self.heads = heads | |
head_dim = dim // heads | |
# TODO: The original paper says sqrt(d_k) | |
# but FBAI + lucidrains do something else | |
self.scale = head_dim ** -0.5 | |
self.to_probabilities = nn.Softmax(dim=-1) | |
self.to_qkv = nn.Linear(dim, dim * 3, bias=False) | |
def forward(self, x): | |
b, n_patches, dim = x.shape | |
qkv = self.to_qkv(x).chunk(3, dim=-1) | |
assert qkv[0].shape == x.shape | |
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), qkv) | |
assert q.shape == (b, self.heads, n_patches, dim // self.heads) | |
attn_matrix = (q @ k.transpose(-1, -2)) * self.scale | |
assert attn_matrix.shape[-1] == attn_matrix.shape[-2] | |
assert attn_matrix.shape[-1] == n_patches | |
probs = self.to_probabilities(attn_matrix) | |
reaveraged_values = probs @ v | |
reaveraged_values = rearrange(reaveraged_values, "b h n d -> b n (h d)") | |
return reaveraged_values | |
def attention_block(*, dim: int, heads: int): | |
return nn.Sequential(norm(dim), Attention(dim, heads)) | |
class ReversibleBlock(nn.Module): | |
def __init__( | |
self, | |
*, | |
dim, | |
heads, | |
mlp_ratio, | |
drop_path_rate: float, | |
layer_id: int | |
): | |
super().__init__() | |
self.layer_id = layer_id | |
self.drop_path_rate = drop_path_rate | |
# No residual connections on purpose: | |
# the paper says the two-stream architecture | |
# has builtin skip connections | |
self.F = attention_block( | |
# dim should be divisible by 2 | |
dim=dim // 2, | |
heads=heads, | |
) | |
self.G = mlp_block( | |
dim=dim // 2, | |
mlp_ratio=mlp_ratio, | |
) | |
# self.seeds = {} | |
# def seed_cuda(self, key): | |
# """ | |
# Fix seeds to allow for stochastic elements such as | |
# dropout to be reproduced exactly in activation | |
# recomputation in the backward pass. | |
# """ | |
# # randomize seeds | |
# # use cuda generator if available | |
# if ( | |
# hasattr(torch.cuda, "default_generators") | |
# and len(torch.cuda.default_generators) > 0 | |
# ): | |
# # GPU | |
# device_idx = torch.cuda.current_device() | |
# seed = torch.cuda.default_generators[device_idx].seed() | |
# else: | |
# # CPU | |
# seed = int(torch.seed() % sys.maxsize) | |
# self.seeds[key] = seed | |
# torch.manual_seed(self.seeds[key]) | |
def forward(self, X_1, X_2): | |
""" | |
forward pass equations: | |
Y_1 = X_1 + Attention(X_2), F = Attention | |
Y_2 = X_2 + MLP(Y_1), G = MLP | |
""" | |
# Y_1 : attn_output | |
f_X_2 = self.F(X_2) | |
assert f_X_2.shape == X_2.shape | |
# self.seed_cuda("droppath") | |
# f_X_2_dropped = drop_path( | |
# f_X_2, drop_prob=self.drop_path_rate, training=self.training | |
# ) | |
# Y_1 = X_1 + f(X_2) | |
# Y_1 = X_1 + f_X_2_dropped | |
Y_1 = X_1 + f_X_2 | |
# free memory | |
del X_1 | |
g_Y_1 = self.G(Y_1) | |
# torch.manual_seed(self.seeds["droppath"]) | |
# g_Y_1_dropped = drop_path( | |
# g_Y_1, drop_prob=self.drop_path_rate, training=self.training | |
# ) | |
# Y_2 = X_2 + g(Y_1) | |
# Y_2 = X_2 + g_Y_1_dropped | |
Y_2 = X_2 + g_Y_1 | |
del X_2 | |
return Y_1, Y_2 | |
def backward_pass( | |
self, | |
Y_1, | |
Y_2, | |
dY_1, | |
dY_2, | |
): | |
print(f"Backwards: {self.layer_id} device_id={Y_1.get_device()}") | |
""" | |
equation for activation recomputation: | |
X_2 = Y_2 - G(Y_1), G = MLP | |
X_1 = Y_1 - F(X_2), F = Attention | |
""" | |
# TODO: I don't fully understand | |
# why this works ... specific questions around | |
# how the gradients dX_1 and dX_2 are being calculated | |
# temporarily record intermediate activation for G | |
# and use them for gradient calculcation of G | |
with torch.enable_grad(): | |
Y_1.requires_grad = True | |
g_Y_1 = self.G(Y_1) | |
assert g_Y_1.shape == Y_1.shape | |
# torch.manual_seed(self.seeds["droppath"]) | |
# g_Y_1 = drop_path( | |
# g_Y_1, drop_prob=self.drop_path_rate, training=self.training | |
# ) | |
g_Y_1.backward(dY_2, retain_graph=True) | |
# activation recomputation is by design and not part of | |
# the computation graph in forward pass. | |
with torch.no_grad(): | |
X_2 = Y_2 - g_Y_1 | |
del g_Y_1 | |
dY_1 = dY_1 + Y_1.grad | |
Y_1.grad = None | |
# record F activations and calc gradients on F | |
with torch.enable_grad(): | |
X_2.requires_grad = True | |
f_X_2 = self.F(X_2) | |
# torch.manual_seed(self.seeds["droppath"]) | |
# f_X_2 = drop_path( | |
# f_X_2, drop_prob=self.drop_path_rate, training=self.training | |
# ) | |
f_X_2.backward(dY_1, retain_graph=True) | |
# propagate reverse computed acitvations at the start of | |
# the previou block for backprop.s | |
with torch.no_grad(): | |
X_1 = Y_1 - f_X_2 | |
del f_X_2, Y_1 | |
dY_2 = dY_2 + X_2.grad | |
X_2.grad = None | |
X_2 = X_2.detach() | |
return X_1, X_2, dY_1, dY_2 | |
class RevBackProp(Function): | |
@staticmethod | |
def forward( | |
ctx, | |
x, | |
layers, | |
): | |
print(f"FORWARD PASS ... (device={x.get_device()}") | |
X_1, X_2 = torch.chunk(x, 2, dim=-1) | |
for layer in layers: | |
X_1, X_2 = layer(X_1, X_2) | |
all_tensors = [X_1.detach(), X_2.detach()] | |
ctx.save_for_backward(*all_tensors) | |
ctx.layers = layers | |
return torch.cat([X_1, X_2], dim=-1) | |
@staticmethod | |
def backward(ctx, dx): | |
print(f"BACKWARD PASS ... (device={dx.get_device()}") | |
dX_1, dX_2 = torch.chunk(dx, 2, dim=-1) | |
assert dX_1.shape == dX_2.shape | |
X_1, X_2 = ctx.saved_tensors | |
layers = ctx.layers | |
for _, layer in enumerate(layers[::-1]): | |
X_1, X_2, dX_1, dX_2 = layer.backward_pass( | |
Y_1=X_1, | |
Y_2=X_2, | |
dY_1=dX_1, | |
dY_2=dX_2, | |
) | |
dx = torch.cat([dX_1, dX_2], dim=-1) | |
del dX_1, dX_2, X_1, X_2 | |
return dx, None, None | |
# def patch_embed() | |
# self.to_patch_embedding = nn.Sequential( | |
# Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width), | |
# nn.Linear(patch_dim, dim), | |
# ) | |
def patch_embed(dim_out: int, patch_size: int, img_size: int): | |
assert ( | |
img_size % patch_size == 0 | |
), f"img_size: {img_size} not divisible by patch_size: {patch_size}" | |
return nn.Sequential( | |
nn.Conv2d( | |
3, | |
dim_out, | |
kernel_size=(patch_size, patch_size), | |
stride=(patch_size, patch_size), | |
padding=(0, 0), | |
), | |
Rearrange("b dim_out w_patched h_patched -> b (w_patched h_patched) dim_out") | |
) | |
class ReversibleViTParams(hp.Hparams): | |
depth: int = hp.required("# of transformer blocks") | |
model_dim: int = hp.required("width of internal representation") | |
output_dim: int = hp.required("width of final representation") | |
heads: int = hp.required("# of attention of heads") | |
patch_size: int = hp.required("width/height of patch") | |
class ReversibleVIT(nn.Module): | |
def __init__(self, cfg, img_size: int): | |
super().__init__() | |
depth, model_dim, output_dim, heads, mlp_ratio, patch_size = ( | |
cfg.depth, | |
cfg.model_dim, | |
cfg.output_dim, | |
cfg.heads, | |
cfg.mlp_ratio, | |
cfg.patch_size, | |
) | |
self.output_dim = output_dim | |
assert ( | |
model_dim % 2 == 0 | |
), f"model_dim must be divisible by 2 for reversible ViT" | |
self.patchify = patch_embed(model_dim // 2, patch_size, img_size) | |
self.blocks = nn.ModuleList([]) | |
for i in range(depth): | |
block = ReversibleBlock( | |
dim=model_dim, | |
heads=heads, | |
mlp_ratio=mlp_ratio, | |
drop_path_rate=None, | |
layer_id=i, | |
) | |
self.blocks.append(block) | |
self.norm = norm(model_dim) | |
# TODO: Not sure where this scale parameter comes from | |
scale = model_dim ** -0.5 | |
# OpenAI has a final projection | |
self.proj = nn.Parameter(scale * torch.randn(model_dim, output_dim)) | |
def forward(self, x): | |
patches = self.patchify(x) | |
concat = torch.cat([patches, patches], dim=-1) | |
concat = RevBackProp.apply(concat, self.blocks) | |
concat = self.norm(concat) | |
concat = concat.mean(1) | |
return concat @ self.proj | |
# ignore unnecessary kwargs | |
def rev_vit_base_patch16_224(**_): | |
heads = 12 | |
return ReversibleVIT( | |
SimpleNamespace(**dict( | |
depth=12, | |
model_dim=heads * 64, | |
output_dim=512, | |
heads=heads, | |
mlp_ratio=4, | |
patch_size=16, | |
)), | |
img_size=224, | |
) | |
def ddp_test_helper(rank, world_size): | |
import torch.distributed as dist | |
from torch.nn.parallel import DistributedDataParallel as DDP | |
import torch.optim as optim | |
dist.init_process_group("nccl", rank=rank, world_size=world_size) | |
model = rev_vit_base_patch16_224().to(rank) | |
ddp_model = DDP(model, device_ids=[rank]) | |
optimizer = optim.SGD(ddp_model.parameters(), lr=0.0001) | |
x = torch.randn((1, 3, 224, 224)).to(rank) | |
outputs = ddp_model(x) | |
outputs.sum().backward() | |
optimizer.step() | |
print("Basic DDP passed") | |
# Try a more complicated case | |
print("Running more complicated DDP test ...") | |
x1 = torch.randn((1, 3, 224, 224)).to(rank) | |
x2 = torch.randn((1, 3, 224, 224)).to(rank) | |
ddp_model = DDP(model, device_ids=[rank], static_graph=True) | |
optimizer = optim.SGD(ddp_model.parameters(), lr=0.0001) | |
outputs1 = ddp_model(x1) | |
outputs2 = ddp_model(x2) | |
(outputs1 + outputs2).sum().backward() | |
optimizer.step() | |
print("Re-entrant DDP passed") | |
if __name__ == "__main__": | |
cfg = SimpleNamespace( | |
**dict( | |
depth=6, | |
model_dim=768, | |
output_dim=512, | |
heads=8, | |
patch_size=16, | |
mlp_ratio=4, | |
) | |
) | |
img_size = 16 * 14 | |
model = ReversibleVIT(cfg, img_size=img_size) | |
x = torch.randn((1, 3, img_size, img_size)) | |
y = model(x) | |
print("Forward finished") | |
y.sum().backward() | |
print("Backward finished") | |
print("Custom vit config passed") | |
model2 = rev_vit_base_patch16_224() | |
x = torch.randn((1, 3, img_size, img_size)) | |
y = model2(x) | |
print("Forward finished") | |
y.sum().backward() | |
print("Backward finished") | |
print("Standard vit config passed") | |
if torch.cuda.is_available(): | |
print("running cuda tests") | |
model3 = rev_vit_base_patch16_224().cuda() | |
x = torch.randn((1, 3, img_size, img_size)).cuda() | |
y = model3(x) | |
print("Forward finished (cuda)") | |
y.sum().backward() | |
print("Backward finished (cuda)") | |
print("Standard vit config passed (cuda)") | |
if torch.cuda.device_count() > 1: | |
print("running ddp tests") | |
import torch.multiprocessing as mp | |
import os | |
os.environ["MASTER_ADDR"] = "localhost" | |
os.environ["MASTER_PORT"] = "29500" | |
world_size = 2 | |
mp.spawn(ddp_test_helper, | |
args=(world_size,), | |
nprocs=world_size, | |
join=True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment