Skip to content

Instantly share code, notes, and snippets.

@jaymody
Created September 24, 2022 00:53
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jaymody/ed8213f5779892112cb39022e9b83133 to your computer and use it in GitHub Desktop.
Save jaymody/ed8213f5779892112cb39022e9b83133 to your computer and use it in GitHub Desktop.
Parameter counting for the original transformer architecture from Attention is All You Need.
import math
def count_params(
N=6,
d_model=512,
d_ff=2048,
h=8,
d_k=64,
d_v=64,
vocab_size=37000,
):
src_emb_count = vocab_size * d_model
trg_emb_count = vocab_size * d_model
WQ = [h, d_model, d_k]
WK = [h, d_model, d_k]
WV = [h, d_model, d_v]
WO = [h * d_v, d_model]
mh_attn_count = sum(map(math.prod, [WQ, WK, WV, WO]))
ffn_count = (d_model * d_ff + d_ff) + (d_ff * d_model + d_model)
layer_norm_count = d_model + d_model
enc_layer_count = 1*mh_attn_count + 1*ffn_count + 2*layer_norm_count
dec_layer_count = 2*mh_attn_count + 1*ffn_count + 3*layer_norm_count
enc_stack_count = enc_layer_count * N
dec_stack_count = dec_layer_count * N
final_layer_count = d_model * vocab_size + vocab_size
return src_emb_count + trg_emb_count + enc_stack_count + dec_stack_count + final_layer_count
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment