-
-
Save mtisz/de47d181ee3faaa44254c15bc43d24a1 to your computer and use it in GitHub Desktop.
Convert grok-1 weight to torch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import torch | |
import jax | |
from tqdm import tqdm | |
from model import LanguageModelConfig, TransformerConfig, QuantizedWeight8bit as QW8Bit | |
from runners import InferenceRunner, ModelRunner, sample_from_model | |
CKPT_PATH = "./checkpoints" | |
grok_1_model = LanguageModelConfig( | |
vocab_size=128 * 1024, | |
pad_token=0, | |
eos_token=2, | |
sequence_len=8192, | |
embedding_init_scale=1.0, | |
output_multiplier_scale=0.5773502691896257, | |
embedding_multiplier_scale=78.38367176906169, | |
model=TransformerConfig( | |
emb_size=48 * 128, | |
widening_factor=8, | |
key_size=128, | |
num_q_heads=48, | |
num_kv_heads=8, | |
num_layers=64, | |
attn_output_multiplier=0.08838834764831845, | |
shard_activations=True, | |
# MoE. | |
num_experts=8, | |
num_selected_experts=2, | |
# Activation sharding. | |
data_axis="data", | |
model_axis="model", | |
), | |
) | |
runner = ModelRunner( | |
model=grok_1_model, | |
bs_per_device=0.125, | |
checkpoint_path=CKPT_PATH, | |
) | |
dummy_data = dict( | |
inputs=np.zeros((1, 256), dtype=np.int32), | |
targets=np.zeros((1, 256), dtype=np.int32), | |
) | |
runner.transform_forward = True | |
runner.initialize(dummy_data, (1, 1), (1, 1)) | |
params = runner.load_or_init(dummy_data) | |
new_params = {} | |
keys = list(params.keys()) | |
for key in tqdm(keys): | |
new_key = key.replace('/', '.').replace('decoder_layer_', 'decoder_layer.').replace('language_model', 'transformer') | |
new_key += '.weight' | |
v = list(params[key].values())[0] | |
if hasattr(v , 'scales'): | |
dtype = torch.float32 if v.scales.dtype == np.float32 else torch.bfloat16 | |
# torch cannot convert bfloat16 directly | |
weight = torch.from_numpy(np.asarray(v.weight).astype(np.float32)).to(dtype) | |
scale =torch.from_numpy(np.asarray(v.scales).astype(np.float32)).to(dtype) | |
# row parallel layers have sharded scale | |
if len(scale.shape) >= 2 and scale.shape[-2] != 1: | |
scale = scale[..., None, :] | |
weight = weight.view(*weight.shape[:-2], 8, -1, weight.shape[-1]) | |
weight = (weight * scale).view(*weight.shape[:-3], -1, weight.shape[-1]) | |
else: | |
weight = weight * scale | |
else: | |
dtype = torch.float32 if v.dtype == np.float32 else torch.bfloat16 | |
weight = torch.from_numpy(np.asarray(v).astype(np.float32)).to(dtype) | |
# Transpose linear matrix | |
if len(weight.shape) >= 2 and 'in_out_embed' not in new_key: | |
weight = weight.transpose(-1, -2).contiguous() | |
if 'moe' not in new_key: | |
new_params[new_key] = weight | |
else: | |
# split moe | |
for i in range(8): | |
new_key_i = new_key.replace('moe', f'moe.{i}') | |
new_params[new_key_i] = weight[i] | |
del params[key] | |
torch.save(new_params, 'hf/pytorch_model.bin') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment