Skip to content

Instantly share code, notes, and snippets.

@MicPie
Created August 28, 2022 18:37
Show Gist options
  • Save MicPie/974e8a0a5c31a04a686e35f869c094f3 to your computer and use it in GitHub Desktop.
Save MicPie/974e8a0a5c31a04a686e35f869c094f3 to your computer and use it in GitHub Desktop.
import numpy as np
import torch
import pdb
def take_out(flat_params, shape):
return flat_params[:np.prod(shape)].view(*shape), flat_params[np.prod(shape):]
def get_slice(shard_size, shard_i):
return slice(shard_size * shard_i, shard_size * (shard_i + 1))
def create_mmap_file(name, shape, dtype='float16', mode='w+'):
return np.memmap("layer_mmaps/"+name+".mmap", dtype=dtype, mode=mode, shape=shape)
#def load_sharded_weights(opt_model, sharded_checkpoint_list):
def load_sharded_weights(sharded_checkpoint_list):
#config = opt_model.config
vocab_size = 50272
num_attention_heads = 96
hidden_size = 12288
ffn_dim = 49152
max_position_embeddings = 2048
num_hidden_layers = 96
#model = opt_model.model.decoder
num_shards = len(sharded_checkpoint_list)
# noinspection PyUnresolvedReferences
vocab_size_per_shard = vocab_size // num_shards
heads_per_shard = num_attention_heads // num_shards
hidden_size_per_shard = hidden_size // num_shards
ffn_dim_per_shard = ffn_dim // num_shards
dims_per_head = hidden_size // num_attention_heads
# mmap dicts for each layer
model_decoder_layers_self_attn_k_proj_weight = {}
model_decoder_layers_self_attn_v_proj_weight = {}
model_decoder_layers_self_attn_q_proj_weight = {}
model_decoder_layers_self_attn_k_proj_bias = {}
model_decoder_layers_self_attn_v_proj_bias = {}
model_decoder_layers_self_attn_q_proj_bias = {}
model_decoder_layers_self_attn_out_proj_weight = {}
model_decoder_layers_self_attn_out_proj_bias = {}
model_decoder_layers_self_attn_layer_norm_weight = {}
model_decoder_layers_self_attn_layer_norm_bias = {}
model_decoder_layers_fc1_weight = {}
model_decoder_layers_fc1_bias = {}
model_decoder_layers_fc2_weight = {}
model_decoder_layers_fc2_bias = {}
model_decoder_layers_final_layer_norm_weight = {}
model_decoder_layers_final_layer_norm_bias = {}
for shard_i in range(num_shards):
loaded = torch.load(sharded_checkpoint_list[shard_i], map_location="cpu")
if len(loaded["model"]) == 2:
# small_model
flat_params = loaded["model"]["flat_param_0"]
load_final_layer_norm_first = False
else:
# big model
load_final_layer_norm_first = True
flat_params = torch.cat([
v.flatten()
for k, v in loaded["model"].items()
if k != "decoder.version"
])
# Vocab
print(shard_i, "lm_head.weight")
if shard_i == 0:
lm_head_weight = create_mmap_file(
"lm_head.weight",
shape=(vocab_size, hidden_size),
)
model_decoder_embed_tokens_weight = create_mmap_file(
"model.decoder.embed_tokens.weight",
shape=(vocab_size, hidden_size),
)
out, flat_params = take_out(flat_params, (vocab_size_per_shard, hidden_size))
#model.embed_tokens.weight.data[get_slice(vocab_size_per_shard, shard_i)] = out
model_decoder_embed_tokens_weight[get_slice(vocab_size_per_shard, shard_i)] = out.numpy()
model_decoder_embed_tokens_weight.flush()
#opt_model.lm_head.weight = model.embed_tokens.weight
lm_head_weight = model_decoder_embed_tokens_weight
lm_head_weight.flush()
# Pos encoding (fixed offset=2)
print(shard_i, "model.decoder.embed_positions")
if shard_i == 0:
model_decoder_embed_positions_weight = create_mmap_file(
"model.decoder.embed_positions.weight",
shape=(max_position_embeddings+2, hidden_size),
)
out, flat_params = take_out(flat_params, (max_position_embeddings + 2, hidden_size))
#model.embed_positions.weight.data[:] = out
model_decoder_embed_positions_weight[:] = out.numpy()
model_decoder_embed_positions_weight.flush()
# TODO: This is not shared? Check!
if load_final_layer_norm_first:
# Post-attention LayerNorm
print(shard_i, "model.decoder.final_layer_norm")
if shard_i == 0:
model_decoder_final_layer_norm_weight = create_mmap_file(
"model.decoder.final_layer_norm.weight",
shape=(hidden_size,),
)
model_decoder_final_layer_norm_bias = create_mmap_file(
"model.decoder.final_layer_norm.bias",
shape=(hidden_size,),
)
out, flat_params = take_out(flat_params, (hidden_size,))
#model.final_layer_norm.weight.data[:] = out
model_decoder_final_layer_norm_weight[:] = out.numpy()
model_decoder_final_layer_norm_weight#.flush()
out, flat_params = take_out(flat_params, (hidden_size,))
#model.final_layer_norm.bias.data[:] = out
model_decoder_final_layer_norm_bias[:] = out.numpy()
model_decoder_final_layer_norm_bias.flush()
# If code fails here, you need to update transformers.
# An earlier version was missing the final_layer_norm
for layer_i in range(num_hidden_layers):
# K/V/Q weights
print(shard_i, layer_i, f"model.decoder.layers.{layer_i}.self_attn")
# model.decoder.layers.0.self_attn.k_proj.weight
if shard_i == 0:
model_decoder_layers_self_attn_k_proj_weight[layer_i] = create_mmap_file(
f"model.decoder.layers.{layer_i}.self_attn.k_proj.weight",
shape=(num_attention_heads, dims_per_head, hidden_size),
)
model_decoder_layers_self_attn_v_proj_weight[layer_i] = create_mmap_file(
f"model.decoder.layers.{layer_i}.self_attn.v_proj.weight",
shape=(num_attention_heads, dims_per_head, hidden_size),
)
model_decoder_layers_self_attn_q_proj_weight[layer_i] = create_mmap_file(
f"model.decoder.layers.{layer_i}.self_attn.q_proj.weight",
shape=(num_attention_heads, dims_per_head, hidden_size),
)
out, flat_params = take_out(flat_params, (heads_per_shard, dims_per_head, hidden_size))
#model.layers[layer_i].self_attn.k_proj.weight.data.reshape(
# num_attention_heads, dims_per_head, hidden_size,
#)[get_slice(heads_per_shard, shard_i), :, :] = out
model_decoder_layers_self_attn_k_proj_weight[layer_i][get_slice(heads_per_shard, shard_i), :, :] = out.numpy()
model_decoder_layers_self_attn_k_proj_weight[layer_i].flush()
out, flat_params = take_out(flat_params, (heads_per_shard, dims_per_head, hidden_size))
#model.layers[layer_i].self_attn.v_proj.weight.data.reshape(
# num_attention_heads, dims_per_head, hidden_size,
#)[get_slice(heads_per_shard, shard_i), :, :] = out
model_decoder_layers_self_attn_v_proj_weight[layer_i][get_slice(heads_per_shard, shard_i), :, :] = out.numpy()
model_decoder_layers_self_attn_v_proj_weight[layer_i].flush()
out, flat_params = take_out(flat_params, (heads_per_shard, dims_per_head, hidden_size))
#model.layers[layer_i].self_attn.q_proj.weight.data.reshape(
# num_attention_heads, dims_per_head, hidden_size,
#)[get_slice(heads_per_shard, shard_i), :, :] = out
model_decoder_layers_self_attn_q_proj_weight[layer_i][get_slice(heads_per_shard, shard_i), :, :] = out.numpy()
model_decoder_layers_self_attn_q_proj_weight[layer_i].flush()
# K/V/Q bias
if shard_i == 0:
model_decoder_layers_self_attn_k_proj_bias[layer_i] = create_mmap_file(
f"model.decoder.layers.{layer_i}.self_attn.k_proj.bias",
shape=(hidden_size,),
)
model_decoder_layers_self_attn_v_proj_bias[layer_i] = create_mmap_file(
f"model.decoder.layers.{layer_i}.self_attn.v_proj.bias",
shape=(hidden_size,),
)
model_decoder_layers_self_attn_q_proj_bias[layer_i] = create_mmap_file(
f"model.decoder.layers.{layer_i}.self_attn.q_proj.bias",
shape=(hidden_size,),
)
out, flat_params = take_out(flat_params, (hidden_size_per_shard,))
#model.layers[layer_i].self_attn.k_proj.bias.data[
# get_slice(hidden_size_per_shard, shard_i)] = out
model_decoder_layers_self_attn_k_proj_bias[layer_i][get_slice(hidden_size_per_shard, shard_i)] = out.numpy()
model_decoder_layers_self_attn_k_proj_bias[layer_i].flush()
out, flat_params = take_out(flat_params, (hidden_size_per_shard,))
#model.layers[layer_i].self_attn.v_proj.bias.data[
# get_slice(hidden_size_per_shard, shard_i)] = out
model_decoder_layers_self_attn_v_proj_bias[layer_i][get_slice(hidden_size_per_shard, shard_i)] = out.numpy()
model_decoder_layers_self_attn_v_proj_bias[layer_i].flush()
out, flat_params = take_out(flat_params, (hidden_size_per_shard,))
#model.layers[layer_i].self_attn.q_proj.bias.data[
# get_slice(hidden_size_per_shard, shard_i)] = out
model_decoder_layers_self_attn_q_proj_bias[layer_i][get_slice(hidden_size_per_shard, shard_i)] = out.numpy()
model_decoder_layers_self_attn_q_proj_bias[layer_i].flush()
# O weight, O bias
print(shard_i, layer_i, f"model.decoder.layers.{layer_i}.self_attn.out_proj")
if shard_i == 0:
model_decoder_layers_self_attn_out_proj_weight[layer_i] = create_mmap_file(
f"model.decoder.layers.{layer_i}.self_attn.out_proj.weight",
shape=(hidden_size, num_attention_heads, dims_per_head),
)
model_decoder_layers_self_attn_out_proj_bias[layer_i] = create_mmap_file(
f"model.decoder.layers.{layer_i}.self_attn.out_proj.bias",
shape=(hidden_size,),
)
out, flat_params = take_out(flat_params, (hidden_size, heads_per_shard, dims_per_head))
#model.layers[layer_i].self_attn.out_proj.weight.data.reshape(
# hidden_size, num_attention_heads, dims_per_head,
#)[:, get_slice(heads_per_shard, shard_i), :] = out
model_decoder_layers_self_attn_out_proj_weight[layer_i][:, get_slice(heads_per_shard, shard_i), :] = out.numpy()
model_decoder_layers_self_attn_out_proj_weight[layer_i].flush()
out, flat_params = take_out(flat_params, (hidden_size,))
#model.layers[layer_i].self_attn.out_proj.bias.data[:] = out
model_decoder_layers_self_attn_out_proj_bias[layer_i][:] = out.numpy()
model_decoder_layers_self_attn_out_proj_bias[layer_i].flush()
# Input LayerNorm
print(shard_i, layer_i, f"model.decoder.layers.{layer_i}.self_attn_layer_norm")
if shard_i == 0:
model_decoder_layers_self_attn_layer_norm_weight[layer_i] = create_mmap_file(
f"model.decoder.layers.{layer_i}.self_attn_layer_norm.weight",
shape=(hidden_size,),
)
model_decoder_layers_self_attn_layer_norm_bias[layer_i] = create_mmap_file(
f"model.decoder.layers.{layer_i}.self_attn_layer_norm.bias",
shape=(hidden_size,),
)
out, flat_params = take_out(flat_params, (hidden_size,))
#model.layers[layer_i].self_attn_layer_norm.weight.data[:] = out
model_decoder_layers_self_attn_layer_norm_weight[layer_i][:] = out.numpy()
model_decoder_layers_self_attn_layer_norm_weight[layer_i].flush()
out, flat_params = take_out(flat_params, (hidden_size,))
#model.layers[layer_i].self_attn_layer_norm.bias.data[:] = out
model_decoder_layers_self_attn_layer_norm_bias[layer_i][:] = out.numpy()
model_decoder_layers_self_attn_layer_norm_bias[layer_i].flush()
# MLP dense_h_to_4h
print(shard_i, layer_i, f"model.decoder.layers.{layer_i}.fc1")
if shard_i == 0:
model_decoder_layers_fc1_weight[layer_i] = create_mmap_file(
f"model.decoder.layers.{layer_i}.fc1.weight",
shape=(ffn_dim, hidden_size),
)
model_decoder_layers_fc1_bias[layer_i] = create_mmap_file(
f"model.decoder.layers.{layer_i}.fc1.bias",
shape=(ffn_dim,),
)
out, flat_params = take_out(flat_params, (ffn_dim_per_shard, hidden_size))
#model.layers[layer_i].fc1.weight.data[
# get_slice(ffn_dim_per_shard, shard_i), :] = out
model_decoder_layers_fc1_weight[layer_i][get_slice(ffn_dim_per_shard, shard_i), :] = out.numpy()
model_decoder_layers_fc1_weight[layer_i].flush()
out, flat_params = take_out(flat_params, (ffn_dim_per_shard,))
#model.layers[layer_i].fc1.bias.data[
# get_slice(ffn_dim_per_shard, shard_i)] = out
model_decoder_layers_fc1_bias[layer_i][get_slice(ffn_dim_per_shard, shard_i)] = out.numpy()
model_decoder_layers_fc1_bias[layer_i].flush()
# MLP dense_4h_to_h
print(shard_i, layer_i, f"model.decoder.layers.{layer_i}.fc2")
model_decoder_layers_fc2_weight[layer_i] = create_mmap_file(
f"model.decoder.layers.{layer_i}.fc2.weight",
shape=(hidden_size, ffn_dim),
)
model_decoder_layers_fc2_bias[layer_i] = create_mmap_file(
f"model.decoder.layers.{layer_i}.fc2.bias",
shape=(hidden_size,),
)
out, flat_params = take_out(flat_params, (hidden_size, ffn_dim_per_shard))
#model.layers[layer_i].fc2.weight.data[
# :, get_slice(ffn_dim_per_shard, shard_i)] = out
model_decoder_layers_fc2_weight[layer_i][:, get_slice(ffn_dim_per_shard, shard_i)] = out.numpy()
model_decoder_layers_fc2_weight[layer_i].flush()
out, flat_params = take_out(flat_params, (hidden_size,))
#model.layers[layer_i].fc2.bias.data[:] = out
model_decoder_layers_fc2_bias[layer_i][:] = out.numpy()
model_decoder_layers_fc2_bias[layer_i].flush()
# Post-attention LayerNorm
print(shard_i, layer_i, f"model.decoder.layers.{layer_i}.final_layer_norm")
if shard_i == 0:
model_decoder_layers_final_layer_norm_weight[layer_i] = create_mmap_file(
f"model.decoder.layers.{layer_i}.final_layer_norm.weight",
shape=(hidden_size,),
)
model_decoder_layers_final_layer_norm_bias[layer_i] = create_mmap_file(
f"model.decoder.layers.{layer_i}.final_layer_norm.bias",
shape=(hidden_size,),
)
out, flat_params = take_out(flat_params, (hidden_size,))
#model.layers[layer_i].final_layer_norm.weight.data[:] = out
model_decoder_layers_final_layer_norm_weight[layer_i][:] = out.numpy()
model_decoder_layers_final_layer_norm_weight[layer_i].flush()
out, flat_params = take_out(flat_params, (hidden_size,))
#model.layers[layer_i].final_layer_norm.bias.data[:] = out
model_decoder_layers_final_layer_norm_bias[layer_i][:] = out.numpy()
model_decoder_layers_final_layer_norm_bias[layer_i].flush()
if not load_final_layer_norm_first:
print(shard_i, layer_i, "model.decoder.final_layernorm")
# Post-attention LayerNorm
if shard_i == 0:
model_decoder_final_layernorm_weight = create_mmap_file(
"model.decoder.final_layernorm.weight",
shape=(hidden_size,),
)
model_decoder_final_layernorm_bias = create_mmap_file(
"model.decoder.final_layernorm.bias",
shape=(hidden_size,),
)
out, flat_params = take_out(flat_params, (hidden_size,))
#model.final_layernorm.weight.data[:] = out
model_decoder_final_layernorm_weight = out.numpy()
model_decoder_final_layernorm_weight.flush()
out, flat_params = take_out(flat_params, (hidden_size,))
#model.final_layernorm.bias.data[:] = out
model_decoder_final_layernorm_bias = out.numpy()
model_decoder_final_layernorm_bias.flush()
# If code fails here, you need to update transformers.
# An earlier version was missing the final_layer_norm
assert flat_params.numel() == 0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment