-
-
Save mikasenghaas/b90ebeb48cb37b9075b8a1c158c6c6b9 to your computer and use it in GitHub Desktop.
Compute maximum batch size
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# /// script | |
# requires-python = ">=3.10" | |
# dependencies = ["numpy", "pandas"] | |
# /// | |
import numpy as np | |
import pandas as pd | |
# Model parameters (P, H, K, L, T) | |
P = np.array([6738415616, 13015864320, 68976648192]) | |
H = np.array([128, 128, 128]) | |
K = np.array([32, 40, 8]) | |
L = np.array([32, 40, 80]) | |
T = 4096 | |
# Hardware parameters (GPU memory) | |
total_memory = (np.array([24, 40, 80]) * 1e9).reshape(-1, 1) | |
# Compute maximum batch size | |
model_size = 2 * P | |
kv_cache_size = 2 * 2 * H * K * L | |
ephemeral_size = 0.1 * (model_size + T * kv_cache_size) # 10% of model size + kv cache size | |
max_batch_size = (total_memory - model_size - ephemeral_size) // (T * kv_cache_size) | |
max_batch_size = np.where(max_batch_size > 0, max_batch_size, 0) | |
# Show result | |
print(pd.DataFrame(max_batch_size, index=["24GB", "40GB", "80GB"], columns=["Llama-2 7B", "Llama-2 13B", "Llama-2 70B"]).map(lambda x: "OOM" if x == 0 else str(int(x)))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Run with
wget https://gist.githubusercontent.com/mikasenghaas/b90ebeb48cb37b9075b8a1c158c6c6b9/raw/e89f763c5e2a8f83524f09fc3a28258a363e6fcd/max-batch-size.py && uv run max-batch-size.py
To get the following output