Skip to content

Instantly share code, notes, and snippets.

@lapp0
Last active March 2, 2024 23:18
Show Gist options
  • Save lapp0/d28931ebc9f59838800faa7c73e3a0dc to your computer and use it in GitHub Desktop.
Save lapp0/d28931ebc9f59838800faa7c73e3a0dc to your computer and use it in GitHub Desktop.
LLM Memory Requirement Calculator Script (Full Finetune and Inference)
import urllib.request
import json
def bits_to_gb(bits):
return bits / (8 * 1024**3)
def calculate_train_vram_requirements(
batch_size, seq_len, params, precision, num_layers, num_attn_heads, hidden_size, **ignored
):
"""
full train, not lora
source: https://arxiv.org/pdf/2205.05198.pdf (section 4.1)
credit: https://medium.com/@siddheshgunjal82/understanding-vram-requirements-to-train-inference-with-large-language-models-llms-a3edd0f09d9f
"""
# Calculate activations using the provided formula
activations = (
num_layers * (5/2) * num_attn_heads * batch_size * seq_len**2
+ 17 * batch_size * hidden_size * seq_len
)
# Calculate VRAM using the provided formula
vram_bits = precision * (activations + params)
# Convert VRAM from bits to Gigabytes
return bits_to_gb(vram_bits)
def calculate_inference_vram_requirements(
batch_size, seq_len, params, precision, num_layers, hidden_size,
num_attn_heads, num_kv_heads, gqa=True
):
"""
source 1: https://developer.nvidia.com/blog/mastering-llm-techniques-inference-optimization/
source 2: https://www.databricks.com/blog/llm-inference-performance-engineering-best-practices
- same as source 1, but with the introduction a factor (n_heads / n_kv_heads) specific to GQA
- "GQA helps with keeping the KV cache size down by sharing Keys/Values"
- defaulting to calculated models using GQA since Mistral, Yi, and Llama 2 use it
"""
kv_cache = batch_size * seq_len * 2 * num_layers * hidden_size
if gqa:
kv_cache *= num_kv_heads / num_attn_heads
vram_bits = precision * (kv_cache + params)
return bits_to_gb(vram_bits)
def get_model_params(model_uri):
url = f"https://huggingface.co/{model_uri}/raw/main/config.json"
with urllib.request.urlopen(url) as response:
return json.loads(response.read())
def print_table(model_uri, bparams, batch_size=1, precisions=None, mode="infer"):
precisions = precisions or [4, 6, 8, 16]
model_params = get_model_params(model_uri)
seq_lens = (
[2**i for i in range(8, 20) if 2**i< model_params["max_position_embeddings"]]
+ [model_params["max_position_embeddings"]]
)
calc_params = {
"num_layers": model_params["num_hidden_layers"],
"hidden_size": model_params["hidden_size"],
"num_attn_heads": model_params["num_attention_heads"],
"num_kv_heads": model_params["num_key_value_heads"],
}
if mode == "infer":
vram_calculator = calculate_inference_vram_requirements
elif mode == "train":
vram_calculator = calculate_train_vram_requirements
elif mode == "train_lora":
raise NotImplemented
else:
raise ValueError
column_width = 10
# Print the header of the table with precisions
header = f"{'SL / BP':>{column_width}}" + "".join([f" | {p:^10}" for p in precisions])
results = [
f"Model: {model_uri}",
f"Params: {bparams}B",
f"Batch Size: {batch_size}",
f"Mode: {mode}",
"",
"Sequence Length vs Bit Precision - Memory Requirements"
]
results.append(header)
results.append("-" * len(header))
# Iterate over each seq_len and calculate VRAM for each precision
for seq_len in seq_lens:
seq_len_label = f"{seq_len:>{column_width}}"
if seq_len == max(seq_lens):
seq_len_label = "*" + seq_len_label[1:]
row_data = [seq_len_label]
for precision in precisions:
vram_required = vram_calculator(
batch_size=batch_size,
seq_len=seq_len,
precision=precision,
params=bparams * 1e9,
**calc_params # Unpack additional parameters if provided
)
row_data.append(f"{vram_required:8.1f}GB") # Format with 1 decimal point
# Print each row of the table
results.append(" | ".join(row_data))
results += ["", "* Model Max Context Size"]
results += ["", "Code: https://gist.github.com/lapp0/d28931ebc9f59838800faa7c73e3a0dc/edit"]
print(" " + "\n ".join(results))
print_table("01-ai/Yi-34B-200K", bparams=34.395, mode="infer")
"""
Model: 01-ai/Yi-34B-200K
Params: 34.395B
Batch Size: 1
Mode: infer
Sequence Length vs Bit Precision - Memory Requirements
SL / BP | 4 | 6 | 8 | 16
--------------------------------------------------------------
256 | 16.0GB | 24.0GB | 32.1GB | 64.1GB
512 | 16.0GB | 24.1GB | 32.1GB | 64.2GB
1024 | 16.1GB | 24.1GB | 32.2GB | 64.3GB
2048 | 16.1GB | 24.2GB | 32.3GB | 64.5GB
4096 | 16.3GB | 24.4GB | 32.5GB | 65.0GB
8192 | 16.5GB | 24.7GB | 33.0GB | 65.9GB
16384 | 17.0GB | 25.4GB | 33.9GB | 67.8GB
32768 | 17.9GB | 26.8GB | 35.8GB | 71.6GB
65536 | 19.8GB | 29.6GB | 39.5GB | 79.1GB
131072 | 23.5GB | 35.3GB | 47.0GB | 94.1GB
* 200000 | 27.5GB | 41.2GB | 54.9GB | 109.8GB
* Model Max Context Size
Code: https://gist.github.com/lapp0/d28931ebc9f59838800faa7c73e3a0dc/edit
"""
print_table("01-ai/Yi-34B-200K", bparams=34.395, mode="train")
"""
Model: 01-ai/Yi-34B-200K
Params: 34.395B
Batch Size: 1
Mode: train
Sequence Length vs Bit Precision - Memory Requirements
SL / BP | 4 | 6 | 8 | 16
--------------------------------------------------------------
256 | 16.3GB | 24.4GB | 32.6GB | 65.1GB
512 | 17.1GB | 25.6GB | 34.1GB | 68.3GB
1024 | 20.2GB | 30.3GB | 40.4GB | 80.7GB
2048 | 32.5GB | 48.8GB | 65.1GB | 130.2GB
4096 | 81.9GB | 122.8GB | 163.7GB | 327.5GB
8192 | 279.0GB | 418.5GB | 558.0GB | 1115.9GB
16384 | 1066.9GB | 1600.4GB | 2133.9GB | 4267.8GB
32768 | 4217.9GB | 6326.8GB | 8435.8GB | 16871.5GB
65536 | 16819.7GB | 25229.6GB | 33639.5GB | 67278.9GB
131072 | 67223.5GB | 100835.2GB | 134446.9GB | 268893.8GB
* 200000 | 156489.6GB | 234734.3GB | 312979.1GB | 625958.2GB
* Model Max Context Size
Code: https://gist.github.com/lapp0/d28931ebc9f59838800faa7c73e3a0dc/edit
"""
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment