Skip to content

Instantly share code, notes, and snippets.

@sighingnow
Created January 14, 2024 06:00
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sighingnow/87210c9c8bda04f12c85d9cfcadefe4e to your computer and use it in GitHub Desktop.
Save sighingnow/87210c9c8bda04f12c85d9cfcadefe4e to your computer and use it in GitHub Desktop.
llm-arithmetic: flops, memory footprint, memory access I/O, and more for both training and inference (prefill and generation)
class GPTModel:
def __init__(
self,
name: str = None,
vocab_size: int = 51200,
sequence_length: int = 2048,
attention_heads: int = 32,
hidden_size: int = 2304,
layers: int = 24,
micro_batch_size: int = 1,
global_batch_size: int = 512,
batch_size: int = 1,
query_attention_groups: int = 1,
):
self.name = name
self.vocab_size = vocab_size
self.sequence_length = sequence_length
self.attention_heads = attention_heads
self.hidden_size = hidden_size
self.layers = layers
self.micro_batch_size = micro_batch_size
self.global_batch_size = global_batch_size
self.batch_size = batch_size
self.query_attention_groups = query_attention_groups
def pretty_print_number(self, number: int = 1):
return format(number, ",")
def pretty_print_memory(self, size: int = 1):
if size < 2**10:
return f"{size} B"
elif size < 2**20:
return f"{size / 2 ** 10:.2f} KB"
elif size < 2**30:
return f"{size / 2 ** 20:.2f} MB"
elif size < 2**40:
return f"{size / 2 ** 30:.2f} GB"
else:
return f"{size / 2 ** 40:.2f} TB"
def number_of_parameters(self):
# embeddings: hidden_size x (vocab_size + sequence_length)
#
# Wq, Wk, Wv:
# - weight: 3 x (hidden_size x hidden_size)
# - bias: 3 x hidden_size
# FFN:
# - weight: 2 x (4 x hidden_size x hidden_size)
# - bias: hidden_size + 4 * hidden_size
# Wo:
# - weight: hidden_size x hidden_size
# - bias: hidden_size
#
# input layer norm, post attention layer norm, 2 x (hidden_size x 2 /* weight & bias */))
parameters = (
12 * self.layers * (self.hidden_size**2)
+ self.layers * self.hidden_size * 13
+ self.hidden_size * (self.vocab_size + self.sequence_length)
)
return parameters, self.pretty_print_number(parameters)
def kv_cache_per_token(self):
nbytes_fp16 = 2
kv_cache = self.hidden_size // self.query_attention_groups * self.layers * nbytes_fp16 * 2
return kv_cache, self.pretty_print_memory(kv_cache)
def kv_cache(
self,
batch_size: int = None,
sequence_length: int = None,
generation_length: int = None,
):
if batch_size is None:
batch_size = self.batch_size
context_length = self.sequence_length
if sequence_length is not None and generation_length is not None:
context_length = min(context_length, sequence_length + generation_length)
elif sequence_length is not None:
context_length = min(context_length, sequence_length)
elif generation_length is not None:
context_length = min(context_length, generation_length)
kv_cache_per_token, _ = self.kv_cache_per_token()
kv_cache = batch_size * context_length * kv_cache_per_token
return kv_cache, self.pretty_print_memory(kv_cache)
# For gemm: A (m x k) * B (k x n) = C (k x n)
# - flops: 2 * m * n * k
# - memory: m * k + k * n + m * n
def flops_forward_per_layer(self, batch_size: int = None, sequence_length: int = None):
if batch_size is None:
batch_size = self.micro_batch_size
if sequence_length is None:
sequence_length = self.sequence_length
# forward:
# - kqv: batch_size x 3 x 2 /* add, mul */ x (sequence_length x hidden_size x hidden_size)
# - qk^T: batch_size x 2 /* add, mul */ x (sequence_length x hidden_size x sequence_length)
# - softmax(qk^T) * v: batch_size x 2 /* add, mul */ x (sequence_length x sequence_length x hidden_size)
# - post attention linear: batch_size x 2 /* add, mul */ x (sequence_length x hidden_size x hidden_size)
# - FFN: batch_size x 2 /* add, mul */ x (sequence_length x hidden_size x (4 x hidden_size) + sequence_length x (4 x hidden_size) x hidden_size)
return 24 * batch_size * sequence_length * (self.hidden_size * self.hidden_size) + \
2 * batch_size * sequence_length * (self.hidden_size * sequence_length) + \
2 * batch_size * sequence_length * (sequence_length * self.hidden_size)
def flops_inference(self,
batch_size: int = None,
sequence_length: int = None,
generate_length: int = None):
if batch_size is None:
batch_size = self.micro_batch_size
if sequence_length is None:
sequence_length = self.sequence_length
# transformer blocks: foward
# embed layer forward: batch_size x 2 /* add, mul */ x (sequence_length x hidden_size x vocab_size)
transformer_flops = self.layers * self.flops_forward_per_layer(batch_size, sequence_length)
embedding_flops = 2 * batch_size * self.sequence_length * self.hidden_size * self.vocab_size
flops = transformer_flops + embedding_flops
return flops, self.pretty_print_number(flops)
def flops_training(self, batch_size: int = None):
if batch_size is None:
batch_size = self.micro_batch_size
# transformer blocks: forward + backward /* 2 x forward */ + recomputation
# emebding layer forward: batch_size x 2 /* add, mul */ x (sequence_length x hidden_size x vocab_size)
# emebding layer (no recomputation): forward + backward /* 2 x forward */
transformer_flops = self.layers * self.flops_forward_per_layer(batch_size, self.sequence_length)
embedding_flops = 2 * batch_size * self.sequence_length * self.hidden_size * self.vocab_size
flops = (1 + 2 + 1) * transformer_flops + (1 + 2) * embedding_flops
return flops, self.pretty_print_number(flops)
def flops_prefill(self,
batch_size: int = None,
sequence_length: int = None):
if batch_size is None:
batch_size = self.micro_batch_size
if sequence_length is None:
sequence_length = self.sequence_length
# forward:
# - kqv: batch_size x 3 x 2 /* add, mul */ x (sequence_length x hidden_size x hidden_size)
# - qk^T: batch_size x 2 /* add, mul */ x (sequence_length x hidden_size x sequence_length)
# - softmax(qk^T) * v: batch_size x 2 /* add, mul */ x (sequence_length x sequence_length x hidden_size)
# - post attention linear: batch_size x 2 /* add, mul */ x (sequence_length x hidden_size x hidden_size)
# - FFN: batch_size x 2 /* add, mul */ x (sequence_length x hidden_size x (4 x hidden_size) + sequence_length x (4 x hidden_size) x hidden_size)
flops = 6 * batch_size * sequence_length * (self.hidden_size * self.hidden_size) + \
2 * batch_size * sequence_length * (self.hidden_size * sequence_length) + \
2 * batch_size * sequence_length * (sequence_length * self.hidden_size) + \
2 * batch_size * sequence_length * (self.hidden_size * self.hidden_size) + \
2 * batch_size * sequence_length * (self.hidden_size * (4 * self.hidden_size) + (4 * self.hidden_size) * self.hidden_size)
return flops, self.pretty_print_number(flops)
def flops_generation(self,
batch_size: int = None,
sequence_length: int = None,
parallel_decoding_length: int = 1):
if batch_size is None:
batch_size = self.micro_batch_size
if sequence_length is None:
sequence_length = self.sequence_length
# forward:
# - attn: batch_size x 3 x 2 /* add, mul */ x (parallel_decoding_length x hidden_size x hidden_size)
# - qk^T: batch_size x 2 /* add, mul */ x (parallel_decoding_length x hidden_size x (sequence_length + parallel_decoding_length))
# - softmax(qk^T) * v: batch_size x 2 /* add, mul */ x (parallel_decoding_length x (sequence_length + parallel_decoding_length) x hidden_size)
# - post attention linear: batch_size x 2 /* add, mul */ x (parallel_decoding_length x hidden_size x hidden_size)
# - FFN: batch_size x 2 /* add, mul */ x (parallel_decoding_length x hidden_size x (4 x hidden_size) + parallel_decoding_length x (4 x hidden_size) x hidden_size)
flops = 6 * batch_size * (parallel_decoding_length * self.hidden_size * self.hidden_size) + \
2 * batch_size * (parallel_decoding_length * self.hidden_size * (sequence_length + parallel_decoding_length)) + \
2 * batch_size * (parallel_decoding_length * (sequence_length + parallel_decoding_length) * self.hidden_size) + \
2 * batch_size * (parallel_decoding_length * self.hidden_size * self.hidden_size) + \
2 * batch_size * (parallel_decoding_length * self.hidden_size * (4 * self.hidden_size) + parallel_decoding_length * (4 * self.hidden_size) * self.hidden_size)
return flops, self.pretty_print_number(flops)
def memory_io_prefill(self,
batch_size: int = None,
sequence_length: int = None):
if batch_size is None:
batch_size = self.micro_batch_size
if sequence_length is None:
sequence_length = self.sequence_length
# forward:
# - kqv: 3 x (batch_size x sequence_length x hidden_size + hidden_size x hidden_size + batch_size x sequence_length x hidden_size)
# - qk^T: batch_size x (sequence_length x hidden_size + hidden_size x sequence_length + sequence_length x sequence_length)
# - MHA: batch_size x attention_heads x (sequence_length x (hidden_size / attention_heads) + (hidden_size / attention_heads) x sequence_length + sequence_length x sequence_length)
# - softmax(qk^T) * v: batch_size x (sequence_length x sequence_length + sequence_length x hidden_size + sequence_length x hidden_size)
# - MHA: batch_size x attention_heads x (sequence_length x sequence_length + sequence_length x (hidden_size / attention_heads) + sequence_length x (hidden_size / attention_heads))
# - post attention linear: batch_size x sequence_length x hidden_size + hidden_size * hidden_size + batch_size x sequence_length * hidden_size
# - FFN: batch_size x sequence_length x hidden_size + hidden_size x (4 x hidden_size) + (4 x hidden_size) x hidden_size + batch_size x sequence_length x hidden_size
mem_io = 3 * (batch_size * sequence_length * self.hidden_size + self.hidden_size * self.hidden_size + batch_size * sequence_length * self.hidden_size) + \
batch_size * self.attention_heads * (sequence_length * (self.hidden_size / self.attention_heads) + (self.hidden_size / self.attention_heads) * sequence_length + sequence_length * sequence_length) + \
batch_size * self.attention_heads * (sequence_length * sequence_length + sequence_length * (self.hidden_size / self.attention_heads) + sequence_length * (self.hidden_size / self.attention_heads)) + \
batch_size * sequence_length * self.hidden_size + self.hidden_size * self.hidden_size + batch_size * sequence_length * self.hidden_size + \
batch_size * sequence_length * self.hidden_size + self.hidden_size * (4 * self.hidden_size) + (4 * self.hidden_size) * self.hidden_size + batch_size * sequence_length * self.hidden_size
return mem_io, self.pretty_print_memory(mem_io)
def memory_io_generation(self,
batch_size: int = None,
sequence_length: int = None,
parallel_decoding_length: int = 1):
if batch_size is None:
batch_size = self.micro_batch_size
if sequence_length is None:
sequence_length = self.sequence_length
# forward:
# - kqv: 3 x (batch_size * parallel_decoding_length x hidden_size + hidden_size x hidden_size + batch_size * parallel_decoding_length x hidden_size)
# - qk^T: batch_size x (parallel_decoding_length x hidden_size + hidden_size x (sequence_length + parallel_decoding_length) + parallel_decoding_length x (sequence_length + parallel_decoding_length))
# - MHA: batch_size x attention_heads x (parallel_decoding_length x (hidden_size / attention_heads) + (hidden_size / attention_heads) x (sequence_length + parallel_decoding_length) + parallel_decoding_length x (sequence_length + parallel_decoding_length))
# - softmax(qk^T) * v: batch_size x (parallel_decoding_length x (sequence_length + parallel_decoding_length) + (sequence_length + parallel_decoding_length) x hidden_size + parallel_decoding_length x hidden_size)
# - MHA: batch_size x attention_heads x (parallel_decoding_length x (sequence_length + parallel_decoding_length) + (sequence_length + parallel_decoding_length) x (hidden_size / attention_heads) + parallel_decoding_length x (hidden_size / attention_heads))
# - post attention linear: batch_size x parallel_decoding_length x hidden_size + hidden_size * hidden_size + batch_size x parallel_decoding_length * hidden_size
# - FFN: batch_size x parallel_decoding_length x hidden_size + hidden_size x (4 x hidden_size) + (4 x hidden_size) x hidden_size + batch_size * parallel_decoding_length x hidden_size
mem_io = 3 * (batch_size * parallel_decoding_length * self.hidden_size + self.hidden_size * self.hidden_size + batch_size * parallel_decoding_length * self.hidden_size) + \
batch_size * self.attention_heads * (parallel_decoding_length * (self.hidden_size / self.attention_heads) + (self.hidden_size / self.attention_heads) * (sequence_length + parallel_decoding_length) + parallel_decoding_length * (sequence_length + parallel_decoding_length)) + \
batch_size * self.attention_heads * (parallel_decoding_length * (sequence_length + parallel_decoding_length) + (sequence_length + parallel_decoding_length) * (self.hidden_size / self.attention_heads) + parallel_decoding_length * (self.hidden_size / self.attention_heads)) + \
batch_size * parallel_decoding_length * self.hidden_size + self.hidden_size * self.hidden_size + batch_size * parallel_decoding_length * self.hidden_size + \
batch_size * parallel_decoding_length * self.hidden_size + self.hidden_size * (4 * self.hidden_size) + (4 * self.hidden_size) * self.hidden_size + batch_size * parallel_decoding_length * self.hidden_size
return mem_io, self.pretty_print_memory(mem_io)
def training_time(
self,
number_of_tokens: int = 1e12,
batch_size: int = None,
number_of_gpus: int = 8,
empirical_throughput: int = 140 * 1e12,
activation_recomputation=False,
):
if batch_size is None:
batch_size = self.micro_batch_size
# see also:
#
# - megatron paper
# - https://blog.eleuther.ai/transformer-math/
# - https://zhuanlan.zhihu.com/p/624740065
if activation_recomputation:
factor = 8
else:
factor = 6
num_parameters, _ = self.number_of_parameters()
return (
factor
* number_of_tokens
* num_parameters
/ (number_of_gpus * empirical_throughput)
)
def activation_per_layer(
self,
batch_size: int = None,
sequence_length: int = None,
tensor_parallelism: int = 1,
sequence_parallelism: bool = False,
selective_activation_recomputation: bool = False,
full_activation_recomputation: bool = False,
):
# see also: Reducing Activation Recomputation in Large Transformer Models
if batch_size is None:
batch_size = self.micro_batch_size
if sequence_length is None:
sequence_length = self.sequence_length
assert not (selective_activation_recomputation and full_activation_recomputation)
if full_activation_recomputation:
factor = 2
else:
if sequence_parallelism and selective_activation_recomputation:
factor = 34 / tensor_parallelism
elif sequence_parallelism:
factor = 34 / tensor_parallelism + 5 * self.attention_heads * sequence_length / (self.hidden_size * tensor_parallelism)
elif selective_activation_recomputation:
factor = 10 + 24 / tensor_parallelism
else:
factor = 10 + 24 / tensor_parallelism + 5 * self.attention_heads * sequence_length / (self.hidden_size * tensor_parallelism)
mem = sequence_length * batch_size * self.hidden_size * factor
return mem, self.pretty_print_memory(mem)
def activation_training_first_stage(
self,
batch_size: int = None,
tensor_parallelism: int = 1,
sequence_parallelism: bool = False,
selective_activation_recomputation: bool = False,
pipeline_parallelism: int = 1,
pipeline_parallelism_interleaving: int = 1,
):
if batch_size is None:
batch_size = self.micro_batch_size
activation, _ = self.activation_per_layer(
batch_size,
self.sequence_length,
tensor_parallelism,
sequence_parallelism,
selective_activation_recomputation,
)
activation *= self.layers
scale = 1 + (pipeline_parallelism - 1) / (
pipeline_parallelism * pipeline_parallelism_interleaving
)
mem = activation * scale
return mem, self.pretty_print_memory(mem)
def activating_inference(self,
batch_size: int = None,
sequence_length: int = None,
generate_length: int = None):
if batch_size is None:
batch_size = self.micro_batch_size
if sequence_length is None:
sequence_length = self.sequence_length
activation, _ = self.activation_per_layer(batch_size, sequence_length)
mem = activation # don't multiple num_layers for inference
return mem, self.pretty_print_memory(mem)
megatron_gpt_1_7b = GPTModel(
name="megatron_gpt_1_7b",
vocab_size=51200,
sequence_length=2048,
attention_heads=24,
hidden_size=2304,
layers=24,
micro_batch_size=8,
global_batch_size=512,
)
megatron_gpt_3_6b = GPTModel(
name="megatron_gpt_3_6b",
vocab_size=51200,
sequence_length=2048,
attention_heads=32,
hidden_size=3072,
layers=30,
micro_batch_size=8,
global_batch_size=512,
)
megatron_gpt_7_5b = GPTModel(
name="megatron_gpt_7_5b",
vocab_size=51200,
sequence_length=2048,
attention_heads=32,
hidden_size=4096,
layers=36,
micro_batch_size=8,
global_batch_size=512,
)
megatron_gpt_76_1b = GPTModel(
name="megatron_gpt_76_1b",
vocab_size=51200,
sequence_length=2048,
attention_heads=80,
hidden_size=10240,
layers=60,
micro_batch_size=1,
global_batch_size=1792,
)
megatron_gpt_145_6b = GPTModel(
name="megatron_gpt_145_6b",
vocab_size=51200,
sequence_length=2048,
attention_heads=96,
hidden_size=12288,
layers=80,
micro_batch_size=1,
global_batch_size=2304,
)
megatron_1008_b = GPTModel(
name="megatron_1008_b",
vocab_size=51200,
sequence_length=2048,
attention_heads=160,
hidden_size=25600,
layers=128,
micro_batch_size=1,
global_batch_size=3072,
)
megatron_22_b = GPTModel(
name="megatron_22_b",
vocab_size=51200,
sequence_length=2048,
attention_heads=64,
hidden_size=6144,
layers=48,
micro_batch_size=1,
global_batch_size=4,
)
megatron_175_b = GPTModel(
name="megatron_175_b",
vocab_size=51200,
sequence_length=2048,
attention_heads=96,
hidden_size=12288,
layers=96,
micro_batch_size=1,
global_batch_size=64,
)
megatron_530_b = GPTModel(
name="megatron_530_b",
vocab_size=51200,
sequence_length=2048,
attention_heads=128,
hidden_size=20480,
layers=105,
micro_batch_size=1,
global_batch_size=280,
)
megatron_1000_b = GPTModel(
name="megatron_1000_b",
vocab_size=51200,
sequence_length=2048,
attention_heads=160,
hidden_size=25600,
layers=128,
micro_batch_size=1,
global_batch_size=512,
)
pipedream_gpt2 = GPTModel(
name="pipedream_gpt2",
vocab_size=50257,
sequence_length=1024,
attention_heads=16,
hidden_size=1024,
layers=24,
micro_batch_size=1,
global_batch_size=64,
)
llama_2_7_b = GPTModel(
name="llama-2-7B",
vocab_size=50257,
sequence_length=4096,
attention_heads=32,
hidden_size=4096,
layers=32,
micro_batch_size=1,
global_batch_size=64,
query_attention_groups=1,
)
llama_2_13_b = GPTModel(
name="llama-2-13B",
vocab_size=50257,
sequence_length=4096,
attention_heads=40,
hidden_size=5120,
layers=40,
micro_batch_size=1,
global_batch_size=64,
query_attention_groups=1,
)
llama_2_70_b = GPTModel(
name="llama-2-70B",
vocab_size=50257,
sequence_length=4096,
attention_heads=64,
hidden_size=8192,
layers=80,
micro_batch_size=1,
global_batch_size=64,
query_attention_groups=8,
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment