Created
January 14, 2024 06:00
-
-
Save sighingnow/87210c9c8bda04f12c85d9cfcadefe4e to your computer and use it in GitHub Desktop.
llm-arithmetic: flops, memory footprint, memory access I/O, and more for both training and inference (prefill and generation)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class GPTModel: | |
def __init__( | |
self, | |
name: str = None, | |
vocab_size: int = 51200, | |
sequence_length: int = 2048, | |
attention_heads: int = 32, | |
hidden_size: int = 2304, | |
layers: int = 24, | |
micro_batch_size: int = 1, | |
global_batch_size: int = 512, | |
batch_size: int = 1, | |
query_attention_groups: int = 1, | |
): | |
self.name = name | |
self.vocab_size = vocab_size | |
self.sequence_length = sequence_length | |
self.attention_heads = attention_heads | |
self.hidden_size = hidden_size | |
self.layers = layers | |
self.micro_batch_size = micro_batch_size | |
self.global_batch_size = global_batch_size | |
self.batch_size = batch_size | |
self.query_attention_groups = query_attention_groups | |
def pretty_print_number(self, number: int = 1): | |
return format(number, ",") | |
def pretty_print_memory(self, size: int = 1): | |
if size < 2**10: | |
return f"{size} B" | |
elif size < 2**20: | |
return f"{size / 2 ** 10:.2f} KB" | |
elif size < 2**30: | |
return f"{size / 2 ** 20:.2f} MB" | |
elif size < 2**40: | |
return f"{size / 2 ** 30:.2f} GB" | |
else: | |
return f"{size / 2 ** 40:.2f} TB" | |
def number_of_parameters(self): | |
# embeddings: hidden_size x (vocab_size + sequence_length) | |
# | |
# Wq, Wk, Wv: | |
# - weight: 3 x (hidden_size x hidden_size) | |
# - bias: 3 x hidden_size | |
# FFN: | |
# - weight: 2 x (4 x hidden_size x hidden_size) | |
# - bias: hidden_size + 4 * hidden_size | |
# Wo: | |
# - weight: hidden_size x hidden_size | |
# - bias: hidden_size | |
# | |
# input layer norm, post attention layer norm, 2 x (hidden_size x 2 /* weight & bias */)) | |
parameters = ( | |
12 * self.layers * (self.hidden_size**2) | |
+ self.layers * self.hidden_size * 13 | |
+ self.hidden_size * (self.vocab_size + self.sequence_length) | |
) | |
return parameters, self.pretty_print_number(parameters) | |
def kv_cache_per_token(self): | |
nbytes_fp16 = 2 | |
kv_cache = self.hidden_size // self.query_attention_groups * self.layers * nbytes_fp16 * 2 | |
return kv_cache, self.pretty_print_memory(kv_cache) | |
def kv_cache( | |
self, | |
batch_size: int = None, | |
sequence_length: int = None, | |
generation_length: int = None, | |
): | |
if batch_size is None: | |
batch_size = self.batch_size | |
context_length = self.sequence_length | |
if sequence_length is not None and generation_length is not None: | |
context_length = min(context_length, sequence_length + generation_length) | |
elif sequence_length is not None: | |
context_length = min(context_length, sequence_length) | |
elif generation_length is not None: | |
context_length = min(context_length, generation_length) | |
kv_cache_per_token, _ = self.kv_cache_per_token() | |
kv_cache = batch_size * context_length * kv_cache_per_token | |
return kv_cache, self.pretty_print_memory(kv_cache) | |
# For gemm: A (m x k) * B (k x n) = C (k x n) | |
# - flops: 2 * m * n * k | |
# - memory: m * k + k * n + m * n | |
def flops_forward_per_layer(self, batch_size: int = None, sequence_length: int = None): | |
if batch_size is None: | |
batch_size = self.micro_batch_size | |
if sequence_length is None: | |
sequence_length = self.sequence_length | |
# forward: | |
# - kqv: batch_size x 3 x 2 /* add, mul */ x (sequence_length x hidden_size x hidden_size) | |
# - qk^T: batch_size x 2 /* add, mul */ x (sequence_length x hidden_size x sequence_length) | |
# - softmax(qk^T) * v: batch_size x 2 /* add, mul */ x (sequence_length x sequence_length x hidden_size) | |
# - post attention linear: batch_size x 2 /* add, mul */ x (sequence_length x hidden_size x hidden_size) | |
# - FFN: batch_size x 2 /* add, mul */ x (sequence_length x hidden_size x (4 x hidden_size) + sequence_length x (4 x hidden_size) x hidden_size) | |
return 24 * batch_size * sequence_length * (self.hidden_size * self.hidden_size) + \ | |
2 * batch_size * sequence_length * (self.hidden_size * sequence_length) + \ | |
2 * batch_size * sequence_length * (sequence_length * self.hidden_size) | |
def flops_inference(self, | |
batch_size: int = None, | |
sequence_length: int = None, | |
generate_length: int = None): | |
if batch_size is None: | |
batch_size = self.micro_batch_size | |
if sequence_length is None: | |
sequence_length = self.sequence_length | |
# transformer blocks: foward | |
# embed layer forward: batch_size x 2 /* add, mul */ x (sequence_length x hidden_size x vocab_size) | |
transformer_flops = self.layers * self.flops_forward_per_layer(batch_size, sequence_length) | |
embedding_flops = 2 * batch_size * self.sequence_length * self.hidden_size * self.vocab_size | |
flops = transformer_flops + embedding_flops | |
return flops, self.pretty_print_number(flops) | |
def flops_training(self, batch_size: int = None): | |
if batch_size is None: | |
batch_size = self.micro_batch_size | |
# transformer blocks: forward + backward /* 2 x forward */ + recomputation | |
# emebding layer forward: batch_size x 2 /* add, mul */ x (sequence_length x hidden_size x vocab_size) | |
# emebding layer (no recomputation): forward + backward /* 2 x forward */ | |
transformer_flops = self.layers * self.flops_forward_per_layer(batch_size, self.sequence_length) | |
embedding_flops = 2 * batch_size * self.sequence_length * self.hidden_size * self.vocab_size | |
flops = (1 + 2 + 1) * transformer_flops + (1 + 2) * embedding_flops | |
return flops, self.pretty_print_number(flops) | |
def flops_prefill(self, | |
batch_size: int = None, | |
sequence_length: int = None): | |
if batch_size is None: | |
batch_size = self.micro_batch_size | |
if sequence_length is None: | |
sequence_length = self.sequence_length | |
# forward: | |
# - kqv: batch_size x 3 x 2 /* add, mul */ x (sequence_length x hidden_size x hidden_size) | |
# - qk^T: batch_size x 2 /* add, mul */ x (sequence_length x hidden_size x sequence_length) | |
# - softmax(qk^T) * v: batch_size x 2 /* add, mul */ x (sequence_length x sequence_length x hidden_size) | |
# - post attention linear: batch_size x 2 /* add, mul */ x (sequence_length x hidden_size x hidden_size) | |
# - FFN: batch_size x 2 /* add, mul */ x (sequence_length x hidden_size x (4 x hidden_size) + sequence_length x (4 x hidden_size) x hidden_size) | |
flops = 6 * batch_size * sequence_length * (self.hidden_size * self.hidden_size) + \ | |
2 * batch_size * sequence_length * (self.hidden_size * sequence_length) + \ | |
2 * batch_size * sequence_length * (sequence_length * self.hidden_size) + \ | |
2 * batch_size * sequence_length * (self.hidden_size * self.hidden_size) + \ | |
2 * batch_size * sequence_length * (self.hidden_size * (4 * self.hidden_size) + (4 * self.hidden_size) * self.hidden_size) | |
return flops, self.pretty_print_number(flops) | |
def flops_generation(self, | |
batch_size: int = None, | |
sequence_length: int = None, | |
parallel_decoding_length: int = 1): | |
if batch_size is None: | |
batch_size = self.micro_batch_size | |
if sequence_length is None: | |
sequence_length = self.sequence_length | |
# forward: | |
# - attn: batch_size x 3 x 2 /* add, mul */ x (parallel_decoding_length x hidden_size x hidden_size) | |
# - qk^T: batch_size x 2 /* add, mul */ x (parallel_decoding_length x hidden_size x (sequence_length + parallel_decoding_length)) | |
# - softmax(qk^T) * v: batch_size x 2 /* add, mul */ x (parallel_decoding_length x (sequence_length + parallel_decoding_length) x hidden_size) | |
# - post attention linear: batch_size x 2 /* add, mul */ x (parallel_decoding_length x hidden_size x hidden_size) | |
# - FFN: batch_size x 2 /* add, mul */ x (parallel_decoding_length x hidden_size x (4 x hidden_size) + parallel_decoding_length x (4 x hidden_size) x hidden_size) | |
flops = 6 * batch_size * (parallel_decoding_length * self.hidden_size * self.hidden_size) + \ | |
2 * batch_size * (parallel_decoding_length * self.hidden_size * (sequence_length + parallel_decoding_length)) + \ | |
2 * batch_size * (parallel_decoding_length * (sequence_length + parallel_decoding_length) * self.hidden_size) + \ | |
2 * batch_size * (parallel_decoding_length * self.hidden_size * self.hidden_size) + \ | |
2 * batch_size * (parallel_decoding_length * self.hidden_size * (4 * self.hidden_size) + parallel_decoding_length * (4 * self.hidden_size) * self.hidden_size) | |
return flops, self.pretty_print_number(flops) | |
def memory_io_prefill(self, | |
batch_size: int = None, | |
sequence_length: int = None): | |
if batch_size is None: | |
batch_size = self.micro_batch_size | |
if sequence_length is None: | |
sequence_length = self.sequence_length | |
# forward: | |
# - kqv: 3 x (batch_size x sequence_length x hidden_size + hidden_size x hidden_size + batch_size x sequence_length x hidden_size) | |
# - qk^T: batch_size x (sequence_length x hidden_size + hidden_size x sequence_length + sequence_length x sequence_length) | |
# - MHA: batch_size x attention_heads x (sequence_length x (hidden_size / attention_heads) + (hidden_size / attention_heads) x sequence_length + sequence_length x sequence_length) | |
# - softmax(qk^T) * v: batch_size x (sequence_length x sequence_length + sequence_length x hidden_size + sequence_length x hidden_size) | |
# - MHA: batch_size x attention_heads x (sequence_length x sequence_length + sequence_length x (hidden_size / attention_heads) + sequence_length x (hidden_size / attention_heads)) | |
# - post attention linear: batch_size x sequence_length x hidden_size + hidden_size * hidden_size + batch_size x sequence_length * hidden_size | |
# - FFN: batch_size x sequence_length x hidden_size + hidden_size x (4 x hidden_size) + (4 x hidden_size) x hidden_size + batch_size x sequence_length x hidden_size | |
mem_io = 3 * (batch_size * sequence_length * self.hidden_size + self.hidden_size * self.hidden_size + batch_size * sequence_length * self.hidden_size) + \ | |
batch_size * self.attention_heads * (sequence_length * (self.hidden_size / self.attention_heads) + (self.hidden_size / self.attention_heads) * sequence_length + sequence_length * sequence_length) + \ | |
batch_size * self.attention_heads * (sequence_length * sequence_length + sequence_length * (self.hidden_size / self.attention_heads) + sequence_length * (self.hidden_size / self.attention_heads)) + \ | |
batch_size * sequence_length * self.hidden_size + self.hidden_size * self.hidden_size + batch_size * sequence_length * self.hidden_size + \ | |
batch_size * sequence_length * self.hidden_size + self.hidden_size * (4 * self.hidden_size) + (4 * self.hidden_size) * self.hidden_size + batch_size * sequence_length * self.hidden_size | |
return mem_io, self.pretty_print_memory(mem_io) | |
def memory_io_generation(self, | |
batch_size: int = None, | |
sequence_length: int = None, | |
parallel_decoding_length: int = 1): | |
if batch_size is None: | |
batch_size = self.micro_batch_size | |
if sequence_length is None: | |
sequence_length = self.sequence_length | |
# forward: | |
# - kqv: 3 x (batch_size * parallel_decoding_length x hidden_size + hidden_size x hidden_size + batch_size * parallel_decoding_length x hidden_size) | |
# - qk^T: batch_size x (parallel_decoding_length x hidden_size + hidden_size x (sequence_length + parallel_decoding_length) + parallel_decoding_length x (sequence_length + parallel_decoding_length)) | |
# - MHA: batch_size x attention_heads x (parallel_decoding_length x (hidden_size / attention_heads) + (hidden_size / attention_heads) x (sequence_length + parallel_decoding_length) + parallel_decoding_length x (sequence_length + parallel_decoding_length)) | |
# - softmax(qk^T) * v: batch_size x (parallel_decoding_length x (sequence_length + parallel_decoding_length) + (sequence_length + parallel_decoding_length) x hidden_size + parallel_decoding_length x hidden_size) | |
# - MHA: batch_size x attention_heads x (parallel_decoding_length x (sequence_length + parallel_decoding_length) + (sequence_length + parallel_decoding_length) x (hidden_size / attention_heads) + parallel_decoding_length x (hidden_size / attention_heads)) | |
# - post attention linear: batch_size x parallel_decoding_length x hidden_size + hidden_size * hidden_size + batch_size x parallel_decoding_length * hidden_size | |
# - FFN: batch_size x parallel_decoding_length x hidden_size + hidden_size x (4 x hidden_size) + (4 x hidden_size) x hidden_size + batch_size * parallel_decoding_length x hidden_size | |
mem_io = 3 * (batch_size * parallel_decoding_length * self.hidden_size + self.hidden_size * self.hidden_size + batch_size * parallel_decoding_length * self.hidden_size) + \ | |
batch_size * self.attention_heads * (parallel_decoding_length * (self.hidden_size / self.attention_heads) + (self.hidden_size / self.attention_heads) * (sequence_length + parallel_decoding_length) + parallel_decoding_length * (sequence_length + parallel_decoding_length)) + \ | |
batch_size * self.attention_heads * (parallel_decoding_length * (sequence_length + parallel_decoding_length) + (sequence_length + parallel_decoding_length) * (self.hidden_size / self.attention_heads) + parallel_decoding_length * (self.hidden_size / self.attention_heads)) + \ | |
batch_size * parallel_decoding_length * self.hidden_size + self.hidden_size * self.hidden_size + batch_size * parallel_decoding_length * self.hidden_size + \ | |
batch_size * parallel_decoding_length * self.hidden_size + self.hidden_size * (4 * self.hidden_size) + (4 * self.hidden_size) * self.hidden_size + batch_size * parallel_decoding_length * self.hidden_size | |
return mem_io, self.pretty_print_memory(mem_io) | |
def training_time( | |
self, | |
number_of_tokens: int = 1e12, | |
batch_size: int = None, | |
number_of_gpus: int = 8, | |
empirical_throughput: int = 140 * 1e12, | |
activation_recomputation=False, | |
): | |
if batch_size is None: | |
batch_size = self.micro_batch_size | |
# see also: | |
# | |
# - megatron paper | |
# - https://blog.eleuther.ai/transformer-math/ | |
# - https://zhuanlan.zhihu.com/p/624740065 | |
if activation_recomputation: | |
factor = 8 | |
else: | |
factor = 6 | |
num_parameters, _ = self.number_of_parameters() | |
return ( | |
factor | |
* number_of_tokens | |
* num_parameters | |
/ (number_of_gpus * empirical_throughput) | |
) | |
def activation_per_layer( | |
self, | |
batch_size: int = None, | |
sequence_length: int = None, | |
tensor_parallelism: int = 1, | |
sequence_parallelism: bool = False, | |
selective_activation_recomputation: bool = False, | |
full_activation_recomputation: bool = False, | |
): | |
# see also: Reducing Activation Recomputation in Large Transformer Models | |
if batch_size is None: | |
batch_size = self.micro_batch_size | |
if sequence_length is None: | |
sequence_length = self.sequence_length | |
assert not (selective_activation_recomputation and full_activation_recomputation) | |
if full_activation_recomputation: | |
factor = 2 | |
else: | |
if sequence_parallelism and selective_activation_recomputation: | |
factor = 34 / tensor_parallelism | |
elif sequence_parallelism: | |
factor = 34 / tensor_parallelism + 5 * self.attention_heads * sequence_length / (self.hidden_size * tensor_parallelism) | |
elif selective_activation_recomputation: | |
factor = 10 + 24 / tensor_parallelism | |
else: | |
factor = 10 + 24 / tensor_parallelism + 5 * self.attention_heads * sequence_length / (self.hidden_size * tensor_parallelism) | |
mem = sequence_length * batch_size * self.hidden_size * factor | |
return mem, self.pretty_print_memory(mem) | |
def activation_training_first_stage( | |
self, | |
batch_size: int = None, | |
tensor_parallelism: int = 1, | |
sequence_parallelism: bool = False, | |
selective_activation_recomputation: bool = False, | |
pipeline_parallelism: int = 1, | |
pipeline_parallelism_interleaving: int = 1, | |
): | |
if batch_size is None: | |
batch_size = self.micro_batch_size | |
activation, _ = self.activation_per_layer( | |
batch_size, | |
self.sequence_length, | |
tensor_parallelism, | |
sequence_parallelism, | |
selective_activation_recomputation, | |
) | |
activation *= self.layers | |
scale = 1 + (pipeline_parallelism - 1) / ( | |
pipeline_parallelism * pipeline_parallelism_interleaving | |
) | |
mem = activation * scale | |
return mem, self.pretty_print_memory(mem) | |
def activating_inference(self, | |
batch_size: int = None, | |
sequence_length: int = None, | |
generate_length: int = None): | |
if batch_size is None: | |
batch_size = self.micro_batch_size | |
if sequence_length is None: | |
sequence_length = self.sequence_length | |
activation, _ = self.activation_per_layer(batch_size, sequence_length) | |
mem = activation # don't multiple num_layers for inference | |
return mem, self.pretty_print_memory(mem) | |
megatron_gpt_1_7b = GPTModel( | |
name="megatron_gpt_1_7b", | |
vocab_size=51200, | |
sequence_length=2048, | |
attention_heads=24, | |
hidden_size=2304, | |
layers=24, | |
micro_batch_size=8, | |
global_batch_size=512, | |
) | |
megatron_gpt_3_6b = GPTModel( | |
name="megatron_gpt_3_6b", | |
vocab_size=51200, | |
sequence_length=2048, | |
attention_heads=32, | |
hidden_size=3072, | |
layers=30, | |
micro_batch_size=8, | |
global_batch_size=512, | |
) | |
megatron_gpt_7_5b = GPTModel( | |
name="megatron_gpt_7_5b", | |
vocab_size=51200, | |
sequence_length=2048, | |
attention_heads=32, | |
hidden_size=4096, | |
layers=36, | |
micro_batch_size=8, | |
global_batch_size=512, | |
) | |
megatron_gpt_76_1b = GPTModel( | |
name="megatron_gpt_76_1b", | |
vocab_size=51200, | |
sequence_length=2048, | |
attention_heads=80, | |
hidden_size=10240, | |
layers=60, | |
micro_batch_size=1, | |
global_batch_size=1792, | |
) | |
megatron_gpt_145_6b = GPTModel( | |
name="megatron_gpt_145_6b", | |
vocab_size=51200, | |
sequence_length=2048, | |
attention_heads=96, | |
hidden_size=12288, | |
layers=80, | |
micro_batch_size=1, | |
global_batch_size=2304, | |
) | |
megatron_1008_b = GPTModel( | |
name="megatron_1008_b", | |
vocab_size=51200, | |
sequence_length=2048, | |
attention_heads=160, | |
hidden_size=25600, | |
layers=128, | |
micro_batch_size=1, | |
global_batch_size=3072, | |
) | |
megatron_22_b = GPTModel( | |
name="megatron_22_b", | |
vocab_size=51200, | |
sequence_length=2048, | |
attention_heads=64, | |
hidden_size=6144, | |
layers=48, | |
micro_batch_size=1, | |
global_batch_size=4, | |
) | |
megatron_175_b = GPTModel( | |
name="megatron_175_b", | |
vocab_size=51200, | |
sequence_length=2048, | |
attention_heads=96, | |
hidden_size=12288, | |
layers=96, | |
micro_batch_size=1, | |
global_batch_size=64, | |
) | |
megatron_530_b = GPTModel( | |
name="megatron_530_b", | |
vocab_size=51200, | |
sequence_length=2048, | |
attention_heads=128, | |
hidden_size=20480, | |
layers=105, | |
micro_batch_size=1, | |
global_batch_size=280, | |
) | |
megatron_1000_b = GPTModel( | |
name="megatron_1000_b", | |
vocab_size=51200, | |
sequence_length=2048, | |
attention_heads=160, | |
hidden_size=25600, | |
layers=128, | |
micro_batch_size=1, | |
global_batch_size=512, | |
) | |
pipedream_gpt2 = GPTModel( | |
name="pipedream_gpt2", | |
vocab_size=50257, | |
sequence_length=1024, | |
attention_heads=16, | |
hidden_size=1024, | |
layers=24, | |
micro_batch_size=1, | |
global_batch_size=64, | |
) | |
llama_2_7_b = GPTModel( | |
name="llama-2-7B", | |
vocab_size=50257, | |
sequence_length=4096, | |
attention_heads=32, | |
hidden_size=4096, | |
layers=32, | |
micro_batch_size=1, | |
global_batch_size=64, | |
query_attention_groups=1, | |
) | |
llama_2_13_b = GPTModel( | |
name="llama-2-13B", | |
vocab_size=50257, | |
sequence_length=4096, | |
attention_heads=40, | |
hidden_size=5120, | |
layers=40, | |
micro_batch_size=1, | |
global_batch_size=64, | |
query_attention_groups=1, | |
) | |
llama_2_70_b = GPTModel( | |
name="llama-2-70B", | |
vocab_size=50257, | |
sequence_length=4096, | |
attention_heads=64, | |
hidden_size=8192, | |
layers=80, | |
micro_batch_size=1, | |
global_batch_size=64, | |
query_attention_groups=8, | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment