Skip to content

Instantly share code, notes, and snippets.

@stas00
Created December 18, 2023 03:12
Show Gist options
  • Save stas00/80e10917650feb167ad9b3b7235b0c4a to your computer and use it in GitHub Desktop.
Save stas00/80e10917650feb167ad9b3b7235b0c4a to your computer and use it in GitHub Desktop.
memory allocations breakdown
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "raw",
"id": "8e3cebc7-369d-4cf0-b9bf-63555f042bb2",
"metadata": {},
"source": [
"pip install transformers nvidia-ml-py3 einops ipyexperiments"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "43d0272a-78b0-48ac-b6d1-e7b57dc01650",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn.functional as F\n",
"import pynvml\n",
"from transformers import AutoModelForCausalLM, AutoTokenizer\n",
"from ipyexperiments import IPyExperimentsPytorch\n",
"import gc\n",
"import os\n",
"os.environ['CUDA_MODULE_LOADING'] = 'EAGER' # force kernel preloading"
]
},
{
"cell_type": "markdown",
"id": "ef691ce6-276d-4d79-855a-58ad721b5af0",
"metadata": {},
"source": [
"# Run parameters"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "00cb2ead-719b-4e17-8558-3c1ae4bb0d3f",
"metadata": {},
"outputs": [],
"source": [
"device = torch.device(\"cuda\")\n",
"model_name_or_path = \"microsoft/phi-1_5\" # microsoft/phi-1_5, microsoft/phi-2, NousResearch/Llama-2-7b-hf, mistralai/Mistral-7B-v0.1, gpt2, gpt2-medium, gpt2-large, gpt2-xl\n",
"dtype = torch.float32\n",
"mixed_precision_training = True\n",
"bs = 2\n",
"seq_length = 128\n",
"get_optimizer = lambda parameters: torch.optim.SGD(parameters, lr=0.1, momentum=0.9) # SGD(parameters, lr=0.1), SGD(parameters, lr=0.1, momentum=0.9), AdamW(parameters, lr=0.1)\n",
"\n",
"if mixed_precision_training:\n",
" assert dtype == torch.float32"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "3d48830c-cd59-489b-b500-459eb647c1cd",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CUDA kernels VRAM: 955 MiB\n",
"\n",
"*** Experiment started with the Pytorch backend\n",
"Device: ID 0, NVIDIA A100 80GB PCIe (81920 RAM)\n",
"\n",
"\n",
"*** Current state:\n",
"RAM: Used Free Total Util\n",
"CPU: 3,106 85,241 128,649 MB 2.41% \n",
"GPU: 1,885 80,034 81,920 MB 2.30% \n",
"\n",
"\n",
"・ RAM: △Consumed △Peaked Used Total | Exec time 0:00:00.000\n",
"・ CPU: 0 0 3,106 MB |\n",
"・ GPU: 0 0 1,885 MB |\n"
]
}
],
"source": [
"n_bytes_per_param = 2 if dtype in (torch.float16, torch.bfloat16) else 4\n",
"\n",
"pynvml.nvmlInit()\n",
"handle = pynvml.nvmlDeviceGetHandleByIndex(0)\n",
"get_vram = lambda: pynvml.nvmlDeviceGetMemoryInfo(handle).used / 2**20 # MiB\n",
"\n",
"start_vram = get_vram()\n",
"\n",
"# Initializing CUDA kernels\n",
"a = torch.ones((1,1)).to(device); del a; torch.cuda.empty_cache()\n",
"cuda_kernels_vram = get_vram() - start_vram\n",
"print(f\"CUDA kernels VRAM: {cuda_kernels_vram:.0f} MiB\")\n",
"\n",
"exp = IPyExperimentsPytorch()"
]
},
{
"cell_type": "markdown",
"id": "1689d757-e854-45b8-a35d-3e6e31994b83",
"metadata": {},
"source": [
"# Loading model"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "56bcc214-a1f6-43f5-836c-157be2afd2de",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"・ RAM: △Consumed △Peaked Used Total | Exec time 0:00:00.304\n",
"・ CPU: 36 9 3,143 MB |\n",
"・ GPU: 0 0 1,885 MB |\n"
]
}
],
"source": [
"tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)\n",
"if tokenizer.pad_token is None:\n",
" tokenizer.pad_token = tokenizer.eos_token"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "90d8cc03-2889-4ee0-9869-9d932bd86ac1",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"PhiConfig {\n",
" \"_name_or_path\": \"microsoft/phi-1_5\",\n",
" \"activation_function\": \"gelu_new\",\n",
" \"architectures\": [\n",
" \"PhiForCausalLM\"\n",
" ],\n",
" \"attn_pdrop\": 0.0,\n",
" \"auto_map\": {\n",
" \"AutoConfig\": \"microsoft/phi-1_5--configuration_phi.PhiConfig\",\n",
" \"AutoModelForCausalLM\": \"microsoft/phi-1_5--modeling_phi.PhiForCausalLM\"\n",
" },\n",
" \"embd_pdrop\": 0.0,\n",
" \"flash_attn\": false,\n",
" \"flash_rotary\": false,\n",
" \"fused_dense\": false,\n",
" \"initializer_range\": 0.02,\n",
" \"layer_norm_epsilon\": 1e-05,\n",
" \"model_type\": \"phi-msft\",\n",
" \"n_embd\": 2048,\n",
" \"n_head\": 32,\n",
" \"n_head_kv\": null,\n",
" \"n_inner\": null,\n",
" \"n_layer\": 24,\n",
" \"n_positions\": 2048,\n",
" \"resid_pdrop\": 0.0,\n",
" \"rotary_dim\": 32,\n",
" \"tie_word_embeddings\": false,\n",
" \"torch_dtype\": \"float32\",\n",
" \"transformers_version\": \"4.37.0.dev0\",\n",
" \"use_cache\": false,\n",
" \"vocab_size\": 51200\n",
"}\n",
"\n",
"===========================================================================\n",
"PhiForCausalLM(\n",
" (transformer): PhiModel(\n",
" (embd): Embedding(\n",
" (wte): Embedding(51200, 2048)\n",
" (drop): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (h): ModuleList(\n",
" (0-23): 24 x ParallelBlock(\n",
" (ln): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)\n",
" (resid_dropout): Dropout(p=0.0, inplace=False)\n",
" (mixer): MHA(\n",
" (rotary_emb): RotaryEmbedding()\n",
" (Wqkv): Linear(in_features=2048, out_features=6144, bias=True)\n",
" (out_proj): Linear(in_features=2048, out_features=2048, bias=True)\n",
" (inner_attn): SelfAttention(\n",
" (drop): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (inner_cross_attn): CrossAttention(\n",
" (drop): Dropout(p=0.0, inplace=False)\n",
" )\n",
" )\n",
" (mlp): MLP(\n",
" (fc1): Linear(in_features=2048, out_features=8192, bias=True)\n",
" (fc2): Linear(in_features=8192, out_features=2048, bias=True)\n",
" (act): NewGELUActivation()\n",
" )\n",
" )\n",
" )\n",
" )\n",
" (lm_head): CausalLMHead(\n",
" (ln): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)\n",
" (linear): Linear(in_features=2048, out_features=51200, bias=True)\n",
" )\n",
" (loss): CausalLMLoss(\n",
" (loss_fct): CrossEntropyLoss()\n",
" )\n",
")\n",
"===========================================================================\n",
"Number of parameters: 1.418 B (1418271104)\n",
"Model VRAM usage: 5496 MiB (expected 5410 MiB, error 1.6 %)\n",
"===========================================================================\n",
"・ RAM: △Consumed △Peaked Used Total | Exec time 0:00:04.714\n",
"・ CPU: 333 7,823 3,476 MB |\n",
"・ GPU: 5,496 0 7,381 MB |\n"
]
}
],
"source": [
"model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=dtype, trust_remote_code=True).to(device)\n",
"model.config.use_cache = False\n",
"\n",
"n_parameters = sum(p.numel() for p in model.parameters()) + sum(p.numel() for p in model.buffers())\n",
"model_estimated_vram = n_parameters * n_bytes_per_param / 2**20\n",
"model_actual_vram = get_vram() - cuda_kernels_vram - start_vram\n",
"\n",
"#n_buffers = sum(p.numel() for p in model.buffers())\n",
"\n",
"print(model.config)\n",
"print(\"=\" * 75)\n",
"print(model)\n",
"print(\"=\" * 75)\n",
"print(f\"Number of parameters: {(n_parameters / 1e9):.3f} B ({n_parameters})\")\n",
"print(f\"Model VRAM usage: {(model_actual_vram):.0f} MiB (expected {(model_estimated_vram):.0f} MiB, error {((model_actual_vram - model_estimated_vram) * 100 / model_actual_vram):.1f} %)\")\n",
"print(\"=\" * 75)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "90db98fc-8de2-4c23-97b1-3bdbf780b7c7",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"For batch of 4 items with a sequence length of 512 it will consume 0.046875 MiB VRAM\n",
"・ RAM: △Consumed △Peaked Used Total | Exec time 0:00:00.000\n",
"・ CPU: 0 0 3,477 MB |\n",
"・ GPU: 0 0 7,381 MB |\n"
]
}
],
"source": [
"bs = 4\n",
"seq_length = 512\n",
"\n",
"batch_vram = 3 * bs * seq_length * 8 # 3 for input_ids, attention_masks and labels; 8 for each i64\n",
"print(f\"For batch of {bs} items with a sequence length of {seq_length} it will consume {batch_vram / 2**20} MiB VRAM\")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "5762f438-77e1-4c6d-a3c9-7eaeab21dd85",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[28640, 46785, 7134, ..., 32685, 9896, 1042],\n",
" [28685, 26733, 956, ..., 28010, 12865, 29406],\n",
" [19038, 45183, 9541, ..., 14378, 25289, 32570],\n",
" [45000, 35482, 9371, ..., 11262, 33852, 2560]], device='cuda:0')"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"text/plain": [
"tensor([[1., 1., 1., ..., 1., 1., 1.],\n",
" [1., 1., 1., ..., 1., 1., 1.],\n",
" [1., 1., 1., ..., 1., 1., 1.],\n",
" [1., 1., 1., ..., 1., 1., 1.]], device='cuda:0')"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"・ RAM: △Consumed △Peaked Used Total | Exec time 0:00:00.036\n",
"・ CPU: 13 0 3,490 MB |\n",
"・ GPU: 0 0 7,381 MB |\n"
]
}
],
"source": [
"input_ids = torch.randint(0, len(tokenizer), (bs, seq_length)).to(device)\n",
"attention_mask = torch.ones((bs, seq_length)).to(device)\n",
"input_ids\n",
"attention_mask"
]
},
{
"cell_type": "markdown",
"id": "0b004db1-7540-47c7-a498-aa8f43b910b6",
"metadata": {},
"source": [
"# Warmup Inference forward pass"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "41403d8a-7934-4294-b5ba-a50a6d5b69f9",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"・ RAM: △Consumed △Peaked Used Total | Exec time 0:00:01.220\n",
"・ CPU: 1,865 0 5,355 MB |\n",
"・ GPU: 394 0 7,775 MB |\n"
]
}
],
"source": [
"# warmup - could possibly load some modules / allocate structures - this is your missing eps_ram\n",
"_ = model.eval()\n",
"input_ids_1 = torch.randint(0, len(tokenizer), (1, 1)).to(device)\n",
"attention_mask_1 = torch.ones((1, 1)).to(device)\n",
"with torch.no_grad():\n",
" out = model(input_ids=input_ids_1, attention_mask=attention_mask_1)\n",
" # probs = F.softmax(out.logits[:, -1, :], dim=-1) # for inference we need probabilities only over the last token; omit this as it is very small\n",
" del out"
]
},
{
"cell_type": "markdown",
"id": "8dc419ea-cacd-4709-8a05-942c06d4887f",
"metadata": {},
"source": [
"# Real Inference forward pass"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "c701ccd7-3b5a-42e1-a7e3-2de4a78eeb66",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"・ RAM: △Consumed △Peaked Used Total | Exec time 0:00:00.063\n",
"・ CPU: 0 0 5,355 MB |\n",
"・ GPU: 400 566 8,175 MB |\n"
]
}
],
"source": [
"# real run\n",
"_ = model.eval()\n",
"\n",
"with torch.no_grad():\n",
" out = model(input_ids=input_ids, attention_mask=attention_mask)\n",
" # probs = F.softmax(out.logits[:, -1, :], dim=-1) # for inference we need probabilities only over the last token; omit this as it is very small"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "75bbcef2-4b7b-4a21-b6c3-fe5206ee3ec5",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Out tensor dtype: torch.float32\n",
"・ RAM: △Consumed △Peaked Used Total | Exec time 0:00:00.000\n",
"・ CPU: 0 0 5,355 MB |\n",
"・ GPU: -400 0 7,775 MB |\n"
]
}
],
"source": [
"out_bs, out_sequence_length, out_embedding_size = out.logits.shape\n",
"n_bytes_per_param_out = 2 if out.logits.dtype in (torch.float16, torch.bfloat16) else 4\n",
"output_estimated_vram = out_bs * out_sequence_length * out_embedding_size * n_bytes_per_param_out / 2**20\n",
"print(f\"Out tensor dtype: {out.logits.dtype}\")\n",
"del out; torch.cuda.empty_cache() # calling `free` on allocated memory for `out` tensor"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "2481fa80-b294-4c48-906b-5b4631670ff2",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Total forward pass VRAM usage: 394 MiB\n",
"Output tensor with bs 4, seq length 512 and emb size 51200 VRAM usage: 0 MiB (expected 400 MiB)\n",
"Activations VRAM usage: 0 MiB\n",
"Random eps VRAM: 394 MiB\n",
"===========================================================================\n",
"・ RAM: △Consumed △Peaked Used Total | Exec time 0:00:00.009\n",
"・ CPU: 0 0 5,355 MB |\n",
"・ GPU: 0 0 7,775 MB |\n"
]
}
],
"source": [
"total_forward_pass_vram = get_vram() - model_actual_vram - cuda_kernels_vram - start_vram\n",
"torch.cuda.empty_cache() # calling `free` on allocated memory for forward pass\n",
"output_vram = get_vram() - model_actual_vram - cuda_kernels_vram - start_vram\n",
"\n",
"eps_vram = get_vram() - model_actual_vram - cuda_kernels_vram - start_vram # idk what is that, but it is small\n",
"\n",
"output_actual_vram = output_vram - eps_vram\n",
"activations_actual_vram = total_forward_pass_vram - output_actual_vram - eps_vram\n",
"\n",
"print(f\"Total forward pass VRAM usage: {total_forward_pass_vram:.0f} MiB\")\n",
"print(f\"Output tensor with bs {out_bs}, seq length {out_sequence_length} and emb size {out_embedding_size} VRAM usage: {output_actual_vram:.0f} MiB (expected {output_estimated_vram:.0f} MiB)\")\n",
"print(f\"Activations VRAM usage: {activations_actual_vram:.0f} MiB\")\n",
"print(f\"Random eps VRAM: {eps_vram:.0f} MiB\")\n",
"#print(torch.cuda.memory_summary())\n",
"print(\"=\" * 75)"
]
},
{
"cell_type": "markdown",
"id": "fd1f8bac-c9a4-42b5-bad8-3f5e41fba771",
"metadata": {},
"source": [
"# Warm up Training step"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "43fe4164-5c23-43e9-93e2-df08e6c09927",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"・ RAM: △Consumed △Peaked Used Total | Exec time 0:00:00.013\n",
"・ CPU: 0 0 5,355 MB |\n",
"・ GPU: 0 0 7,775 MB |\n"
]
}
],
"source": [
"# warmup\n",
"_ = model.train()"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "df65b3fe-4ffb-4def-a139-a21bf65dcf43",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"・ RAM: △Consumed △Peaked Used Total | Exec time 0:00:00.065\n",
"・ CPU: 2 0 5,358 MB |\n",
"・ GPU: 2 2,556 7,777 MB |\n"
]
}
],
"source": [
"# check forward - we already run fwd during inference - so expecting no additional memory allocated \n",
"with torch.autocast(device_type=device.type, dtype=torch.float16, enabled=mixed_precision_training):\n",
" out = model(input_ids=input_ids_1, attention_mask=attention_mask_1)\n",
" probs = F.softmax(out.logits, dim=-1)\n",
" loss = F.cross_entropy(probs.permute(0, 2, 1), input_ids_1) # mapping tokens into themselves\n",
"\n",
"del out\n",
"del probs\n",
"del loss\n",
"\n",
"# no leaks here"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "36c4e5a6-f830-4a20-b43b-5f634a061b52",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"・ RAM: △Consumed △Peaked Used Total | Exec time 0:00:00.001\n",
"・ CPU: 0 0 5,358 MB |\n",
"・ GPU: 0 0 7,777 MB |\n"
]
}
],
"source": [
"optimizer = get_optimizer(model.parameters())\n",
"scaler = torch.cuda.amp.GradScaler(enabled=mixed_precision_training)\n",
"del scaler\n",
"del optimizer\n",
"\n",
"# no leaks here\n",
"# and it didn't even allocate any memory for either object"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "a72252fb-b2ef-4e57-affd-c8e99aa57272",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"grads memory: 5410.275 MB\n",
"・ RAM: △Consumed △Peaked Used Total | Exec time 0:00:00.130\n",
"・ CPU: 3 0 5,362 MB |\n",
"・ GPU: 5,636 1,848 13,413 MB |\n"
]
}
],
"source": [
"# now running backward for the first time - so expects the grads memory allocation to occur\n",
"with torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=mixed_precision_training):\n",
" out = model(input_ids=input_ids_1, attention_mask=attention_mask_1)\n",
" probs = F.softmax(out.logits, dim=-1)\n",
" loss = F.cross_entropy(probs.permute(0, 2, 1), input_ids_1) # mapping tokens into themselves\n",
"loss.backward()\n",
"\n",
"del out\n",
"del probs\n",
"del loss\n",
"\n",
"# backward manifested grads here \n",
"grads_estimated_vram = n_parameters * n_bytes_per_param / 2**20\n",
"\n",
"print(f\"grads memory: {grads_estimated_vram:.3f} MB\")"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "33c391dc-cfa8-4c20-b0cc-49692f7c9073",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"・ RAM: △Consumed △Peaked Used Total | Exec time 0:00:00.098\n",
"・ CPU: 0 0 5,362 MB |\n",
"・ GPU: 2 2,948 13,415 MB |\n"
]
}
],
"source": [
"# now we expect the optimizer states to be manifested\n",
"optimizer = get_optimizer(model.parameters())\n",
"\n",
"with torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=mixed_precision_training):\n",
" out = model(input_ids=input_ids_1, attention_mask=attention_mask_1)\n",
" probs = F.softmax(out.logits, dim=-1)\n",
" loss = F.cross_entropy(probs.permute(0, 2, 1), input_ids_1) # mapping tokens into themselves\n",
"loss.backward()"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "a4e449a0-7ea0-488e-965d-faaa5e947844",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"optim_states estimated memory: 5410.275 MB\n",
"・ RAM: △Consumed △Peaked Used Total | Exec time 0:00:00.076\n",
"・ CPU: 0 0 5,363 MB |\n",
"・ GPU: 5,374 2 18,789 MB |\n"
]
}
],
"source": [
"optimizer.step() # we can see here the optim states get allocated only on the first step()\n",
"\n",
"del out\n",
"del probs\n",
"del loss\n",
"\n",
"optim_states_estimated_vram = n_parameters * n_bytes_per_param / 2**20\n",
"\n",
"print(f\"optim_states estimated memory: {optim_states_estimated_vram:.3f} MB\")"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "68ef439d-6f4b-42df-ac87-8537b850cbe0",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"・ RAM: △Consumed △Peaked Used Total | Exec time 0:00:00.001\n",
"・ CPU: 0 0 5,363 MB |\n",
"・ GPU: -5,404 0 13,385 MB |\n"
]
}
],
"source": [
"# free grads\n",
"optimizer.zero_grad(set_to_none=True)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "3675c67e-1411-4fb8-a361-28c61923eedb",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"・ RAM: △Consumed △Peaked Used Total | Exec time 0:00:00.001\n",
"・ CPU: 0 0 5,363 MB |\n",
"・ GPU: -5,578 0 7,807 MB |\n"
]
}
],
"source": [
"# free optim states\n",
"del optimizer"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "75894651-a847-4fc5-84de-f0a871e28e30",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"grads + optim_states estimated_vram: 10820.5 MB\n",
"・ RAM: △Consumed △Peaked Used Total | Exec time 0:00:00.000\n",
"・ CPU: 0 0 5,363 MB |\n",
"・ GPU: 0 0 7,807 MB |\n"
]
}
],
"source": [
"# optim states + grads memory\n",
"grads_n_optim_states_vram = grads_estimated_vram + optim_states_estimated_vram \n",
"print(f\"grads + optim_states estimated_vram: {grads_n_optim_states_vram:.1f} MB\")"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "d4237cf4-1a48-4e43-be99-c3018f98429c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"・ RAM: △Consumed △Peaked Used Total | Exec time 0:00:00.192\n",
"・ CPU: 0 0 5,363 MB |\n",
"・ GPU: 0 12,478 7,807 MB |\n"
]
}
],
"source": [
"# full warmup train step with reset - check that allocated memory before and after is the same\n",
"_ = model.train()\n",
"optimizer = get_optimizer(model.parameters())\n",
"\n",
"with torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=mixed_precision_training):\n",
" out = model(input_ids=input_ids_1, attention_mask=attention_mask_1)\n",
" probs = F.softmax(out.logits, dim=-1)\n",
" loss = F.cross_entropy(probs.permute(0, 2, 1), input_ids_1) # mapping tokens into themselves\n",
"loss.backward()\n",
"optimizer.step()\n",
"optimizer.zero_grad(set_to_none=True)\n",
"\n",
"del out\n",
"del probs\n",
"del loss\n",
"del optimizer\n",
"\n",
"# from peak memory delta we can see that optim states + grads that were allocated and then freed - do checkout - the rest of the peak memory is activations memory"
]
},
{
"cell_type": "markdown",
"id": "3cacc100-5376-43e5-85ea-ef938f600c8f",
"metadata": {},
"source": [
"# Real Training step"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "ca1837d7-5e00-480a-99f0-0fa0cd88e700",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model gradients type: torch.float32\n",
"Total train forward pass VRAM usage (activations + output tensor): 15930 MiB (expect 2705 MiB of these to be for fp16 weights copy)\n",
"Activations VRAM usage: 12825 MiB\n",
"Gradients VRAM usage: 4864 MiB (model weights were 5496 MiB)\n",
"Actual optimizer states VRAM usage: 6312 MiB\n",
"Random eps VRAM usage: 426 MiB\n",
"===========================================================================\n",
"・ RAM: △Consumed △Peaked Used Total | Exec time 0:00:00.601\n",
"・ CPU: 0 0 5,363 MB |\n",
"・ GPU: 0 17,898 7,807 MB |\n"
]
}
],
"source": [
"_ = model.train()\n",
"optimizer = get_optimizer(model.parameters())\n",
"\n",
"with torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=mixed_precision_training):\n",
" out = model(input_ids=input_ids, attention_mask=attention_mask)\n",
" total_train_forward_pass_vram = get_vram() - model_actual_vram - cuda_kernels_vram - start_vram - eps_vram\n",
" \n",
" probs = F.softmax(out.logits, dim=-1)\n",
" probs_vram = get_vram() - total_train_forward_pass_vram - model_actual_vram - cuda_kernels_vram - start_vram - eps_vram\n",
" \n",
" loss = F.cross_entropy(probs.permute(0, 2, 1), input_ids) # mapping tokens into themselves\n",
" loss_calculation_vram = get_vram() - probs_vram - total_train_forward_pass_vram - model_actual_vram - cuda_kernels_vram - start_vram - eps_vram\n",
"loss.backward()\n",
"optimizer.step()\n",
"\n",
"backward_vram = get_vram() - loss_calculation_vram - probs_vram - total_train_forward_pass_vram - model_actual_vram - cuda_kernels_vram - start_vram - eps_vram\n",
"\n",
"print(f\"Model gradients type: {next(model.parameters()).grad.dtype}\")\n",
"print(f\"Total train forward pass VRAM usage (activations + output tensor): {total_train_forward_pass_vram:.0f} MiB\" + (f\" (expect {(n_parameters * 2 / 2**20):.0f} MiB of these to be for fp16 weights copy)\" if mixed_precision_training else \"\"))\n",
"print(f\"Activations VRAM usage: {(total_train_forward_pass_vram - (n_parameters * 2 / 2**20 if mixed_precision_training else 0) - output_estimated_vram):.0f} MiB\")\n",
"#print(f\"Actual probs tensor VRAM usage: {probs_vram:.0f} MiB\")\n",
"#print(f\"Loss calculation VRAM usage: {loss_calculation_vram:.0f} MiB\")\n",
"#print(f\"Backward calculation VRAM usage: {backward_vram:.0f} MiB\")\n",
"\n",
"del out\n",
"del probs\n",
"del loss\n",
"torch.cuda.empty_cache() # calling `free` on allocated memory for activations and outputs\n",
"\n",
"gradients_optimizer_total_vram = get_vram() - model_actual_vram - cuda_kernels_vram - start_vram - eps_vram\n",
"optimizer.zero_grad(set_to_none=True); torch.cuda.empty_cache()\n",
"optimizer_total_vram = get_vram() - model_actual_vram - cuda_kernels_vram - start_vram - eps_vram\n",
"del optimizer; torch.cuda.empty_cache()\n",
"eps_2_vram = get_vram() - model_actual_vram - cuda_kernels_vram - start_vram - eps_vram\n",
"\n",
"gradients_actual_vram = gradients_optimizer_total_vram - optimizer_total_vram\n",
"optimizer_actual_vram = optimizer_total_vram - eps_2_vram\n",
"print(f\"Gradients VRAM usage: {gradients_actual_vram:.0f} MiB (model weights were {model_actual_vram:.0f} MiB)\")\n",
"print(f\"Actual optimizer states VRAM usage: {optimizer_actual_vram:.0f} MiB\")\n",
"\n",
"eps_vram += eps_2_vram\n",
"print(f\"Random eps VRAM usage: {eps_vram:.0f} MiB\")\n",
"print(\"=\" * 75)"
]
},
{
"cell_type": "markdown",
"id": "cfd62ab9-22e3-4e84-943e-82f4b63762a1",
"metadata": {},
"source": [
"# Estimation activations"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "ca072747-a715-4827-a4a1-3335b9c844a5",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Calculating size of activation for single block with:\n",
"hidden size 2048\n",
"num attention heads 32\n",
"num key value heads 32\n",
"intermediate size 8192\n",
"head dim 64\n",
"num hidden layers 24\n",
"===========================================================================\n",
"Single layer (out of 24) estimated activations VRAM usage: 296 MiB\n",
"All layers estimated activations VRAM usage: 7104 MiB\n",
"Estimated activations on inference forward pass VRAM usage (softmax output + v): 72 MiB\n",
"===========================================================================\n",
"・ RAM: △Consumed △Peaked Used Total | Exec time 0:00:00.001\n",
"・ CPU: 0 0 5,363 MB |\n",
"・ GPU: 0 0 7,807 MB |\n"
]
}
],
"source": [
"n_bytes_per_param = 2 if mixed_precision_training or dtype in (torch.float16, torch.bfloat16) else 4\n",
"\n",
"hidden_size = model.config.hidden_size\n",
"num_attention_heads = model.config.num_attention_heads\n",
"num_key_value_heads = model.config.num_key_value_heads if hasattr(model.config, \"num_key_value_heads\") else model.config.num_attention_heads # different from num_attention_heads in case of GQA\n",
"intermediate_size = model.config.intermediate_size if hasattr(model.config, \"intermediate_size\") else 4 * model.config.hidden_size # MLP projection\n",
"num_hidden_layers = model.config.num_hidden_layers\n",
"head_dim = hidden_size // num_attention_heads\n",
"print(f\"Calculating size of activation for single block with:\\nhidden size {hidden_size}\\nnum attention heads {num_attention_heads}\\nnum key value heads {num_key_value_heads}\\nintermediate size {intermediate_size}\\nhead dim {head_dim}\\nnum hidden layers {num_hidden_layers}\")\n",
"print(\"=\" * 75)\n",
"\n",
"attention_input = n_bytes_per_param * bs * seq_length * hidden_size\n",
"q = n_bytes_per_param * bs * seq_length * head_dim * num_attention_heads # for Q @ K.T\n",
"k = n_bytes_per_param * bs * seq_length * head_dim * num_key_value_heads # num_key_value_heads might be different from num_attention_heads in case of GQA\n",
"softmax_output = n_bytes_per_param * bs * num_attention_heads * seq_length ** 2 # to multiply with V\n",
"softmax_dropout_mask = 1 * bs * num_attention_heads * seq_length ** 2 # single byte per elem\n",
"dropout_output = n_bytes_per_param * bs * num_attention_heads * seq_length ** 2\n",
"v = n_bytes_per_param * bs * seq_length * head_dim * num_key_value_heads\n",
"out_proj_input = n_bytes_per_param * bs * seq_length * num_attention_heads * head_dim\n",
"attention_dropout = 1 * bs * seq_length * hidden_size\n",
"#attention_block = attention_input + q + k + softmax_output + v + out_proj_input\n",
"attention_block = attention_input + q + k + softmax_output + v + out_proj_input + softmax_dropout_mask + dropout_output + attention_dropout\n",
"\n",
"mlp_input = n_bytes_per_param * bs * seq_length * hidden_size\n",
"activation_input = n_bytes_per_param * bs * seq_length * intermediate_size # SiLU\n",
"down_proj_input = n_bytes_per_param * bs * seq_length * intermediate_size\n",
"dropout_mask = 1 * bs * seq_length * hidden_size # single byte per elem\n",
"#mlp_block = mlp_input + activation_input + down_proj_input\n",
"mlp_block = mlp_input + activation_input + down_proj_input + dropout_mask\n",
"\n",
"layer_norms = n_bytes_per_param * bs * seq_length * hidden_size * 2 # 2 layer norms\n",
"\n",
"layer = attention_block + mlp_block + layer_norms\n",
"print(f\"Single layer (out of {num_hidden_layers}) estimated activations VRAM usage: {layer // 2**20} MiB\")\n",
"print(f\"All layers estimated activations VRAM usage: {layer * num_hidden_layers // 2**20} MiB\")\n",
"print(f\"Estimated activations on inference forward pass VRAM usage (softmax output + v): {(softmax_output + v) // 2**20} MiB\")\n",
"print(\"=\" * 75)"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "7d70deac-a86a-403d-a21b-097e77e932fe",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"・ RAM: △Consumed △Peaked Used Total | Exec time 0:00:00.007\n",
"・ CPU: 0 0 5,363 MB |\n",
"・ GPU: 0 0 7,807 MB |\n"
]
}
],
"source": [
"# https://arxiv.org/pdf/2205.05198.pdf\n",
"\n",
"def calculate_attention_block():\n",
" return 11 * seq_length * bs * hidden_size + 5 * num_attention_heads * seq_length ** 2 * bs\n",
"\n",
"def calculate_mlp_block():\n",
" return 19 * seq_length * bs * hidden_size\n",
"\n",
"def calculate_layernorms():\n",
" return 4 * seq_length * bs * hidden_size\n",
"\n",
"def calculate_per_layer():\n",
" return seq_length * bs * hidden_size * (34 + 5 * num_attention_heads * seq_length / hidden_size)\n",
"\n",
"assert calculate_attention_block() + calculate_mlp_block() + calculate_layernorms() == calculate_per_layer()"
]
},
{
"cell_type": "raw",
"id": "23628603-07f8-4f3f-9731-018939acf519",
"metadata": {},
"source": [
"from torch.profiler import profile, record_function, ProfilerActivity\n",
"\n",
"with profile(activities=[ProfilerActivity.CUDA], profile_memory=True, record_shapes=True) as prof:\n",
" with torch.no_grad():\n",
" out = model(input_ids=input_ids, attention_mask=attention_mask)\n",
"\n",
"prof.key_averages().table(sort_by=\"self_cuda_memory_usage\", row_limit=10)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.18"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment