Skip to content

Instantly share code, notes, and snippets.

@mikasenghaas
Created April 25, 2025 23:46
Show Gist options
  • Save mikasenghaas/b90ebeb48cb37b9075b8a1c158c6c6b9 to your computer and use it in GitHub Desktop.
Save mikasenghaas/b90ebeb48cb37b9075b8a1c158c6c6b9 to your computer and use it in GitHub Desktop.
Compute maximum batch size
# /// script
# requires-python = ">=3.10"
# dependencies = ["numpy", "pandas"]
# ///
import numpy as np
import pandas as pd
# Model parameters (P, H, K, L, T)
P = np.array([6738415616, 13015864320, 68976648192])
H = np.array([128, 128, 128])
K = np.array([32, 40, 8])
L = np.array([32, 40, 80])
T = 4096
# Hardware parameters (GPU memory)
total_memory = (np.array([24, 40, 80]) * 1e9).reshape(-1, 1)
# Compute maximum batch size
model_size = 2 * P
kv_cache_size = 2 * 2 * H * K * L
ephemeral_size = 0.1 * (model_size + T * kv_cache_size) # 10% of model size + kv cache size
max_batch_size = (total_memory - model_size - ephemeral_size) // (T * kv_cache_size)
max_batch_size = np.where(max_batch_size > 0, max_batch_size, 0)
# Show result
print(pd.DataFrame(max_batch_size, index=["24GB", "40GB", "80GB"], columns=["Llama-2 7B", "Llama-2 13B", "Llama-2 70B"]).map(lambda x: "OOM" if x == 0 else str(int(x))))
@mikasenghaas
Copy link
Author

mikasenghaas commented Apr 25, 2025

Run with

wget https://gist.githubusercontent.com/mikasenghaas/b90ebeb48cb37b9075b8a1c158c6c6b9/raw/e89f763c5e2a8f83524f09fc3a28258a363e6fcd/max-batch-size.py && uv run max-batch-size.py

To get the following output

     Llama-2 7B Llama-2 13B Llama-2 70B
24GB          4         OOM         OOM
40GB         11           3         OOM
80GB         30          15         OOM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment