Skip to content

Instantly share code, notes, and snippets.

@pacman100
Created October 4, 2022 13:16
Show Gist options
  • Save pacman100/eedce29f084f3efdac76456bd407f978 to your computer and use it in GitHub Desktop.
Save pacman100/eedce29f084f3efdac76456bd407f978 to your computer and use it in GitHub Desktop.
#checking if conversion is correct
code:
import sys
import torch
from megatron import get_args, get_tokenizer, initialize_megatron, mpu
from megatron.model import GPTModel
from megatron.training import get_model
from megatron.checkpointing import load_checkpoint
from megatron.utils import get_ltor_masks_and_position_ids
from transformers import GPT2LMHeadModel, AutoTokenizer
from tokenizers import ByteLevelBPETokenizer
trfs_path = "/home/sourab/megatron_lm_gpt/hf_checkpoint"
mlm_path = "/home/sourab/megatron_lm_gpt/megatron_lm_checkpoint"
def initialize():
model_path = mlm_path
sys.argv.extend(
[
"--num-layers", "12",
"--hidden-size", "768",
"--num-attention-heads", "12",
"--seq-length", "1024",
"--max-position-embeddings", "1024",
"--bf16",
"--load", str(model_path),
"--micro-batch-size", "1",
"--checkpoint-activations",
"--no-scaled-masked-softmax-fusion",
"--no-load-rng",
"--no-load-optim"
]
)
initialize_megatron(ignore_unknown_args=True)
if __name__ == "__main__":
if mpu.is_unitialized():
initialize()
args = get_args()
args.padded_vocab_size = 50432#50304
tokenizer = AutoTokenizer.from_pretrained(trfs_path)
def model_provider(pre_process=True, post_process=True):
return GPTModel(num_tokentypes=0, parallel_output=False,
pre_process=True, post_process=True)
model = get_model(model_provider)
load_checkpoint(model, None, None)
model = model[0]
model.eval()
inputs = "Hi, how are you doing?"
input_ids = torch.tensor(tokenizer.encode(inputs)).unsqueeze(0)
attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
input_ids,
0,
reset_position_ids=False,
reset_attention_mask=False,
eod_mask_loss=False,
)
input_ids = input_ids.cuda()
position_ids = position_ids.cuda()
attention_mask = attention_mask.cuda()
logits = model(input_ids, position_ids, attention_mask)
model2 = GPT2LMHeadModel.from_pretrained(trfs_path)
# input_ids = torch.tensor(tokenizer.encode("Hi, how are you doing?")).unsqueeze(0).cuda()
model = model2.cuda()
model2.eval()
out = model2(input_ids)
for j in range(out.logits.shape[1]):
for i in range(out.logits.shape[2]):
a, b= out.logits[0,j,i].item(), logits[0,j,i].item()
assert(abs(a-b) / max(max(abs(a),abs(b)), 0.5) < 0.1)
print(logits)
print(out.logits)
Command: torchrun --nnodes 1 --nproc_per_node 2 conversion_checkpoint_test.py \
--tensor-model-parallel-size 2
Output:
> number of parameters on (tensor, pipeline) model parallel rank (1, 0): 62708736
> number of parameters on (tensor, pipeline) model parallel rank (0, 0): 62708736
loading release checkpoint from /home/sourab/megatron_lm_gpt/megatron_lm_checkpoint
checkpoint version 3.0
successfully loaded checkpoint from /home/sourab/megatron_lm_gpt/megatron_lm_checkpoint at iteration 0
tensor([[[ 1.1406, 1.8672, 2.7031, ..., -7.1250, -7.1250, -7.1250],
[-0.4043, -0.2793, -0.9531, ..., -5.7188, -5.7188, -5.7188],
[ 0.4668, 1.7031, 1.9375, ..., -6.5312, -6.5312, -6.5312],
...,
[ 0.5664, 2.5625, 3.2188, ..., -5.3125, -5.3125, -5.3125],
[-0.1152, 1.6562, 3.6094, ..., -6.3125, -6.3125, -6.3125],
[ 3.8125, 3.1875, 3.3281, ..., -6.8125, -6.8125, -6.8125]]],
device='cuda:0', grad_fn=<ToCopyBackward0>)
tensor([[[ 1.1352, 1.8615, 2.7035, ..., -3.2931, -3.9886, -7.1244],
[-0.4087, -0.2925, -0.9624, ..., -3.4504, -2.9110, -5.7412],
[ 0.4740, 1.7097, 1.9377, ..., -4.9377, -5.9048, -6.5181],
...,
[ 0.5715, 2.5654, 3.2345, ..., -3.7216, -5.5251, -5.3296],
[-0.1051, 1.6484, 3.6077, ..., -3.4981, -6.4280, -6.3206],
[ 3.8139, 3.1724, 3.3257, ..., -3.3109, -5.6192, -6.8090]]],
device='cuda:0', grad_fn=<UnsafeViewBackward0>)
tensor([[[ 1.1406, 1.8672, 2.7031, ..., -7.1250, -7.1250, -7.1250],
[-0.4043, -0.2793, -0.9531, ..., -5.7188, -5.7188, -5.7188],
[ 0.4668, 1.7031, 1.9375, ..., -6.5312, -6.5312, -6.5312],
...,
[ 0.5664, 2.5625, 3.2188, ..., -5.3125, -5.3125, -5.3125],
[-0.1152, 1.6562, 3.6094, ..., -6.3125, -6.3125, -6.3125],
[ 3.8125, 3.1875, 3.3281, ..., -6.8125, -6.8125, -6.8125]]],
device='cuda:1', grad_fn=<ToCopyBackward0>)
tensor([[[ 1.1352, 1.8615, 2.7035, ..., -3.2931, -3.9886, -7.1244],
[-0.4087, -0.2925, -0.9624, ..., -3.4504, -2.9110, -5.7412],
[ 0.4740, 1.7097, 1.9377, ..., -4.9377, -5.9048, -6.5181],
...,
[ 0.5715, 2.5654, 3.2345, ..., -3.7216, -5.5251, -5.3296],
[-0.1051, 1.6484, 3.6077, ..., -3.4981, -6.4280, -6.3206],
[ 3.8139, 3.1724, 3.3257, ..., -3.3109, -5.6192, -6.8090]]],
device='cuda:1', grad_fn=<UnsafeViewBackward0>)
Loading Megatron-LM checkpoint arguments from: /home/sourab/megatron_lm_gpt/iter_0005000/mp_rank_00/model_rng.pt
50257
Converting
Converting embeddings
Converting transformer layers
Converting pipeline parallel rank 0
Converting final layernorm
Converting LM head
Conversion from Megatron-LM to Transformers is done!
# transformer.wpe.weight : torch.Size([1024, 768])
# transformer.wte.weight : torch.Size([50257, 768])
# transformer.h.0.ln_1.weight : torch.Size([768])
# transformer.h.0.ln_1.bias : torch.Size([768])
# transformer.h.0.attn.bias : torch.Size([1, 1, 1024, 1024])
# transformer.h.0.attn.masked_bias : torch.Size([])
# transformer.h.0.attn.c_attn.weight : torch.Size([768, 2304])
# transformer.h.0.attn.c_attn.bias : torch.Size([2304])
# transformer.h.0.attn.c_proj.weight : torch.Size([768, 768])
# transformer.h.0.attn.c_proj.bias : torch.Size([768])
# transformer.h.0.ln_2.weight : torch.Size([768])
# transformer.h.0.ln_2.bias : torch.Size([768])
# transformer.h.0.mlp.c_fc.weight : torch.Size([768, 3072])
# transformer.h.0.mlp.c_fc.bias : torch.Size([3072])
# transformer.h.0.mlp.c_proj.weight : torch.Size([3072, 768])
# transformer.h.0.mlp.c_proj.bias : torch.Size([768])
# transformer.h.1.ln_1.weight : torch.Size([768])
# transformer.h.1.ln_1.bias : torch.Size([768])
# transformer.h.1.attn.bias : torch.Size([1, 1, 1024, 1024])
# transformer.h.1.attn.masked_bias : torch.Size([])
# transformer.h.1.attn.c_attn.weight : torch.Size([768, 2304])
# transformer.h.1.attn.c_attn.bias : torch.Size([2304])
# transformer.h.1.attn.c_proj.weight : torch.Size([768, 768])
# transformer.h.1.attn.c_proj.bias : torch.Size([768])
# transformer.h.1.ln_2.weight : torch.Size([768])
# transformer.h.1.ln_2.bias : torch.Size([768])
# transformer.h.1.mlp.c_fc.weight : torch.Size([768, 3072])
# transformer.h.1.mlp.c_fc.bias : torch.Size([3072])
# transformer.h.1.mlp.c_proj.weight : torch.Size([3072, 768])
# transformer.h.1.mlp.c_proj.bias : torch.Size([768])
# transformer.h.2.ln_1.weight : torch.Size([768])
# transformer.h.2.ln_1.bias : torch.Size([768])
# transformer.h.2.attn.bias : torch.Size([1, 1, 1024, 1024])
# transformer.h.2.attn.masked_bias : torch.Size([])
# transformer.h.2.attn.c_attn.weight : torch.Size([768, 2304])
# transformer.h.2.attn.c_attn.bias : torch.Size([2304])
# transformer.h.2.attn.c_proj.weight : torch.Size([768, 768])
# transformer.h.2.attn.c_proj.bias : torch.Size([768])
# transformer.h.2.ln_2.weight : torch.Size([768])
# transformer.h.2.ln_2.bias : torch.Size([768])
# transformer.h.2.mlp.c_fc.weight : torch.Size([768, 3072])
# transformer.h.2.mlp.c_fc.bias : torch.Size([3072])
# transformer.h.2.mlp.c_proj.weight : torch.Size([3072, 768])
# transformer.h.2.mlp.c_proj.bias : torch.Size([768])
# transformer.h.3.ln_1.weight : torch.Size([768])
# transformer.h.3.ln_1.bias : torch.Size([768])
# transformer.h.3.attn.bias : torch.Size([1, 1, 1024, 1024])
# transformer.h.3.attn.masked_bias : torch.Size([])
# transformer.h.3.attn.c_attn.weight : torch.Size([768, 2304])
# transformer.h.3.attn.c_attn.bias : torch.Size([2304])
# transformer.h.3.attn.c_proj.weight : torch.Size([768, 768])
# transformer.h.3.attn.c_proj.bias : torch.Size([768])
# transformer.h.3.ln_2.weight : torch.Size([768])
# transformer.h.3.ln_2.bias : torch.Size([768])
# transformer.h.3.mlp.c_fc.weight : torch.Size([768, 3072])
# transformer.h.3.mlp.c_fc.bias : torch.Size([3072])
# transformer.h.3.mlp.c_proj.weight : torch.Size([3072, 768])
# transformer.h.3.mlp.c_proj.bias : torch.Size([768])
# transformer.h.4.ln_1.weight : torch.Size([768])
# transformer.h.4.ln_1.bias : torch.Size([768])
# transformer.h.4.attn.bias : torch.Size([1, 1, 1024, 1024])
# transformer.h.4.attn.masked_bias : torch.Size([])
# transformer.h.4.attn.c_attn.weight : torch.Size([768, 2304])
# transformer.h.4.attn.c_attn.bias : torch.Size([2304])
# transformer.h.4.attn.c_proj.weight : torch.Size([768, 768])
# transformer.h.4.attn.c_proj.bias : torch.Size([768])
# transformer.h.4.ln_2.weight : torch.Size([768])
# transformer.h.4.ln_2.bias : torch.Size([768])
# transformer.h.4.mlp.c_fc.weight : torch.Size([768, 3072])
# transformer.h.4.mlp.c_fc.bias : torch.Size([3072])
# transformer.h.4.mlp.c_proj.weight : torch.Size([3072, 768])
# transformer.h.4.mlp.c_proj.bias : torch.Size([768])
# transformer.h.5.ln_1.weight : torch.Size([768])
# transformer.h.5.ln_1.bias : torch.Size([768])
# transformer.h.5.attn.bias : torch.Size([1, 1, 1024, 1024])
# transformer.h.5.attn.masked_bias : torch.Size([])
# transformer.h.5.attn.c_attn.weight : torch.Size([768, 2304])
# transformer.h.5.attn.c_attn.bias : torch.Size([2304])
# transformer.h.5.attn.c_proj.weight : torch.Size([768, 768])
# transformer.h.5.attn.c_proj.bias : torch.Size([768])
# transformer.h.5.ln_2.weight : torch.Size([768])
# transformer.h.5.ln_2.bias : torch.Size([768])
# transformer.h.5.mlp.c_fc.weight : torch.Size([768, 3072])
# transformer.h.5.mlp.c_fc.bias : torch.Size([3072])
# transformer.h.5.mlp.c_proj.weight : torch.Size([3072, 768])
# transformer.h.5.mlp.c_proj.bias : torch.Size([768])
# transformer.h.6.ln_1.weight : torch.Size([768])
# transformer.h.6.ln_1.bias : torch.Size([768])
# transformer.h.6.attn.bias : torch.Size([1, 1, 1024, 1024])
# transformer.h.6.attn.masked_bias : torch.Size([])
# transformer.h.6.attn.c_attn.weight : torch.Size([768, 2304])
# transformer.h.6.attn.c_attn.bias : torch.Size([2304])
# transformer.h.6.attn.c_proj.weight : torch.Size([768, 768])
# transformer.h.6.attn.c_proj.bias : torch.Size([768])
# transformer.h.6.ln_2.weight : torch.Size([768])
# transformer.h.6.ln_2.bias : torch.Size([768])
# transformer.h.6.mlp.c_fc.weight : torch.Size([768, 3072])
# transformer.h.6.mlp.c_fc.bias : torch.Size([3072])
# transformer.h.6.mlp.c_proj.weight : torch.Size([3072, 768])
# transformer.h.6.mlp.c_proj.bias : torch.Size([768])
# transformer.h.7.ln_1.weight : torch.Size([768])
# transformer.h.7.ln_1.bias : torch.Size([768])
# transformer.h.7.attn.bias : torch.Size([1, 1, 1024, 1024])
# transformer.h.7.attn.masked_bias : torch.Size([])
# transformer.h.7.attn.c_attn.weight : torch.Size([768, 2304])
# transformer.h.7.attn.c_attn.bias : torch.Size([2304])
# transformer.h.7.attn.c_proj.weight : torch.Size([768, 768])
# transformer.h.7.attn.c_proj.bias : torch.Size([768])
# transformer.h.7.ln_2.weight : torch.Size([768])
# transformer.h.7.ln_2.bias : torch.Size([768])
# transformer.h.7.mlp.c_fc.weight : torch.Size([768, 3072])
# transformer.h.7.mlp.c_fc.bias : torch.Size([3072])
# transformer.h.7.mlp.c_proj.weight : torch.Size([3072, 768])
# transformer.h.7.mlp.c_proj.bias : torch.Size([768])
# transformer.h.8.ln_1.weight : torch.Size([768])
# transformer.h.8.ln_1.bias : torch.Size([768])
# transformer.h.8.attn.bias : torch.Size([1, 1, 1024, 1024])
# transformer.h.8.attn.masked_bias : torch.Size([])
# transformer.h.8.attn.c_attn.weight : torch.Size([768, 2304])
# transformer.h.8.attn.c_attn.bias : torch.Size([2304])
# transformer.h.8.attn.c_proj.weight : torch.Size([768, 768])
# transformer.h.8.attn.c_proj.bias : torch.Size([768])
# transformer.h.8.ln_2.weight : torch.Size([768])
# transformer.h.8.ln_2.bias : torch.Size([768])
# transformer.h.8.mlp.c_fc.weight : torch.Size([768, 3072])
# transformer.h.8.mlp.c_fc.bias : torch.Size([3072])
# transformer.h.8.mlp.c_proj.weight : torch.Size([3072, 768])
# transformer.h.8.mlp.c_proj.bias : torch.Size([768])
# transformer.h.9.ln_1.weight : torch.Size([768])
# transformer.h.9.ln_1.bias : torch.Size([768])
# transformer.h.9.attn.bias : torch.Size([1, 1, 1024, 1024])
# transformer.h.9.attn.masked_bias : torch.Size([])
# transformer.h.9.attn.c_attn.weight : torch.Size([768, 2304])
# transformer.h.9.attn.c_attn.bias : torch.Size([2304])
# transformer.h.9.attn.c_proj.weight : torch.Size([768, 768])
# transformer.h.9.attn.c_proj.bias : torch.Size([768])
# transformer.h.9.ln_2.weight : torch.Size([768])
# transformer.h.9.ln_2.bias : torch.Size([768])
# transformer.h.9.mlp.c_fc.weight : torch.Size([768, 3072])
# transformer.h.9.mlp.c_fc.bias : torch.Size([3072])
# transformer.h.9.mlp.c_proj.weight : torch.Size([3072, 768])
# transformer.h.9.mlp.c_proj.bias : torch.Size([768])
# transformer.h.10.ln_1.weight : torch.Size([768])
# transformer.h.10.ln_1.bias : torch.Size([768])
# transformer.h.10.attn.bias : torch.Size([1, 1, 1024, 1024])
# transformer.h.10.attn.masked_bias : torch.Size([])
# transformer.h.10.attn.c_attn.weight : torch.Size([768, 2304])
# transformer.h.10.attn.c_attn.bias : torch.Size([2304])
# transformer.h.10.attn.c_proj.weight : torch.Size([768, 768])
# transformer.h.10.attn.c_proj.bias : torch.Size([768])
# transformer.h.10.ln_2.weight : torch.Size([768])
# transformer.h.10.ln_2.bias : torch.Size([768])
# transformer.h.10.mlp.c_fc.weight : torch.Size([768, 3072])
# transformer.h.10.mlp.c_fc.bias : torch.Size([3072])
# transformer.h.10.mlp.c_proj.weight : torch.Size([3072, 768])
# transformer.h.10.mlp.c_proj.bias : torch.Size([768])
# transformer.h.11.ln_1.weight : torch.Size([768])
# transformer.h.11.ln_1.bias : torch.Size([768])
# transformer.h.11.attn.bias : torch.Size([1, 1, 1024, 1024])
# transformer.h.11.attn.masked_bias : torch.Size([])
# transformer.h.11.attn.c_attn.weight : torch.Size([768, 2304])
# transformer.h.11.attn.c_attn.bias : torch.Size([2304])
# transformer.h.11.attn.c_proj.weight : torch.Size([768, 768])
# transformer.h.11.attn.c_proj.bias : torch.Size([768])
# transformer.h.11.ln_2.weight : torch.Size([768])
# transformer.h.11.ln_2.bias : torch.Size([768])
# transformer.h.11.mlp.c_fc.weight : torch.Size([768, 3072])
# transformer.h.11.mlp.c_fc.bias : torch.Size([3072])
# transformer.h.11.mlp.c_proj.weight : torch.Size([3072, 768])
# transformer.h.11.mlp.c_proj.bias : torch.Size([768])
# transformer.ln_f.weight : torch.Size([768])
# transformer.ln_f.bias : torch.Size([768])
# lm_head.weight : torch.Size([50257, 768])
Saving config
Adding GPT2TokenizerFast tokenizer files
The model is bigger than the maximum size per checkpoint (200MB) and is going to be split in 4 checkpoint shards. You can find where each parameters has been saved in the index located at /home/sourab/megatron_lm_gpt/hf_checkpoint/pytorch_model.bin.index.json.
Converting
converting embedding layer
> padded vocab (size: 50257) with 175 dummy tokens (new size: 50432)
converting transformer layers
Checkpoint structure of model state dict shard belonging to TP rank 0 and PP rank 0:
# model
..# language_model
....# embedding
......# position_embeddings
........# weight : torch.Size([1024, 768])
......# word_embeddings
........# weight : torch.Size([25216, 768])
....# encoder
......# layers.0.input_layernorm.weight : torch.Size([768])
......# layers.0.input_layernorm.bias : torch.Size([768])
......# layers.0.self_attention.query_key_value.weight : torch.Size([1152, 768])
......# layers.0.self_attention.query_key_value.bias : torch.Size([1152])
......# layers.0.self_attention.dense.weight : torch.Size([768, 384])
......# layers.0.self_attention.dense.bias : torch.Size([768])
......# layers.0.post_attention_layernorm.weight : torch.Size([768])
......# layers.0.post_attention_layernorm.bias : torch.Size([768])
......# layers.0.mlp.dense_h_to_4h.weight : torch.Size([1536, 768])
......# layers.0.mlp.dense_h_to_4h.bias : torch.Size([1536])
......# layers.0.mlp.dense_4h_to_h.weight : torch.Size([768, 1536])
......# layers.0.mlp.dense_4h_to_h.bias : torch.Size([768])
......# layers.1.input_layernorm.weight : torch.Size([768])
......# layers.1.input_layernorm.bias : torch.Size([768])
......# layers.1.self_attention.query_key_value.weight : torch.Size([1152, 768])
......# layers.1.self_attention.query_key_value.bias : torch.Size([1152])
......# layers.1.self_attention.dense.weight : torch.Size([768, 384])
......# layers.1.self_attention.dense.bias : torch.Size([768])
......# layers.1.post_attention_layernorm.weight : torch.Size([768])
......# layers.1.post_attention_layernorm.bias : torch.Size([768])
......# layers.1.mlp.dense_h_to_4h.weight : torch.Size([1536, 768])
......# layers.1.mlp.dense_h_to_4h.bias : torch.Size([1536])
......# layers.1.mlp.dense_4h_to_h.weight : torch.Size([768, 1536])
......# layers.1.mlp.dense_4h_to_h.bias : torch.Size([768])
......# layers.2.input_layernorm.weight : torch.Size([768])
......# layers.2.input_layernorm.bias : torch.Size([768])
......# layers.2.self_attention.query_key_value.weight : torch.Size([1152, 768])
......# layers.2.self_attention.query_key_value.bias : torch.Size([1152])
......# layers.2.self_attention.dense.weight : torch.Size([768, 384])
......# layers.2.self_attention.dense.bias : torch.Size([768])
......# layers.2.post_attention_layernorm.weight : torch.Size([768])
......# layers.2.post_attention_layernorm.bias : torch.Size([768])
......# layers.2.mlp.dense_h_to_4h.weight : torch.Size([1536, 768])
......# layers.2.mlp.dense_h_to_4h.bias : torch.Size([1536])
......# layers.2.mlp.dense_4h_to_h.weight : torch.Size([768, 1536])
......# layers.2.mlp.dense_4h_to_h.bias : torch.Size([768])
......# layers.3.input_layernorm.weight : torch.Size([768])
......# layers.3.input_layernorm.bias : torch.Size([768])
......# layers.3.self_attention.query_key_value.weight : torch.Size([1152, 768])
......# layers.3.self_attention.query_key_value.bias : torch.Size([1152])
......# layers.3.self_attention.dense.weight : torch.Size([768, 384])
......# layers.3.self_attention.dense.bias : torch.Size([768])
......# layers.3.post_attention_layernorm.weight : torch.Size([768])
......# layers.3.post_attention_layernorm.bias : torch.Size([768])
......# layers.3.mlp.dense_h_to_4h.weight : torch.Size([1536, 768])
......# layers.3.mlp.dense_h_to_4h.bias : torch.Size([1536])
......# layers.3.mlp.dense_4h_to_h.weight : torch.Size([768, 1536])
......# layers.3.mlp.dense_4h_to_h.bias : torch.Size([768])
......# layers.4.input_layernorm.weight : torch.Size([768])
......# layers.4.input_layernorm.bias : torch.Size([768])
......# layers.4.self_attention.query_key_value.weight : torch.Size([1152, 768])
......# layers.4.self_attention.query_key_value.bias : torch.Size([1152])
......# layers.4.self_attention.dense.weight : torch.Size([768, 384])
......# layers.4.self_attention.dense.bias : torch.Size([768])
......# layers.4.post_attention_layernorm.weight : torch.Size([768])
......# layers.4.post_attention_layernorm.bias : torch.Size([768])
......# layers.4.mlp.dense_h_to_4h.weight : torch.Size([1536, 768])
......# layers.4.mlp.dense_h_to_4h.bias : torch.Size([1536])
......# layers.4.mlp.dense_4h_to_h.weight : torch.Size([768, 1536])
......# layers.4.mlp.dense_4h_to_h.bias : torch.Size([768])
......# layers.5.input_layernorm.weight : torch.Size([768])
......# layers.5.input_layernorm.bias : torch.Size([768])
......# layers.5.self_attention.query_key_value.weight : torch.Size([1152, 768])
......# layers.5.self_attention.query_key_value.bias : torch.Size([1152])
......# layers.5.self_attention.dense.weight : torch.Size([768, 384])
......# layers.5.self_attention.dense.bias : torch.Size([768])
......# layers.5.post_attention_layernorm.weight : torch.Size([768])
......# layers.5.post_attention_layernorm.bias : torch.Size([768])
......# layers.5.mlp.dense_h_to_4h.weight : torch.Size([1536, 768])
......# layers.5.mlp.dense_h_to_4h.bias : torch.Size([1536])
......# layers.5.mlp.dense_4h_to_h.weight : torch.Size([768, 1536])
......# layers.5.mlp.dense_4h_to_h.bias : torch.Size([768])
......# layers.6.input_layernorm.weight : torch.Size([768])
......# layers.6.input_layernorm.bias : torch.Size([768])
......# layers.6.self_attention.query_key_value.weight : torch.Size([1152, 768])
......# layers.6.self_attention.query_key_value.bias : torch.Size([1152])
......# layers.6.self_attention.dense.weight : torch.Size([768, 384])
......# layers.6.self_attention.dense.bias : torch.Size([768])
......# layers.6.post_attention_layernorm.weight : torch.Size([768])
......# layers.6.post_attention_layernorm.bias : torch.Size([768])
......# layers.6.mlp.dense_h_to_4h.weight : torch.Size([1536, 768])
......# layers.6.mlp.dense_h_to_4h.bias : torch.Size([1536])
......# layers.6.mlp.dense_4h_to_h.weight : torch.Size([768, 1536])
......# layers.6.mlp.dense_4h_to_h.bias : torch.Size([768])
......# layers.7.input_layernorm.weight : torch.Size([768])
......# layers.7.input_layernorm.bias : torch.Size([768])
......# layers.7.self_attention.query_key_value.weight : torch.Size([1152, 768])
......# layers.7.self_attention.query_key_value.bias : torch.Size([1152])
......# layers.7.self_attention.dense.weight : torch.Size([768, 384])
......# layers.7.self_attention.dense.bias : torch.Size([768])
......# layers.7.post_attention_layernorm.weight : torch.Size([768])
......# layers.7.post_attention_layernorm.bias : torch.Size([768])
......# layers.7.mlp.dense_h_to_4h.weight : torch.Size([1536, 768])
......# layers.7.mlp.dense_h_to_4h.bias : torch.Size([1536])
......# layers.7.mlp.dense_4h_to_h.weight : torch.Size([768, 1536])
......# layers.7.mlp.dense_4h_to_h.bias : torch.Size([768])
......# layers.8.input_layernorm.weight : torch.Size([768])
......# layers.8.input_layernorm.bias : torch.Size([768])
......# layers.8.self_attention.query_key_value.weight : torch.Size([1152, 768])
......# layers.8.self_attention.query_key_value.bias : torch.Size([1152])
......# layers.8.self_attention.dense.weight : torch.Size([768, 384])
......# layers.8.self_attention.dense.bias : torch.Size([768])
......# layers.8.post_attention_layernorm.weight : torch.Size([768])
......# layers.8.post_attention_layernorm.bias : torch.Size([768])
......# layers.8.mlp.dense_h_to_4h.weight : torch.Size([1536, 768])
......# layers.8.mlp.dense_h_to_4h.bias : torch.Size([1536])
......# layers.8.mlp.dense_4h_to_h.weight : torch.Size([768, 1536])
......# layers.8.mlp.dense_4h_to_h.bias : torch.Size([768])
......# layers.9.input_layernorm.weight : torch.Size([768])
......# layers.9.input_layernorm.bias : torch.Size([768])
......# layers.9.self_attention.query_key_value.weight : torch.Size([1152, 768])
......# layers.9.self_attention.query_key_value.bias : torch.Size([1152])
......# layers.9.self_attention.dense.weight : torch.Size([768, 384])
......# layers.9.self_attention.dense.bias : torch.Size([768])
......# layers.9.post_attention_layernorm.weight : torch.Size([768])
......# layers.9.post_attention_layernorm.bias : torch.Size([768])
......# layers.9.mlp.dense_h_to_4h.weight : torch.Size([1536, 768])
......# layers.9.mlp.dense_h_to_4h.bias : torch.Size([1536])
......# layers.9.mlp.dense_4h_to_h.weight : torch.Size([768, 1536])
......# layers.9.mlp.dense_4h_to_h.bias : torch.Size([768])
......# layers.10.input_layernorm.weight : torch.Size([768])
......# layers.10.input_layernorm.bias : torch.Size([768])
......# layers.10.self_attention.query_key_value.weight : torch.Size([1152, 768])
......# layers.10.self_attention.query_key_value.bias : torch.Size([1152])
......# layers.10.self_attention.dense.weight : torch.Size([768, 384])
......# layers.10.self_attention.dense.bias : torch.Size([768])
......# layers.10.post_attention_layernorm.weight : torch.Size([768])
......# layers.10.post_attention_layernorm.bias : torch.Size([768])
......# layers.10.mlp.dense_h_to_4h.weight : torch.Size([1536, 768])
......# layers.10.mlp.dense_h_to_4h.bias : torch.Size([1536])
......# layers.10.mlp.dense_4h_to_h.weight : torch.Size([768, 1536])
......# layers.10.mlp.dense_4h_to_h.bias : torch.Size([768])
......# layers.11.input_layernorm.weight : torch.Size([768])
......# layers.11.input_layernorm.bias : torch.Size([768])
......# layers.11.self_attention.query_key_value.weight : torch.Size([1152, 768])
......# layers.11.self_attention.query_key_value.bias : torch.Size([1152])
......# layers.11.self_attention.dense.weight : torch.Size([768, 384])
......# layers.11.self_attention.dense.bias : torch.Size([768])
......# layers.11.post_attention_layernorm.weight : torch.Size([768])
......# layers.11.post_attention_layernorm.bias : torch.Size([768])
......# layers.11.mlp.dense_h_to_4h.weight : torch.Size([1536, 768])
......# layers.11.mlp.dense_h_to_4h.bias : torch.Size([1536])
......# layers.11.mlp.dense_4h_to_h.weight : torch.Size([768, 1536])
......# layers.11.mlp.dense_4h_to_h.bias : torch.Size([768])
......# final_layernorm.weight : torch.Size([768])
......# final_layernorm.bias : torch.Size([768])
# checkpoint_version : 3.0
# args : namespace(orig_vocab_size=50257, max_position_embeddings=1024, hidden_size=768, num_layers=12, num_attention_heads=12, ffn_hidden_size=3072, tensor_model_parallel_size=2, pipeline_model_parallel_size=1, data_parallel_size=2, make_vocab_size_divisible_by=128, rank=0, tokenizer_type=None, bias_gelu_fusion=True, openai_gelu=False, params_dtype=torch.bfloat16, padded_vocab_size=50432)
# optimizer
..# step : 0
..# param_groups : [{'lr': 0.0, 'beta1': 0.0, 'beta2': 0.0, 'eps': 0.0, 'weight_decay': 0.0, 'correct_bias': False, 'params': []}]
Checkpoint structure of model state dict shard belonging to TP rank 1 and PP rank 0:
# model
..# language_model
....# embedding
......# position_embeddings
........# weight : torch.Size([1024, 768])
......# word_embeddings
........# weight : torch.Size([25216, 768])
....# encoder
......# layers.0.input_layernorm.weight : torch.Size([768])
......# layers.0.input_layernorm.bias : torch.Size([768])
......# layers.0.self_attention.query_key_value.weight : torch.Size([1152, 768])
......# layers.0.self_attention.query_key_value.bias : torch.Size([1152])
......# layers.0.self_attention.dense.weight : torch.Size([768, 384])
......# layers.0.self_attention.dense.bias : torch.Size([768])
......# layers.0.post_attention_layernorm.weight : torch.Size([768])
......# layers.0.post_attention_layernorm.bias : torch.Size([768])
......# layers.0.mlp.dense_h_to_4h.weight : torch.Size([1536, 768])
......# layers.0.mlp.dense_h_to_4h.bias : torch.Size([1536])
......# layers.0.mlp.dense_4h_to_h.weight : torch.Size([768, 1536])
......# layers.0.mlp.dense_4h_to_h.bias : torch.Size([768])
......# layers.1.input_layernorm.weight : torch.Size([768])
......# layers.1.input_layernorm.bias : torch.Size([768])
......# layers.1.self_attention.query_key_value.weight : torch.Size([1152, 768])
......# layers.1.self_attention.query_key_value.bias : torch.Size([1152])
......# layers.1.self_attention.dense.weight : torch.Size([768, 384])
......# layers.1.self_attention.dense.bias : torch.Size([768])
......# layers.1.post_attention_layernorm.weight : torch.Size([768])
......# layers.1.post_attention_layernorm.bias : torch.Size([768])
......# layers.1.mlp.dense_h_to_4h.weight : torch.Size([1536, 768])
......# layers.1.mlp.dense_h_to_4h.bias : torch.Size([1536])
......# layers.1.mlp.dense_4h_to_h.weight : torch.Size([768, 1536])
......# layers.1.mlp.dense_4h_to_h.bias : torch.Size([768])
......# layers.2.input_layernorm.weight : torch.Size([768])
......# layers.2.input_layernorm.bias : torch.Size([768])
......# layers.2.self_attention.query_key_value.weight : torch.Size([1152, 768])
......# layers.2.self_attention.query_key_value.bias : torch.Size([1152])
......# layers.2.self_attention.dense.weight : torch.Size([768, 384])
......# layers.2.self_attention.dense.bias : torch.Size([768])
......# layers.2.post_attention_layernorm.weight : torch.Size([768])
......# layers.2.post_attention_layernorm.bias : torch.Size([768])
......# layers.2.mlp.dense_h_to_4h.weight : torch.Size([1536, 768])
......# layers.2.mlp.dense_h_to_4h.bias : torch.Size([1536])
......# layers.2.mlp.dense_4h_to_h.weight : torch.Size([768, 1536])
......# layers.2.mlp.dense_4h_to_h.bias : torch.Size([768])
......# layers.3.input_layernorm.weight : torch.Size([768])
......# layers.3.input_layernorm.bias : torch.Size([768])
......# layers.3.self_attention.query_key_value.weight : torch.Size([1152, 768])
......# layers.3.self_attention.query_key_value.bias : torch.Size([1152])
......# layers.3.self_attention.dense.weight : torch.Size([768, 384])
......# layers.3.self_attention.dense.bias : torch.Size([768])
......# layers.3.post_attention_layernorm.weight : torch.Size([768])
......# layers.3.post_attention_layernorm.bias : torch.Size([768])
......# layers.3.mlp.dense_h_to_4h.weight : torch.Size([1536, 768])
......# layers.3.mlp.dense_h_to_4h.bias : torch.Size([1536])
......# layers.3.mlp.dense_4h_to_h.weight : torch.Size([768, 1536])
......# layers.3.mlp.dense_4h_to_h.bias : torch.Size([768])
......# layers.4.input_layernorm.weight : torch.Size([768])
......# layers.4.input_layernorm.bias : torch.Size([768])
......# layers.4.self_attention.query_key_value.weight : torch.Size([1152, 768])
......# layers.4.self_attention.query_key_value.bias : torch.Size([1152])
......# layers.4.self_attention.dense.weight : torch.Size([768, 384])
......# layers.4.self_attention.dense.bias : torch.Size([768])
......# layers.4.post_attention_layernorm.weight : torch.Size([768])
......# layers.4.post_attention_layernorm.bias : torch.Size([768])
......# layers.4.mlp.dense_h_to_4h.weight : torch.Size([1536, 768])
......# layers.4.mlp.dense_h_to_4h.bias : torch.Size([1536])
......# layers.4.mlp.dense_4h_to_h.weight : torch.Size([768, 1536])
......# layers.4.mlp.dense_4h_to_h.bias : torch.Size([768])
......# layers.5.input_layernorm.weight : torch.Size([768])
......# layers.5.input_layernorm.bias : torch.Size([768])
......# layers.5.self_attention.query_key_value.weight : torch.Size([1152, 768])
......# layers.5.self_attention.query_key_value.bias : torch.Size([1152])
......# layers.5.self_attention.dense.weight : torch.Size([768, 384])
......# layers.5.self_attention.dense.bias : torch.Size([768])
......# layers.5.post_attention_layernorm.weight : torch.Size([768])
......# layers.5.post_attention_layernorm.bias : torch.Size([768])
......# layers.5.mlp.dense_h_to_4h.weight : torch.Size([1536, 768])
......# layers.5.mlp.dense_h_to_4h.bias : torch.Size([1536])
......# layers.5.mlp.dense_4h_to_h.weight : torch.Size([768, 1536])
......# layers.5.mlp.dense_4h_to_h.bias : torch.Size([768])
......# layers.6.input_layernorm.weight : torch.Size([768])
......# layers.6.input_layernorm.bias : torch.Size([768])
......# layers.6.self_attention.query_key_value.weight : torch.Size([1152, 768])
......# layers.6.self_attention.query_key_value.bias : torch.Size([1152])
......# layers.6.self_attention.dense.weight : torch.Size([768, 384])
......# layers.6.self_attention.dense.bias : torch.Size([768])
......# layers.6.post_attention_layernorm.weight : torch.Size([768])
......# layers.6.post_attention_layernorm.bias : torch.Size([768])
......# layers.6.mlp.dense_h_to_4h.weight : torch.Size([1536, 768])
......# layers.6.mlp.dense_h_to_4h.bias : torch.Size([1536])
......# layers.6.mlp.dense_4h_to_h.weight : torch.Size([768, 1536])
......# layers.6.mlp.dense_4h_to_h.bias : torch.Size([768])
......# layers.7.input_layernorm.weight : torch.Size([768])
......# layers.7.input_layernorm.bias : torch.Size([768])
......# layers.7.self_attention.query_key_value.weight : torch.Size([1152, 768])
......# layers.7.self_attention.query_key_value.bias : torch.Size([1152])
......# layers.7.self_attention.dense.weight : torch.Size([768, 384])
......# layers.7.self_attention.dense.bias : torch.Size([768])
......# layers.7.post_attention_layernorm.weight : torch.Size([768])
......# layers.7.post_attention_layernorm.bias : torch.Size([768])
......# layers.7.mlp.dense_h_to_4h.weight : torch.Size([1536, 768])
......# layers.7.mlp.dense_h_to_4h.bias : torch.Size([1536])
......# layers.7.mlp.dense_4h_to_h.weight : torch.Size([768, 1536])
......# layers.7.mlp.dense_4h_to_h.bias : torch.Size([768])
......# layers.8.input_layernorm.weight : torch.Size([768])
......# layers.8.input_layernorm.bias : torch.Size([768])
......# layers.8.self_attention.query_key_value.weight : torch.Size([1152, 768])
......# layers.8.self_attention.query_key_value.bias : torch.Size([1152])
......# layers.8.self_attention.dense.weight : torch.Size([768, 384])
......# layers.8.self_attention.dense.bias : torch.Size([768])
......# layers.8.post_attention_layernorm.weight : torch.Size([768])
......# layers.8.post_attention_layernorm.bias : torch.Size([768])
......# layers.8.mlp.dense_h_to_4h.weight : torch.Size([1536, 768])
......# layers.8.mlp.dense_h_to_4h.bias : torch.Size([1536])
......# layers.8.mlp.dense_4h_to_h.weight : torch.Size([768, 1536])
......# layers.8.mlp.dense_4h_to_h.bias : torch.Size([768])
......# layers.9.input_layernorm.weight : torch.Size([768])
......# layers.9.input_layernorm.bias : torch.Size([768])
......# layers.9.self_attention.query_key_value.weight : torch.Size([1152, 768])
......# layers.9.self_attention.query_key_value.bias : torch.Size([1152])
......# layers.9.self_attention.dense.weight : torch.Size([768, 384])
......# layers.9.self_attention.dense.bias : torch.Size([768])
......# layers.9.post_attention_layernorm.weight : torch.Size([768])
......# layers.9.post_attention_layernorm.bias : torch.Size([768])
......# layers.9.mlp.dense_h_to_4h.weight : torch.Size([1536, 768])
......# layers.9.mlp.dense_h_to_4h.bias : torch.Size([1536])
......# layers.9.mlp.dense_4h_to_h.weight : torch.Size([768, 1536])
......# layers.9.mlp.dense_4h_to_h.bias : torch.Size([768])
......# layers.10.input_layernorm.weight : torch.Size([768])
......# layers.10.input_layernorm.bias : torch.Size([768])
......# layers.10.self_attention.query_key_value.weight : torch.Size([1152, 768])
......# layers.10.self_attention.query_key_value.bias : torch.Size([1152])
......# layers.10.self_attention.dense.weight : torch.Size([768, 384])
......# layers.10.self_attention.dense.bias : torch.Size([768])
......# layers.10.post_attention_layernorm.weight : torch.Size([768])
......# layers.10.post_attention_layernorm.bias : torch.Size([768])
......# layers.10.mlp.dense_h_to_4h.weight : torch.Size([1536, 768])
......# layers.10.mlp.dense_h_to_4h.bias : torch.Size([1536])
......# layers.10.mlp.dense_4h_to_h.weight : torch.Size([768, 1536])
......# layers.10.mlp.dense_4h_to_h.bias : torch.Size([768])
......# layers.11.input_layernorm.weight : torch.Size([768])
......# layers.11.input_layernorm.bias : torch.Size([768])
......# layers.11.self_attention.query_key_value.weight : torch.Size([1152, 768])
......# layers.11.self_attention.query_key_value.bias : torch.Size([1152])
......# layers.11.self_attention.dense.weight : torch.Size([768, 384])
......# layers.11.self_attention.dense.bias : torch.Size([768])
......# layers.11.post_attention_layernorm.weight : torch.Size([768])
......# layers.11.post_attention_layernorm.bias : torch.Size([768])
......# layers.11.mlp.dense_h_to_4h.weight : torch.Size([1536, 768])
......# layers.11.mlp.dense_h_to_4h.bias : torch.Size([1536])
......# layers.11.mlp.dense_4h_to_h.weight : torch.Size([768, 1536])
......# layers.11.mlp.dense_4h_to_h.bias : torch.Size([768])
......# final_layernorm.weight : torch.Size([768])
......# final_layernorm.bias : torch.Size([768])
# checkpoint_version : 3.0
# args : namespace(orig_vocab_size=50257, max_position_embeddings=1024, hidden_size=768, num_layers=12, num_attention_heads=12, ffn_hidden_size=3072, tensor_model_parallel_size=2, pipeline_model_parallel_size=1, data_parallel_size=2, make_vocab_size_divisible_by=128, rank=0, tokenizer_type=None, bias_gelu_fusion=True, openai_gelu=False, params_dtype=torch.bfloat16, padded_vocab_size=50432)
# optimizer
..# step : 0
..# param_groups : [{'lr': 0.0, 'beta1': 0.0, 'beta2': 0.0, 'eps': 0.0, 'weight_decay': 0.0, 'correct_bias': False, 'params': []}]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment