Skip to content

Instantly share code, notes, and snippets.

@chu-tianxiang
Created March 20, 2024 04:40
Show Gist options
  • Save chu-tianxiang/ec310e15d56949fd0f351cb5f65ee7a1 to your computer and use it in GitHub Desktop.
Save chu-tianxiang/ec310e15d56949fd0f351cb5f65ee7a1 to your computer and use it in GitHub Desktop.
Convert grok-1 weight to torch
import numpy as np
import torch
import jax
from tqdm import tqdm
from model import LanguageModelConfig, TransformerConfig, QuantizedWeight8bit as QW8Bit
from runners import InferenceRunner, ModelRunner, sample_from_model
CKPT_PATH = "./checkpoints"
grok_1_model = LanguageModelConfig(
vocab_size=128 * 1024,
pad_token=0,
eos_token=2,
sequence_len=8192,
embedding_init_scale=1.0,
output_multiplier_scale=0.5773502691896257,
embedding_multiplier_scale=78.38367176906169,
model=TransformerConfig(
emb_size=48 * 128,
widening_factor=8,
key_size=128,
num_q_heads=48,
num_kv_heads=8,
num_layers=64,
attn_output_multiplier=0.08838834764831845,
shard_activations=True,
# MoE.
num_experts=8,
num_selected_experts=2,
# Activation sharding.
data_axis="data",
model_axis="model",
),
)
runner = ModelRunner(
model=grok_1_model,
bs_per_device=0.125,
checkpoint_path=CKPT_PATH,
)
dummy_data = dict(
inputs=np.zeros((1, 256), dtype=np.int32),
targets=np.zeros((1, 256), dtype=np.int32),
)
runner.transform_forward = True
runner.initialize(dummy_data, (1, 1), (1, 1))
params = runner.load_or_init(dummy_data)
new_params = {}
keys = list(params.keys())
for key in tqdm(keys):
new_key = key.replace('/', '.').replace('decoder_layer_', 'decoder_layer.').replace('language_model', 'transformer')
new_key += '.weight'
v = list(params[key].values())[0]
if hasattr(v , 'scales'):
dtype = torch.float32 if v.scales.dtype == np.float32 else torch.bfloat16
# torch cannot convert bfloat16 directly
weight = torch.from_numpy(np.asarray(v.weight).astype(np.float32)).to(dtype)
scale =torch.from_numpy(np.asarray(v.scales).astype(np.float32)).to(dtype)
# row parallel layers have sharded scale
if len(scale.shape) >= 2 and scale.shape[-2] != 1:
scale = scale[..., None, :]
weight = weight.view(*weight.shape[:-2], 8, -1, weight.shape[-1])
weight = (weight * scale).view(*weight.shape[:-3], -1, weight.shape[-1])
else:
weight = weight * scale
else:
dtype = torch.float32 if v.dtype == np.float32 else torch.bfloat16
weight = torch.from_numpy(np.asarray(v).astype(np.float32)).to(dtype)
# Transpose linear matrix
if len(weight.shape) >= 2 and 'in_out_embed' not in new_key:
weight = weight.transpose(-1, -2).contiguous()
if 'moe' not in new_key:
new_params[new_key] = weight
else:
# split moe
for i in range(8):
new_key_i = new_key.replace('moe', f'moe.{i}')
new_params[new_key_i] = weight[i]
del params[key]
torch.save(new_params, 'hf/pytorch_model.bin')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment