Skip to content

Instantly share code, notes, and snippets.

@vuvietbach
Created January 10, 2024 04:09
Show Gist options
  • Save vuvietbach/ca103b38b5d2af092d4bfd8f19526bd7 to your computer and use it in GitHub Desktop.
Save vuvietbach/ca103b38b5d2af092d4bfd8f19526bd7 to your computer and use it in GitHub Desktop.
Minimal code to reproduce pytorch oom issue, in which gpu mem usage keeps increasing until out of mem
name: test_env
channels:
- pytorch
- nvidia
- anaconda
- conda-forge
- defaults
dependencies:
- pytorch=2.0.1=py3.11_cuda11.8_cudnn8.7.0_0
- torchvision=0.15.2=py311_cu118
- einops
- tqdm
- typing
- easydict
import torch
import torch.nn.functional as F
import itertools
from torch.utils.data import Dataset, DataLoader
from easydict import EasyDict as edict
import torch.nn as nn
from typing import Tuple, Union
from torch.utils.checkpoint import checkpoint as grad_ckpt
from collections import OrderedDict
from einops import rearrange
from torch import einsum
from torch.optim.lr_scheduler import CosineAnnealingLR, CosineAnnealingWarmRestarts
from tqdm import tqdm
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def main():
cfg = edict(
{"DATA": {"NUM_INPUT_FRAMES": 8, 'num_way': 5, 'num_shot': 3}}
)
# Dataloader
dataset = TestDS(cfg)
train_loader = torch.utils.data.DataLoader(
dataset,
batch_size=1,
)
num_iter = 10000
loader_iter = iter(itertools.cycle(train_loader))
# Model
model = TestModel(cfg)
model.to(DEVICE)
model.train()
# optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
scheduler = CosineAnnealingWarmRestarts(
optimizer, 1000, eta_min=1e-6
)
running_loss = 0
gradient_accumulation_steps = 4
for cur_iter in tqdm(range(num_iter)):
input = next(loader_iter)
input = squeeze(input, DEVICE)
input["split"] = "train"
output = model(input)
loss = F.cross_entropy(output["logits"], input["target_labels"].long())
loss.backward()
running_loss += loss.item()
if (cur_iter + 1) % gradient_accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
scheduler.step()
class TestDS(Dataset):
def __init__(self, cfg):
self.cfg = cfg
def __len__(self):
return 100
def __getitem__(self, idx):
num_shot = self.cfg.DATA.num_shot
num_way = self.cfg.DATA.num_way
num_frame_per_video = self.cfg.DATA.NUM_INPUT_FRAMES
total_frames = num_shot * num_way * num_frame_per_video
support_set = torch.randn(total_frames, 3, 224, 224)
target_set = torch.randn(total_frames, 3, 224, 224)
support_labels = torch.arange(5).repeat_interleave(3)
real_support_labels = torch.arange(5).repeat_interleave(3)
target_labels = torch.arange(5).repeat_interleave(3)
return {"support_set":support_set, "support_labels":support_labels, "target_set":target_set, "target_labels":target_labels, "real_support_labels":real_support_labels}
def squeeze(input, DEVICE):
for k in input.keys():
if torch.is_tensor(input[k]):
input[k] = input[k].squeeze(0).to(DEVICE)
elif isinstance(input[k], list):
input[k] = [x[0] for x in input[k]]
return input
class TestModel(torch.nn.Module):
"""
OTAM with a CNN backbone.
"""
def __init__(self, cfg):
super(TestModel, self).__init__()
self.cfg = cfg
backbone = CLIP(
embed_dim=512,
image_resolution=224,
vision_layers=12,
vision_width=768,
vision_patch_size=16,
context_length=77,
vocab_size=49408,
transformer_width=512,
transformer_heads=8,
transformer_layers=12,
spatial=False
)
self.backbone = backbone.visual # model.load_state_dict(state_dict)
self.mid_dim = 512
with torch.no_grad():
self.text_features = torch.randn(70, 512)
self.context2 = Transformer_v1(
dim=self.mid_dim,
heads=8,
dim_head_k=self.mid_dim // 8,
dropout_atte=0.2,
)
def get_feats(
self,
support_images,
target_images,
):
"""
Takes in images from the support set and query video and returns CNN features.
"""
support_features = self.backbone(support_images).squeeze()
target_features = self.backbone(target_images).squeeze()
dim = int(support_features.shape[1])
support_features = support_features.reshape(
-1, self.cfg.DATA.NUM_INPUT_FRAMES, dim
)
target_features = target_features.reshape(
-1, self.cfg.DATA.NUM_INPUT_FRAMES, dim
)
return support_features, target_features
def calculate_sdtw_distance(self, target_features, support_features):
# calculate matrix
assert len(target_features.shape) == 3, "target_features must be 3D"
assert len(support_features.shape) == 3, "support_features must be 3D"
n_queries = target_features.shape[0]
n_support = support_features.shape[0]
support_features = rearrange(support_features, "b s d -> (b s) d")
target_features = rearrange(target_features, "b s d -> (b s) d")
frame_sim = cos_sim(target_features, support_features)
frame_dists = 1 - frame_sim
dists = rearrange(
frame_dists, "(tb ts) (sb ss) -> tb sb ts ss", tb=n_queries, sb=n_support
) # [25, 25, 8, 8]
cum_dists = OTAM_cum_dist_v2(dists) + OTAM_cum_dist_v2(
rearrange(dists, "tb sb ts ss -> tb sb ss ts")
)
return cum_dists
def forward(self, inputs):
support_images, support_labels, target_images, support_real_class = (
inputs["support_set"],
inputs["support_labels"],
inputs["target_set"],
inputs["real_support_labels"],
)
support_images = support_images.reshape((-1, *support_images.shape[-3:]))
target_images = target_images.reshape((-1, *target_images.shape[-3:]))
res = {}
support_features, target_features = self.get_feats(
support_images, target_images
)
context_support = self.text_features[support_real_class.long().cpu()].unsqueeze(1).to(support_features.device)
support_features = torch.cat([support_features, context_support], dim=1)
support_features = self.context2(
support_features, support_features, support_features
)[:, : self.cfg.DATA.NUM_INPUT_FRAMES, :]
unique_labels = torch.unique(support_labels)
support_features_label = [
torch.index_select(
support_features, 0, extract_class_indices(support_labels, c)
)
for c in unique_labels
]
support_features = [torch.mean(x, dim=0) for x in support_features_label]
support_features = torch.stack(support_features)
support_labels = unique_labels
target_features = self.context2(
target_features, target_features, target_features
)
output = self.calculate_sdtw_distance(target_features, support_features)
res.update({"logits": -output})
return res
class CLIP(nn.Module):
def __init__(
self,
embed_dim: int,
# vision
image_resolution: int,
vision_layers: Union[Tuple[int, int, int, int], int],
vision_width: int,
vision_patch_size: int,
# text
context_length: int,
vocab_size: int,
transformer_width: int,
transformer_heads: int,
transformer_layers: int,
spatial=False,
):
super().__init__()
self.context_length = context_length
vision_heads = vision_width // 64
self.visual = VisionTransformer(
input_resolution=image_resolution,
patch_size=vision_patch_size,
width=vision_width,
layers=vision_layers,
heads=vision_heads,
output_dim=embed_dim,
)
self.transformer = Transformer(
width=transformer_width,
layers=transformer_layers,
heads=transformer_heads,
attn_mask=self.build_attention_mask(),
)
self.vocab_size = vocab_size
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
self.positional_embedding = nn.Parameter(
torch.empty(self.context_length, transformer_width)
)
self.ln_final = LayerNorm(transformer_width)
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
self.logit_scale = nn.Parameter(torch.ones([]) * torch.log(torch.tensor(1 / 0.07)))
self.initialize_parameters()
def initialize_parameters(self):
nn.init.normal_(self.token_embedding.weight, std=0.02)
nn.init.normal_(self.positional_embedding, std=0.01)
proj_std = (self.transformer.width**-0.5) * (
(2 * self.transformer.layers) ** -0.5
)
attn_std = self.transformer.width**-0.5
fc_std = (2 * self.transformer.width) ** -0.5
for block in self.transformer.resblocks:
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
if self.text_projection is not None:
nn.init.normal_(self.text_projection, std=self.transformer.width**-0.5)
def build_attention_mask(self):
# lazily create causal attention mask, with full attention between the vision tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.empty(self.context_length, self.context_length)
mask.fill_(float("-inf"))
mask.triu_(1) # zero out the lower diagonal
return mask
@property
def dtype(self):
return self.visual.conv1.weight.dtype
def encode_image(self, image):
return self.visual(image.type(self.dtype))
def encode_text(self, text):
x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
x = x + self.positional_embedding.type(self.dtype)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_final(x).type(self.dtype)
# x.shape = [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
return x
def forward(self, image, text):
image_features = self.encode_image(image)
text_features = self.encode_text(text)
# normalized features
image_features = image_features / image_features.norm(dim=1, keepdim=True)
text_features = text_features / text_features.norm(dim=1, keepdim=True)
# cosine similarity as logits
logit_scale = self.logit_scale.exp()
logits_per_image = logit_scale * image_features @ text_features.t()
logits_per_text = logits_per_image.t()
# shape = [global_batch_size, global_batch_size]
return logits_per_image, logits_per_text
class VisionTransformer(nn.Module):
def __init__(
self,
input_resolution: int,
patch_size: int,
width: int,
layers: int,
heads: int,
output_dim: int,
):
super().__init__()
self.input_resolution = input_resolution
self.output_dim = output_dim
self.conv1 = nn.Conv2d(
in_channels=3,
out_channels=width,
kernel_size=patch_size,
stride=patch_size,
bias=False,
)
scale = width**-0.5
self.class_embedding = nn.Parameter(scale * torch.randn(width))
self.positional_embedding = nn.Parameter(
scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)
)
self.ln_pre = LayerNorm(width)
self.transformer = Transformer(width, layers, heads)
self.ln_post = LayerNorm(width)
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
def forward(self, x: torch.Tensor):
x = grad_ckpt(self.conv1, x) # shape = [*, width, grid, grid]
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
x = torch.cat(
[
self.class_embedding.to(x.dtype)
+ torch.zeros(
x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device
),
x,
],
dim=1,
) # shape = [*, grid ** 2 + 1, width]
x = x + self.positional_embedding.to(x.dtype)
x = self.ln_pre(x)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_post(x[:, 0, :])
if self.proj is not None:
x = x @ self.proj
return x
class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""
def forward(self, x: torch.Tensor):
orig_type = x.dtype
ret = super().forward(x.type(torch.float32))
return ret.type(orig_type)
class Transformer(nn.Module):
def __init__(
self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None
):
super().__init__()
self.width = width
self.layers = layers
self.resblocks = nn.Sequential(
*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]
)
def forward(self, x: torch.Tensor):
return self.resblocks(x)
class ResidualAttentionBlock(nn.Module):
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
super().__init__()
self.attn = nn.MultiheadAttention(d_model, n_head)
self.ln_1 = LayerNorm(d_model)
self.mlp = nn.Sequential(
OrderedDict(
[
("c_fc", nn.Linear(d_model, d_model * 4)),
("gelu", QuickGELU()),
("c_proj", nn.Linear(d_model * 4, d_model)),
]
)
)
self.ln_2 = LayerNorm(d_model)
self.attn_mask = attn_mask
def attention(self, x: torch.Tensor):
self.attn_mask = (
self.attn_mask.to(dtype=x.dtype, device=x.device)
if self.attn_mask is not None
else None
)
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
def forward(self, x: torch.Tensor):
x = x + grad_ckpt(self.attention, self.ln_1(x))
x = x + grad_ckpt(self.mlp, self.ln_2(x))
return x
class Transformer_v1(nn.Module):
def __init__(
self,
heads=8,
dim=2048,
dim_head_k=256,
dim_head_v=256,
dropout_atte=0.05,
mlp_dim=2048,
dropout_ffn=0.05,
depth=1,
):
super().__init__()
self.layers = nn.ModuleList([])
self.depth = depth
for _ in range(depth):
self.layers.append(
nn.ModuleList(
[ # PreNormattention(2048, Attention(2048, heads = 8, dim_head = 256, dropout = 0.2))
# PreNormattention(heads, dim, dim_head_k, dim_head_v, dropout=dropout_atte),
PreNormattention_qkv(
dim,
Attention_qkv(
dim,
heads=heads,
dim_head=dim_head_k,
dropout=dropout_atte,
),
),
FeedForward(dim, mlp_dim, dropout=dropout_ffn),
]
)
)
def forward(self, q, k, v):
for attn, ff in self.layers:
x = attn(q, k, v)
x = ff(x) + x
q = x
k = x
v = x
return x
class PreNormattention_qkv(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, q, k, v, **kwargs):
return self.fn(self.norm(q), self.norm(k), self.norm(v), **kwargs) + q
class Attention_qkv(nn.Module):
def __init__(self, dim, heads=8, dim_head=64, dropout=0.0):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head**-0.5
self.attend = nn.Softmax(dim=-1)
# self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_k = nn.Linear(dim, inner_dim, bias=False)
self.to_v = nn.Linear(dim, inner_dim, bias=False)
self.to_out = (
nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))
if project_out
else nn.Identity()
)
def forward(self, q, k, v):
b, n, _, h = *q.shape, self.heads
bk = k.shape[0]
# qkv = self.to_qkv(x).chunk(3, dim = -1)
q = grad_ckpt(self.to_q, q)
k = grad_ckpt(self.to_k, k)
v = grad_ckpt(self.to_v, v)
# q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
q = rearrange(q, "b n (h d) -> b h n d", h=h)
k = rearrange(k, "b n (h d) -> b h n d", b=bk, h=h)
v = rearrange(v, "b n (h d) -> b h n d", b=bk, h=h)
dots = einsum("b h i d, b h j d -> b h i j", q, k) * self.scale
attn = self.attend(dots) # [30, 8, 8, 5]
out = einsum("b h i j, b h j d -> b h i d", attn, v)
out = rearrange(out, "b h n d -> b n (h d)")
return self.to_out(out)
class QuickGELU(nn.Module):
def forward(self, x: torch.Tensor):
return x * torch.sigmoid(1.702 * x)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout=0.0):
super().__init__()
# self.net = nn.Sequential(
# nn.Linear(dim, hidden_dim),
# nn.GELU(),
# nn.Dropout(dropout),
# nn.Linear(hidden_dim, dim),
# nn.Dropout(dropout),
# )
self.ff1 = nn.Linear(dim, hidden_dim)
self.act1 = nn.GELU()
self.drop1 = nn.Dropout(dropout)
self.ff2 = nn.Linear(hidden_dim, dim)
self.drop2 = nn.Dropout(dropout)
def forward(self, x):
x = self.drop1(self.act1(grad_ckpt(self.ff1, x)))
x = self.drop2(grad_ckpt(self.ff2, x))
return x
def extract_class_indices(labels, which_class):
"""
Helper method to extract the indices of elements which have the specified label.
:param labels: (torch.tensor) Labels of the context set.
:param which_class: Label for which indices are extracted.
:return: (torch.tensor) Indices in the form of a mask that indicate the locations of the specified label.
"""
class_mask = torch.eq(
labels, which_class
) # binary mask of labels equal to which_class
class_mask_indices = class_mask.nonzero(
as_tuple=False
) # indices of labels equal to which class
return torch.reshape(class_mask_indices, (-1,))
def cos_sim(x, y, epsilon=0.01):
"""
Calculates the cosine similarity between the last dimension of two tensors.
"""
numerator = torch.matmul(x, y.transpose(-1, -2))
xnorm = torch.norm(x, dim=-1).unsqueeze(-1)
ynorm = torch.norm(y, dim=-1).unsqueeze(-1)
denominator = torch.matmul(xnorm, ynorm.transpose(-1, -2)) + epsilon
dists = torch.div(numerator, denominator)
return dists
def OTAM_cum_dist_v2(dists, lbda=0.5):
"""
Calculates the OTAM distances for sequences in one direction (e.g. query to support).
:input: Tensor with frame similarity scores of shape [n_queries, n_support, query_seq_len, support_seq_len]
TODO: clearn up if possible - currently messy to work with pt1.8. Possibly due to stack operation?
"""
dists = F.pad(dists, (1, 1), "constant", 0) # [25, 25, 8, 10]
cum_dists = torch.zeros(dists.shape, device=dists.device)
# top row
for m in range(1, dists.shape[3]):
# cum_dists[:,:,0,m] = dists[:,:,0,m] - lbda * torch.log( torch.exp(- cum_dists[:,:,0,m-1]))
# paper does continuous relaxation of the cum_dists entry, but it trains faster without, so using the simpler version for now:
cum_dists[:, :, 0, m] = dists[:, :, 0, m] + cum_dists[:, :, 0, m - 1]
# remaining rows
for l in range(1, dists.shape[2]):
# first non-zero column
cum_dists[:, :, l, 1] = dists[:, :, l, 1] - lbda * torch.log(
torch.exp(-cum_dists[:, :, l - 1, 0] / lbda)
+ torch.exp(-cum_dists[:, :, l - 1, 1] / lbda)
+ torch.exp(-cum_dists[:, :, l, 0] / lbda)
)
# middle columns
for m in range(2, dists.shape[3] - 1):
cum_dists[:, :, l, m] = dists[:, :, l, m] - lbda * torch.log(
torch.exp(-cum_dists[:, :, l - 1, m - 1] / lbda)
+ torch.exp(-cum_dists[:, :, l, m - 1] / lbda)
)
# last column
# cum_dists[:,:,l,-1] = dists[:,:,l,-1] - lbda * torch.log( torch.exp(- cum_dists[:,:,l-1,-2] / lbda) + torch.exp(- cum_dists[:,:,l,-2] / lbda) )
cum_dists[:, :, l, -1] = dists[:, :, l, -1] - lbda * torch.log(
torch.exp(-cum_dists[:, :, l - 1, -2] / lbda)
+ torch.exp(-cum_dists[:, :, l - 1, -1] / lbda)
+ torch.exp(-cum_dists[:, :, l, -2] / lbda)
)
return cum_dists[:, :, -1, -1]
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment