Created
January 10, 2024 04:09
-
-
Save vuvietbach/ca103b38b5d2af092d4bfd8f19526bd7 to your computer and use it in GitHub Desktop.
Minimal code to reproduce pytorch oom issue, in which gpu mem usage keeps increasing until out of mem
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
name: test_env | |
channels: | |
- pytorch | |
- nvidia | |
- anaconda | |
- conda-forge | |
- defaults | |
dependencies: | |
- pytorch=2.0.1=py3.11_cuda11.8_cudnn8.7.0_0 | |
- torchvision=0.15.2=py311_cu118 | |
- einops | |
- tqdm | |
- typing | |
- easydict |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn.functional as F | |
import itertools | |
from torch.utils.data import Dataset, DataLoader | |
from easydict import EasyDict as edict | |
import torch.nn as nn | |
from typing import Tuple, Union | |
from torch.utils.checkpoint import checkpoint as grad_ckpt | |
from collections import OrderedDict | |
from einops import rearrange | |
from torch import einsum | |
from torch.optim.lr_scheduler import CosineAnnealingLR, CosineAnnealingWarmRestarts | |
from tqdm import tqdm | |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
def main(): | |
cfg = edict( | |
{"DATA": {"NUM_INPUT_FRAMES": 8, 'num_way': 5, 'num_shot': 3}} | |
) | |
# Dataloader | |
dataset = TestDS(cfg) | |
train_loader = torch.utils.data.DataLoader( | |
dataset, | |
batch_size=1, | |
) | |
num_iter = 10000 | |
loader_iter = iter(itertools.cycle(train_loader)) | |
# Model | |
model = TestModel(cfg) | |
model.to(DEVICE) | |
model.train() | |
# optimizer | |
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4) | |
scheduler = CosineAnnealingWarmRestarts( | |
optimizer, 1000, eta_min=1e-6 | |
) | |
running_loss = 0 | |
gradient_accumulation_steps = 4 | |
for cur_iter in tqdm(range(num_iter)): | |
input = next(loader_iter) | |
input = squeeze(input, DEVICE) | |
input["split"] = "train" | |
output = model(input) | |
loss = F.cross_entropy(output["logits"], input["target_labels"].long()) | |
loss.backward() | |
running_loss += loss.item() | |
if (cur_iter + 1) % gradient_accumulation_steps == 0: | |
optimizer.step() | |
optimizer.zero_grad() | |
scheduler.step() | |
class TestDS(Dataset): | |
def __init__(self, cfg): | |
self.cfg = cfg | |
def __len__(self): | |
return 100 | |
def __getitem__(self, idx): | |
num_shot = self.cfg.DATA.num_shot | |
num_way = self.cfg.DATA.num_way | |
num_frame_per_video = self.cfg.DATA.NUM_INPUT_FRAMES | |
total_frames = num_shot * num_way * num_frame_per_video | |
support_set = torch.randn(total_frames, 3, 224, 224) | |
target_set = torch.randn(total_frames, 3, 224, 224) | |
support_labels = torch.arange(5).repeat_interleave(3) | |
real_support_labels = torch.arange(5).repeat_interleave(3) | |
target_labels = torch.arange(5).repeat_interleave(3) | |
return {"support_set":support_set, "support_labels":support_labels, "target_set":target_set, "target_labels":target_labels, "real_support_labels":real_support_labels} | |
def squeeze(input, DEVICE): | |
for k in input.keys(): | |
if torch.is_tensor(input[k]): | |
input[k] = input[k].squeeze(0).to(DEVICE) | |
elif isinstance(input[k], list): | |
input[k] = [x[0] for x in input[k]] | |
return input | |
class TestModel(torch.nn.Module): | |
""" | |
OTAM with a CNN backbone. | |
""" | |
def __init__(self, cfg): | |
super(TestModel, self).__init__() | |
self.cfg = cfg | |
backbone = CLIP( | |
embed_dim=512, | |
image_resolution=224, | |
vision_layers=12, | |
vision_width=768, | |
vision_patch_size=16, | |
context_length=77, | |
vocab_size=49408, | |
transformer_width=512, | |
transformer_heads=8, | |
transformer_layers=12, | |
spatial=False | |
) | |
self.backbone = backbone.visual # model.load_state_dict(state_dict) | |
self.mid_dim = 512 | |
with torch.no_grad(): | |
self.text_features = torch.randn(70, 512) | |
self.context2 = Transformer_v1( | |
dim=self.mid_dim, | |
heads=8, | |
dim_head_k=self.mid_dim // 8, | |
dropout_atte=0.2, | |
) | |
def get_feats( | |
self, | |
support_images, | |
target_images, | |
): | |
""" | |
Takes in images from the support set and query video and returns CNN features. | |
""" | |
support_features = self.backbone(support_images).squeeze() | |
target_features = self.backbone(target_images).squeeze() | |
dim = int(support_features.shape[1]) | |
support_features = support_features.reshape( | |
-1, self.cfg.DATA.NUM_INPUT_FRAMES, dim | |
) | |
target_features = target_features.reshape( | |
-1, self.cfg.DATA.NUM_INPUT_FRAMES, dim | |
) | |
return support_features, target_features | |
def calculate_sdtw_distance(self, target_features, support_features): | |
# calculate matrix | |
assert len(target_features.shape) == 3, "target_features must be 3D" | |
assert len(support_features.shape) == 3, "support_features must be 3D" | |
n_queries = target_features.shape[0] | |
n_support = support_features.shape[0] | |
support_features = rearrange(support_features, "b s d -> (b s) d") | |
target_features = rearrange(target_features, "b s d -> (b s) d") | |
frame_sim = cos_sim(target_features, support_features) | |
frame_dists = 1 - frame_sim | |
dists = rearrange( | |
frame_dists, "(tb ts) (sb ss) -> tb sb ts ss", tb=n_queries, sb=n_support | |
) # [25, 25, 8, 8] | |
cum_dists = OTAM_cum_dist_v2(dists) + OTAM_cum_dist_v2( | |
rearrange(dists, "tb sb ts ss -> tb sb ss ts") | |
) | |
return cum_dists | |
def forward(self, inputs): | |
support_images, support_labels, target_images, support_real_class = ( | |
inputs["support_set"], | |
inputs["support_labels"], | |
inputs["target_set"], | |
inputs["real_support_labels"], | |
) | |
support_images = support_images.reshape((-1, *support_images.shape[-3:])) | |
target_images = target_images.reshape((-1, *target_images.shape[-3:])) | |
res = {} | |
support_features, target_features = self.get_feats( | |
support_images, target_images | |
) | |
context_support = self.text_features[support_real_class.long().cpu()].unsqueeze(1).to(support_features.device) | |
support_features = torch.cat([support_features, context_support], dim=1) | |
support_features = self.context2( | |
support_features, support_features, support_features | |
)[:, : self.cfg.DATA.NUM_INPUT_FRAMES, :] | |
unique_labels = torch.unique(support_labels) | |
support_features_label = [ | |
torch.index_select( | |
support_features, 0, extract_class_indices(support_labels, c) | |
) | |
for c in unique_labels | |
] | |
support_features = [torch.mean(x, dim=0) for x in support_features_label] | |
support_features = torch.stack(support_features) | |
support_labels = unique_labels | |
target_features = self.context2( | |
target_features, target_features, target_features | |
) | |
output = self.calculate_sdtw_distance(target_features, support_features) | |
res.update({"logits": -output}) | |
return res | |
class CLIP(nn.Module): | |
def __init__( | |
self, | |
embed_dim: int, | |
# vision | |
image_resolution: int, | |
vision_layers: Union[Tuple[int, int, int, int], int], | |
vision_width: int, | |
vision_patch_size: int, | |
# text | |
context_length: int, | |
vocab_size: int, | |
transformer_width: int, | |
transformer_heads: int, | |
transformer_layers: int, | |
spatial=False, | |
): | |
super().__init__() | |
self.context_length = context_length | |
vision_heads = vision_width // 64 | |
self.visual = VisionTransformer( | |
input_resolution=image_resolution, | |
patch_size=vision_patch_size, | |
width=vision_width, | |
layers=vision_layers, | |
heads=vision_heads, | |
output_dim=embed_dim, | |
) | |
self.transformer = Transformer( | |
width=transformer_width, | |
layers=transformer_layers, | |
heads=transformer_heads, | |
attn_mask=self.build_attention_mask(), | |
) | |
self.vocab_size = vocab_size | |
self.token_embedding = nn.Embedding(vocab_size, transformer_width) | |
self.positional_embedding = nn.Parameter( | |
torch.empty(self.context_length, transformer_width) | |
) | |
self.ln_final = LayerNorm(transformer_width) | |
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) | |
self.logit_scale = nn.Parameter(torch.ones([]) * torch.log(torch.tensor(1 / 0.07))) | |
self.initialize_parameters() | |
def initialize_parameters(self): | |
nn.init.normal_(self.token_embedding.weight, std=0.02) | |
nn.init.normal_(self.positional_embedding, std=0.01) | |
proj_std = (self.transformer.width**-0.5) * ( | |
(2 * self.transformer.layers) ** -0.5 | |
) | |
attn_std = self.transformer.width**-0.5 | |
fc_std = (2 * self.transformer.width) ** -0.5 | |
for block in self.transformer.resblocks: | |
nn.init.normal_(block.attn.in_proj_weight, std=attn_std) | |
nn.init.normal_(block.attn.out_proj.weight, std=proj_std) | |
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) | |
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) | |
if self.text_projection is not None: | |
nn.init.normal_(self.text_projection, std=self.transformer.width**-0.5) | |
def build_attention_mask(self): | |
# lazily create causal attention mask, with full attention between the vision tokens | |
# pytorch uses additive attention mask; fill with -inf | |
mask = torch.empty(self.context_length, self.context_length) | |
mask.fill_(float("-inf")) | |
mask.triu_(1) # zero out the lower diagonal | |
return mask | |
@property | |
def dtype(self): | |
return self.visual.conv1.weight.dtype | |
def encode_image(self, image): | |
return self.visual(image.type(self.dtype)) | |
def encode_text(self, text): | |
x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] | |
x = x + self.positional_embedding.type(self.dtype) | |
x = x.permute(1, 0, 2) # NLD -> LND | |
x = self.transformer(x) | |
x = x.permute(1, 0, 2) # LND -> NLD | |
x = self.ln_final(x).type(self.dtype) | |
# x.shape = [batch_size, n_ctx, transformer.width] | |
# take features from the eot embedding (eot_token is the highest number in each sequence) | |
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection | |
return x | |
def forward(self, image, text): | |
image_features = self.encode_image(image) | |
text_features = self.encode_text(text) | |
# normalized features | |
image_features = image_features / image_features.norm(dim=1, keepdim=True) | |
text_features = text_features / text_features.norm(dim=1, keepdim=True) | |
# cosine similarity as logits | |
logit_scale = self.logit_scale.exp() | |
logits_per_image = logit_scale * image_features @ text_features.t() | |
logits_per_text = logits_per_image.t() | |
# shape = [global_batch_size, global_batch_size] | |
return logits_per_image, logits_per_text | |
class VisionTransformer(nn.Module): | |
def __init__( | |
self, | |
input_resolution: int, | |
patch_size: int, | |
width: int, | |
layers: int, | |
heads: int, | |
output_dim: int, | |
): | |
super().__init__() | |
self.input_resolution = input_resolution | |
self.output_dim = output_dim | |
self.conv1 = nn.Conv2d( | |
in_channels=3, | |
out_channels=width, | |
kernel_size=patch_size, | |
stride=patch_size, | |
bias=False, | |
) | |
scale = width**-0.5 | |
self.class_embedding = nn.Parameter(scale * torch.randn(width)) | |
self.positional_embedding = nn.Parameter( | |
scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width) | |
) | |
self.ln_pre = LayerNorm(width) | |
self.transformer = Transformer(width, layers, heads) | |
self.ln_post = LayerNorm(width) | |
self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) | |
def forward(self, x: torch.Tensor): | |
x = grad_ckpt(self.conv1, x) # shape = [*, width, grid, grid] | |
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] | |
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] | |
x = torch.cat( | |
[ | |
self.class_embedding.to(x.dtype) | |
+ torch.zeros( | |
x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device | |
), | |
x, | |
], | |
dim=1, | |
) # shape = [*, grid ** 2 + 1, width] | |
x = x + self.positional_embedding.to(x.dtype) | |
x = self.ln_pre(x) | |
x = x.permute(1, 0, 2) # NLD -> LND | |
x = self.transformer(x) | |
x = x.permute(1, 0, 2) # LND -> NLD | |
x = self.ln_post(x[:, 0, :]) | |
if self.proj is not None: | |
x = x @ self.proj | |
return x | |
class LayerNorm(nn.LayerNorm): | |
"""Subclass torch's LayerNorm to handle fp16.""" | |
def forward(self, x: torch.Tensor): | |
orig_type = x.dtype | |
ret = super().forward(x.type(torch.float32)) | |
return ret.type(orig_type) | |
class Transformer(nn.Module): | |
def __init__( | |
self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None | |
): | |
super().__init__() | |
self.width = width | |
self.layers = layers | |
self.resblocks = nn.Sequential( | |
*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)] | |
) | |
def forward(self, x: torch.Tensor): | |
return self.resblocks(x) | |
class ResidualAttentionBlock(nn.Module): | |
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): | |
super().__init__() | |
self.attn = nn.MultiheadAttention(d_model, n_head) | |
self.ln_1 = LayerNorm(d_model) | |
self.mlp = nn.Sequential( | |
OrderedDict( | |
[ | |
("c_fc", nn.Linear(d_model, d_model * 4)), | |
("gelu", QuickGELU()), | |
("c_proj", nn.Linear(d_model * 4, d_model)), | |
] | |
) | |
) | |
self.ln_2 = LayerNorm(d_model) | |
self.attn_mask = attn_mask | |
def attention(self, x: torch.Tensor): | |
self.attn_mask = ( | |
self.attn_mask.to(dtype=x.dtype, device=x.device) | |
if self.attn_mask is not None | |
else None | |
) | |
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] | |
def forward(self, x: torch.Tensor): | |
x = x + grad_ckpt(self.attention, self.ln_1(x)) | |
x = x + grad_ckpt(self.mlp, self.ln_2(x)) | |
return x | |
class Transformer_v1(nn.Module): | |
def __init__( | |
self, | |
heads=8, | |
dim=2048, | |
dim_head_k=256, | |
dim_head_v=256, | |
dropout_atte=0.05, | |
mlp_dim=2048, | |
dropout_ffn=0.05, | |
depth=1, | |
): | |
super().__init__() | |
self.layers = nn.ModuleList([]) | |
self.depth = depth | |
for _ in range(depth): | |
self.layers.append( | |
nn.ModuleList( | |
[ # PreNormattention(2048, Attention(2048, heads = 8, dim_head = 256, dropout = 0.2)) | |
# PreNormattention(heads, dim, dim_head_k, dim_head_v, dropout=dropout_atte), | |
PreNormattention_qkv( | |
dim, | |
Attention_qkv( | |
dim, | |
heads=heads, | |
dim_head=dim_head_k, | |
dropout=dropout_atte, | |
), | |
), | |
FeedForward(dim, mlp_dim, dropout=dropout_ffn), | |
] | |
) | |
) | |
def forward(self, q, k, v): | |
for attn, ff in self.layers: | |
x = attn(q, k, v) | |
x = ff(x) + x | |
q = x | |
k = x | |
v = x | |
return x | |
class PreNormattention_qkv(nn.Module): | |
def __init__(self, dim, fn): | |
super().__init__() | |
self.norm = nn.LayerNorm(dim) | |
self.fn = fn | |
def forward(self, q, k, v, **kwargs): | |
return self.fn(self.norm(q), self.norm(k), self.norm(v), **kwargs) + q | |
class Attention_qkv(nn.Module): | |
def __init__(self, dim, heads=8, dim_head=64, dropout=0.0): | |
super().__init__() | |
inner_dim = dim_head * heads | |
project_out = not (heads == 1 and dim_head == dim) | |
self.heads = heads | |
self.scale = dim_head**-0.5 | |
self.attend = nn.Softmax(dim=-1) | |
# self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) | |
self.to_q = nn.Linear(dim, inner_dim, bias=False) | |
self.to_k = nn.Linear(dim, inner_dim, bias=False) | |
self.to_v = nn.Linear(dim, inner_dim, bias=False) | |
self.to_out = ( | |
nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout)) | |
if project_out | |
else nn.Identity() | |
) | |
def forward(self, q, k, v): | |
b, n, _, h = *q.shape, self.heads | |
bk = k.shape[0] | |
# qkv = self.to_qkv(x).chunk(3, dim = -1) | |
q = grad_ckpt(self.to_q, q) | |
k = grad_ckpt(self.to_k, k) | |
v = grad_ckpt(self.to_v, v) | |
# q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) | |
q = rearrange(q, "b n (h d) -> b h n d", h=h) | |
k = rearrange(k, "b n (h d) -> b h n d", b=bk, h=h) | |
v = rearrange(v, "b n (h d) -> b h n d", b=bk, h=h) | |
dots = einsum("b h i d, b h j d -> b h i j", q, k) * self.scale | |
attn = self.attend(dots) # [30, 8, 8, 5] | |
out = einsum("b h i j, b h j d -> b h i d", attn, v) | |
out = rearrange(out, "b h n d -> b n (h d)") | |
return self.to_out(out) | |
class QuickGELU(nn.Module): | |
def forward(self, x: torch.Tensor): | |
return x * torch.sigmoid(1.702 * x) | |
class FeedForward(nn.Module): | |
def __init__(self, dim, hidden_dim, dropout=0.0): | |
super().__init__() | |
# self.net = nn.Sequential( | |
# nn.Linear(dim, hidden_dim), | |
# nn.GELU(), | |
# nn.Dropout(dropout), | |
# nn.Linear(hidden_dim, dim), | |
# nn.Dropout(dropout), | |
# ) | |
self.ff1 = nn.Linear(dim, hidden_dim) | |
self.act1 = nn.GELU() | |
self.drop1 = nn.Dropout(dropout) | |
self.ff2 = nn.Linear(hidden_dim, dim) | |
self.drop2 = nn.Dropout(dropout) | |
def forward(self, x): | |
x = self.drop1(self.act1(grad_ckpt(self.ff1, x))) | |
x = self.drop2(grad_ckpt(self.ff2, x)) | |
return x | |
def extract_class_indices(labels, which_class): | |
""" | |
Helper method to extract the indices of elements which have the specified label. | |
:param labels: (torch.tensor) Labels of the context set. | |
:param which_class: Label for which indices are extracted. | |
:return: (torch.tensor) Indices in the form of a mask that indicate the locations of the specified label. | |
""" | |
class_mask = torch.eq( | |
labels, which_class | |
) # binary mask of labels equal to which_class | |
class_mask_indices = class_mask.nonzero( | |
as_tuple=False | |
) # indices of labels equal to which class | |
return torch.reshape(class_mask_indices, (-1,)) | |
def cos_sim(x, y, epsilon=0.01): | |
""" | |
Calculates the cosine similarity between the last dimension of two tensors. | |
""" | |
numerator = torch.matmul(x, y.transpose(-1, -2)) | |
xnorm = torch.norm(x, dim=-1).unsqueeze(-1) | |
ynorm = torch.norm(y, dim=-1).unsqueeze(-1) | |
denominator = torch.matmul(xnorm, ynorm.transpose(-1, -2)) + epsilon | |
dists = torch.div(numerator, denominator) | |
return dists | |
def OTAM_cum_dist_v2(dists, lbda=0.5): | |
""" | |
Calculates the OTAM distances for sequences in one direction (e.g. query to support). | |
:input: Tensor with frame similarity scores of shape [n_queries, n_support, query_seq_len, support_seq_len] | |
TODO: clearn up if possible - currently messy to work with pt1.8. Possibly due to stack operation? | |
""" | |
dists = F.pad(dists, (1, 1), "constant", 0) # [25, 25, 8, 10] | |
cum_dists = torch.zeros(dists.shape, device=dists.device) | |
# top row | |
for m in range(1, dists.shape[3]): | |
# cum_dists[:,:,0,m] = dists[:,:,0,m] - lbda * torch.log( torch.exp(- cum_dists[:,:,0,m-1])) | |
# paper does continuous relaxation of the cum_dists entry, but it trains faster without, so using the simpler version for now: | |
cum_dists[:, :, 0, m] = dists[:, :, 0, m] + cum_dists[:, :, 0, m - 1] | |
# remaining rows | |
for l in range(1, dists.shape[2]): | |
# first non-zero column | |
cum_dists[:, :, l, 1] = dists[:, :, l, 1] - lbda * torch.log( | |
torch.exp(-cum_dists[:, :, l - 1, 0] / lbda) | |
+ torch.exp(-cum_dists[:, :, l - 1, 1] / lbda) | |
+ torch.exp(-cum_dists[:, :, l, 0] / lbda) | |
) | |
# middle columns | |
for m in range(2, dists.shape[3] - 1): | |
cum_dists[:, :, l, m] = dists[:, :, l, m] - lbda * torch.log( | |
torch.exp(-cum_dists[:, :, l - 1, m - 1] / lbda) | |
+ torch.exp(-cum_dists[:, :, l, m - 1] / lbda) | |
) | |
# last column | |
# cum_dists[:,:,l,-1] = dists[:,:,l,-1] - lbda * torch.log( torch.exp(- cum_dists[:,:,l-1,-2] / lbda) + torch.exp(- cum_dists[:,:,l,-2] / lbda) ) | |
cum_dists[:, :, l, -1] = dists[:, :, l, -1] - lbda * torch.log( | |
torch.exp(-cum_dists[:, :, l - 1, -2] / lbda) | |
+ torch.exp(-cum_dists[:, :, l - 1, -1] / lbda) | |
+ torch.exp(-cum_dists[:, :, l, -2] / lbda) | |
) | |
return cum_dists[:, :, -1, -1] | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment