Skip to content

Instantly share code, notes, and snippets.

@maldevide
Last active February 14, 2024 10:00
Show Gist options
  • Save maldevide/08829eada04ad9bd78e46c1a3787d42b to your computer and use it in GitHub Desktop.
Save maldevide/08829eada04ad9bd78e46c1a3787d42b to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/mnt/biggy/ai/notebook/txvenv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"import transformers\n",
"import torch\n",
"from transformers import AutoModelForCausalLM, AutoTokenizer\n",
"\n",
"# Load the model\n",
"model_o_path = \"../llama/raw/TinyLlama-1.1B-intermediate-step-1431k-3T/\"\n",
"model_a_path = \"../llama/raw/TinyLlama-1.1B-32k\"\n",
"model_b_path = \"../llama/raw/TinyLlama-1.1B-3T-openhermes/\"\n",
"model_c_path = \"../llama/raw/TinyLlama-3T-Cinder-v1.3/\"\n",
"model_d_path = \"../llama/raw/TinyLlama-1.1B-32k\"\n",
"models = [\n",
" model_o_path,\n",
" model_a_path,\n",
" model_b_path,\n",
" model_c_path,\n",
" model_d_path,\n",
"]\n",
"trust_remote_code = True\n",
"config = {\n",
" 'torch_dtype': torch.float16,\n",
" 'low_cpu_mem_usage': True,\n",
" 'trust_remote_code': trust_remote_code,\n",
"}\n",
"\n",
"models = [\n",
" transformers.AutoModelForCausalLM.from_pretrained(\n",
" model_path, **config\n",
" ) for model_path in models\n",
"]\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"201"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"keys = models[0].state_dict().keys()\n",
"len(keys)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" 0 model.embed_tokens.weight torch.Size([32000, 2048]) torch.float16 cpu True\n",
" 1 model.layers.0.self_attn.q_proj.weight torch.Size([2048, 2048]) torch.float16 cpu True\n",
" 2 model.layers.0.self_attn.k_proj.weight torch.Size([256, 2048]) torch.float16 cpu True\n",
" 3 model.layers.0.self_attn.v_proj.weight torch.Size([256, 2048]) torch.float16 cpu True\n",
" 4 model.layers.0.self_attn.o_proj.weight torch.Size([2048, 2048]) torch.float16 cpu True\n",
" 5 model.layers.0.mlp.gate_proj.weight torch.Size([5632, 2048]) torch.float16 cpu True\n",
" 6 model.layers.0.mlp.up_proj.weight torch.Size([5632, 2048]) torch.float16 cpu True\n",
" 7 model.layers.0.mlp.down_proj.weight torch.Size([2048, 5632]) torch.float16 cpu True\n",
" 8 model.layers.0.input_layernorm.weight torch.Size([2048]) torch.float16 cpu True\n",
" 9 model.layers.0.post_attention_layernorm.weight torch.Size([2048]) torch.float16 cpu True\n",
" 10 model.layers.1.self_attn.q_proj.weight torch.Size([2048, 2048]) torch.float16 cpu True\n",
" 11 model.layers.1.self_attn.k_proj.weight torch.Size([256, 2048]) torch.float16 cpu True\n",
" 12 model.layers.1.self_attn.v_proj.weight torch.Size([256, 2048]) torch.float16 cpu True\n",
" 13 model.layers.1.self_attn.o_proj.weight torch.Size([2048, 2048]) torch.float16 cpu True\n",
" 14 model.layers.1.mlp.gate_proj.weight torch.Size([5632, 2048]) torch.float16 cpu True\n",
" 15 model.layers.1.mlp.up_proj.weight torch.Size([5632, 2048]) torch.float16 cpu True\n",
" 16 model.layers.1.mlp.down_proj.weight torch.Size([2048, 5632]) torch.float16 cpu True\n",
" 17 model.layers.1.input_layernorm.weight torch.Size([2048]) torch.float16 cpu True\n",
" 18 model.layers.1.post_attention_layernorm.weight torch.Size([2048]) torch.float16 cpu True\n",
" 19 model.layers.2.self_attn.q_proj.weight torch.Size([2048, 2048]) torch.float16 cpu True\n",
" 20 model.layers.2.self_attn.k_proj.weight torch.Size([256, 2048]) torch.float16 cpu True\n",
" 21 model.layers.2.self_attn.v_proj.weight torch.Size([256, 2048]) torch.float16 cpu True\n",
" 22 model.layers.2.self_attn.o_proj.weight torch.Size([2048, 2048]) torch.float16 cpu True\n",
" 23 model.layers.2.mlp.gate_proj.weight torch.Size([5632, 2048]) torch.float16 cpu True\n",
" 24 model.layers.2.mlp.up_proj.weight torch.Size([5632, 2048]) torch.float16 cpu True\n",
" 25 model.layers.2.mlp.down_proj.weight torch.Size([2048, 5632]) torch.float16 cpu True\n",
" 26 model.layers.2.input_layernorm.weight torch.Size([2048]) torch.float16 cpu True\n",
" 27 model.layers.2.post_attention_layernorm.weight torch.Size([2048]) torch.float16 cpu True\n",
" 28 model.layers.3.self_attn.q_proj.weight torch.Size([2048, 2048]) torch.float16 cpu True\n",
" 29 model.layers.3.self_attn.k_proj.weight torch.Size([256, 2048]) torch.float16 cpu True\n",
" 30 model.layers.3.self_attn.v_proj.weight torch.Size([256, 2048]) torch.float16 cpu True\n",
" 31 model.layers.3.self_attn.o_proj.weight torch.Size([2048, 2048]) torch.float16 cpu True\n",
" 32 model.layers.3.mlp.gate_proj.weight torch.Size([5632, 2048]) torch.float16 cpu True\n",
" 33 model.layers.3.mlp.up_proj.weight torch.Size([5632, 2048]) torch.float16 cpu True\n",
" 34 model.layers.3.mlp.down_proj.weight torch.Size([2048, 5632]) torch.float16 cpu True\n",
" 35 model.layers.3.input_layernorm.weight torch.Size([2048]) torch.float16 cpu True\n",
" 36 model.layers.3.post_attention_layernorm.weight torch.Size([2048]) torch.float16 cpu True\n",
" 37 model.layers.4.self_attn.q_proj.weight torch.Size([2048, 2048]) torch.float16 cpu True\n",
" 38 model.layers.4.self_attn.k_proj.weight torch.Size([256, 2048]) torch.float16 cpu True\n",
" 39 model.layers.4.self_attn.v_proj.weight torch.Size([256, 2048]) torch.float16 cpu True\n",
" 40 model.layers.4.self_attn.o_proj.weight torch.Size([2048, 2048]) torch.float16 cpu True\n",
" 41 model.layers.4.mlp.gate_proj.weight torch.Size([5632, 2048]) torch.float16 cpu True\n",
" 42 model.layers.4.mlp.up_proj.weight torch.Size([5632, 2048]) torch.float16 cpu True\n",
" 43 model.layers.4.mlp.down_proj.weight torch.Size([2048, 5632]) torch.float16 cpu True\n",
" 44 model.layers.4.input_layernorm.weight torch.Size([2048]) torch.float16 cpu True\n",
" 45 model.layers.4.post_attention_layernorm.weight torch.Size([2048]) torch.float16 cpu True\n",
" 46 model.layers.5.self_attn.q_proj.weight torch.Size([2048, 2048]) torch.float16 cpu True\n",
" 47 model.layers.5.self_attn.k_proj.weight torch.Size([256, 2048]) torch.float16 cpu True\n",
" 48 model.layers.5.self_attn.v_proj.weight torch.Size([256, 2048]) torch.float16 cpu True\n",
" 49 model.layers.5.self_attn.o_proj.weight torch.Size([2048, 2048]) torch.float16 cpu True\n",
" 50 model.layers.5.mlp.gate_proj.weight torch.Size([5632, 2048]) torch.float16 cpu True\n",
" 51 model.layers.5.mlp.up_proj.weight torch.Size([5632, 2048]) torch.float16 cpu True\n",
" 52 model.layers.5.mlp.down_proj.weight torch.Size([2048, 5632]) torch.float16 cpu True\n",
" 53 model.layers.5.input_layernorm.weight torch.Size([2048]) torch.float16 cpu True\n",
" 54 model.layers.5.post_attention_layernorm.weight torch.Size([2048]) torch.float16 cpu True\n",
" 55 model.layers.6.self_attn.q_proj.weight torch.Size([2048, 2048]) torch.float16 cpu True\n",
" 56 model.layers.6.self_attn.k_proj.weight torch.Size([256, 2048]) torch.float16 cpu True\n",
" 57 model.layers.6.self_attn.v_proj.weight torch.Size([256, 2048]) torch.float16 cpu True\n",
" 58 model.layers.6.self_attn.o_proj.weight torch.Size([2048, 2048]) torch.float16 cpu True\n",
" 59 model.layers.6.mlp.gate_proj.weight torch.Size([5632, 2048]) torch.float16 cpu True\n",
" 60 model.layers.6.mlp.up_proj.weight torch.Size([5632, 2048]) torch.float16 cpu True\n",
" 61 model.layers.6.mlp.down_proj.weight torch.Size([2048, 5632]) torch.float16 cpu True\n",
" 62 model.layers.6.input_layernorm.weight torch.Size([2048]) torch.float16 cpu True\n",
" 63 model.layers.6.post_attention_layernorm.weight torch.Size([2048]) torch.float16 cpu True\n",
" 64 model.layers.7.self_attn.q_proj.weight torch.Size([2048, 2048]) torch.float16 cpu True\n",
" 65 model.layers.7.self_attn.k_proj.weight torch.Size([256, 2048]) torch.float16 cpu True\n",
" 66 model.layers.7.self_attn.v_proj.weight torch.Size([256, 2048]) torch.float16 cpu True\n",
" 67 model.layers.7.self_attn.o_proj.weight torch.Size([2048, 2048]) torch.float16 cpu True\n",
" 68 model.layers.7.mlp.gate_proj.weight torch.Size([5632, 2048]) torch.float16 cpu True\n",
" 69 model.layers.7.mlp.up_proj.weight torch.Size([5632, 2048]) torch.float16 cpu True\n",
" 70 model.layers.7.mlp.down_proj.weight torch.Size([2048, 5632]) torch.float16 cpu True\n",
" 71 model.layers.7.input_layernorm.weight torch.Size([2048]) torch.float16 cpu True\n",
" 72 model.layers.7.post_attention_layernorm.weight torch.Size([2048]) torch.float16 cpu True\n",
" 73 model.layers.8.self_attn.q_proj.weight torch.Size([2048, 2048]) torch.float16 cpu True\n",
" 74 model.layers.8.self_attn.k_proj.weight torch.Size([256, 2048]) torch.float16 cpu True\n",
" 75 model.layers.8.self_attn.v_proj.weight torch.Size([256, 2048]) torch.float16 cpu True\n",
" 76 model.layers.8.self_attn.o_proj.weight torch.Size([2048, 2048]) torch.float16 cpu True\n",
" 77 model.layers.8.mlp.gate_proj.weight torch.Size([5632, 2048]) torch.float16 cpu True\n",
" 78 model.layers.8.mlp.up_proj.weight torch.Size([5632, 2048]) torch.float16 cpu True\n",
" 79 model.layers.8.mlp.down_proj.weight torch.Size([2048, 5632]) torch.float16 cpu True\n",
" 80 model.layers.8.input_layernorm.weight torch.Size([2048]) torch.float16 cpu True\n",
" 81 model.layers.8.post_attention_layernorm.weight torch.Size([2048]) torch.float16 cpu True\n",
" 82 model.layers.9.self_attn.q_proj.weight torch.Size([2048, 2048]) torch.float16 cpu True\n",
" 83 model.layers.9.self_attn.k_proj.weight torch.Size([256, 2048]) torch.float16 cpu True\n",
" 84 model.layers.9.self_attn.v_proj.weight torch.Size([256, 2048]) torch.float16 cpu True\n",
" 85 model.layers.9.self_attn.o_proj.weight torch.Size([2048, 2048]) torch.float16 cpu True\n",
" 86 model.layers.9.mlp.gate_proj.weight torch.Size([5632, 2048]) torch.float16 cpu True\n",
" 87 model.layers.9.mlp.up_proj.weight torch.Size([5632, 2048]) torch.float16 cpu True\n",
" 88 model.layers.9.mlp.down_proj.weight torch.Size([2048, 5632]) torch.float16 cpu True\n",
" 89 model.layers.9.input_layernorm.weight torch.Size([2048]) torch.float16 cpu True\n",
" 90 model.layers.9.post_attention_layernorm.weight torch.Size([2048]) torch.float16 cpu True\n",
" 91 model.layers.10.self_attn.q_proj.weight torch.Size([2048, 2048]) torch.float16 cpu True\n",
" 92 model.layers.10.self_attn.k_proj.weight torch.Size([256, 2048]) torch.float16 cpu True\n",
" 93 model.layers.10.self_attn.v_proj.weight torch.Size([256, 2048]) torch.float16 cpu True\n",
" 94 model.layers.10.self_attn.o_proj.weight torch.Size([2048, 2048]) torch.float16 cpu True\n",
" 95 model.layers.10.mlp.gate_proj.weight torch.Size([5632, 2048]) torch.float16 cpu True\n",
" 96 model.layers.10.mlp.up_proj.weight torch.Size([5632, 2048]) torch.float16 cpu True\n",
" 97 model.layers.10.mlp.down_proj.weight torch.Size([2048, 5632]) torch.float16 cpu True\n",
" 98 model.layers.10.input_layernorm.weight torch.Size([2048]) torch.float16 cpu True\n",
" 99 model.layers.10.post_attention_layernorm.weight torch.Size([2048]) torch.float16 cpu True\n",
"100 model.layers.11.self_attn.q_proj.weight torch.Size([2048, 2048]) torch.float16 cpu True\n",
"101 model.layers.11.self_attn.k_proj.weight torch.Size([256, 2048]) torch.float16 cpu True\n",
"102 model.layers.11.self_attn.v_proj.weight torch.Size([256, 2048]) torch.float16 cpu True\n",
"103 model.layers.11.self_attn.o_proj.weight torch.Size([2048, 2048]) torch.float16 cpu True\n",
"104 model.layers.11.mlp.gate_proj.weight torch.Size([5632, 2048]) torch.float16 cpu True\n",
"105 model.layers.11.mlp.up_proj.weight torch.Size([5632, 2048]) torch.float16 cpu True\n",
"106 model.layers.11.mlp.down_proj.weight torch.Size([2048, 5632]) torch.float16 cpu True\n",
"107 model.layers.11.input_layernorm.weight torch.Size([2048]) torch.float16 cpu True\n",
"108 model.layers.11.post_attention_layernorm.weight torch.Size([2048]) torch.float16 cpu True\n",
"109 model.layers.12.self_attn.q_proj.weight torch.Size([2048, 2048]) torch.float16 cpu True\n",
"110 model.layers.12.self_attn.k_proj.weight torch.Size([256, 2048]) torch.float16 cpu True\n",
"111 model.layers.12.self_attn.v_proj.weight torch.Size([256, 2048]) torch.float16 cpu True\n",
"112 model.layers.12.self_attn.o_proj.weight torch.Size([2048, 2048]) torch.float16 cpu True\n",
"113 model.layers.12.mlp.gate_proj.weight torch.Size([5632, 2048]) torch.float16 cpu True\n",
"114 model.layers.12.mlp.up_proj.weight torch.Size([5632, 2048]) torch.float16 cpu True\n",
"115 model.layers.12.mlp.down_proj.weight torch.Size([2048, 5632]) torch.float16 cpu True\n",
"116 model.layers.12.input_layernorm.weight torch.Size([2048]) torch.float16 cpu True\n",
"117 model.layers.12.post_attention_layernorm.weight torch.Size([2048]) torch.float16 cpu True\n",
"118 model.layers.13.self_attn.q_proj.weight torch.Size([2048, 2048]) torch.float16 cpu True\n",
"119 model.layers.13.self_attn.k_proj.weight torch.Size([256, 2048]) torch.float16 cpu True\n",
"120 model.layers.13.self_attn.v_proj.weight torch.Size([256, 2048]) torch.float16 cpu True\n",
"121 model.layers.13.self_attn.o_proj.weight torch.Size([2048, 2048]) torch.float16 cpu True\n",
"122 model.layers.13.mlp.gate_proj.weight torch.Size([5632, 2048]) torch.float16 cpu True\n",
"123 model.layers.13.mlp.up_proj.weight torch.Size([5632, 2048]) torch.float16 cpu True\n",
"124 model.layers.13.mlp.down_proj.weight torch.Size([2048, 5632]) torch.float16 cpu True\n",
"125 model.layers.13.input_layernorm.weight torch.Size([2048]) torch.float16 cpu True\n",
"126 model.layers.13.post_attention_layernorm.weight torch.Size([2048]) torch.float16 cpu True\n",
"127 model.layers.14.self_attn.q_proj.weight torch.Size([2048, 2048]) torch.float16 cpu True\n",
"128 model.layers.14.self_attn.k_proj.weight torch.Size([256, 2048]) torch.float16 cpu True\n",
"129 model.layers.14.self_attn.v_proj.weight torch.Size([256, 2048]) torch.float16 cpu True\n",
"130 model.layers.14.self_attn.o_proj.weight torch.Size([2048, 2048]) torch.float16 cpu True\n",
"131 model.layers.14.mlp.gate_proj.weight torch.Size([5632, 2048]) torch.float16 cpu True\n",
"132 model.layers.14.mlp.up_proj.weight torch.Size([5632, 2048]) torch.float16 cpu True\n",
"133 model.layers.14.mlp.down_proj.weight torch.Size([2048, 5632]) torch.float16 cpu True\n",
"134 model.layers.14.input_layernorm.weight torch.Size([2048]) torch.float16 cpu True\n",
"135 model.layers.14.post_attention_layernorm.weight torch.Size([2048]) torch.float16 cpu True\n",
"136 model.layers.15.self_attn.q_proj.weight torch.Size([2048, 2048]) torch.float16 cpu True\n",
"137 model.layers.15.self_attn.k_proj.weight torch.Size([256, 2048]) torch.float16 cpu True\n",
"138 model.layers.15.self_attn.v_proj.weight torch.Size([256, 2048]) torch.float16 cpu True\n",
"139 model.layers.15.self_attn.o_proj.weight torch.Size([2048, 2048]) torch.float16 cpu True\n",
"140 model.layers.15.mlp.gate_proj.weight torch.Size([5632, 2048]) torch.float16 cpu True\n",
"141 model.layers.15.mlp.up_proj.weight torch.Size([5632, 2048]) torch.float16 cpu True\n",
"142 model.layers.15.mlp.down_proj.weight torch.Size([2048, 5632]) torch.float16 cpu True\n",
"143 model.layers.15.input_layernorm.weight torch.Size([2048]) torch.float16 cpu True\n",
"144 model.layers.15.post_attention_layernorm.weight torch.Size([2048]) torch.float16 cpu True\n",
"145 model.layers.16.self_attn.q_proj.weight torch.Size([2048, 2048]) torch.float16 cpu True\n",
"146 model.layers.16.self_attn.k_proj.weight torch.Size([256, 2048]) torch.float16 cpu True\n",
"147 model.layers.16.self_attn.v_proj.weight torch.Size([256, 2048]) torch.float16 cpu True\n",
"148 model.layers.16.self_attn.o_proj.weight torch.Size([2048, 2048]) torch.float16 cpu True\n",
"149 model.layers.16.mlp.gate_proj.weight torch.Size([5632, 2048]) torch.float16 cpu True\n",
"150 model.layers.16.mlp.up_proj.weight torch.Size([5632, 2048]) torch.float16 cpu True\n",
"151 model.layers.16.mlp.down_proj.weight torch.Size([2048, 5632]) torch.float16 cpu True\n",
"152 model.layers.16.input_layernorm.weight torch.Size([2048]) torch.float16 cpu True\n",
"153 model.layers.16.post_attention_layernorm.weight torch.Size([2048]) torch.float16 cpu True\n",
"154 model.layers.17.self_attn.q_proj.weight torch.Size([2048, 2048]) torch.float16 cpu True\n",
"155 model.layers.17.self_attn.k_proj.weight torch.Size([256, 2048]) torch.float16 cpu True\n",
"156 model.layers.17.self_attn.v_proj.weight torch.Size([256, 2048]) torch.float16 cpu True\n",
"157 model.layers.17.self_attn.o_proj.weight torch.Size([2048, 2048]) torch.float16 cpu True\n",
"158 model.layers.17.mlp.gate_proj.weight torch.Size([5632, 2048]) torch.float16 cpu True\n",
"159 model.layers.17.mlp.up_proj.weight torch.Size([5632, 2048]) torch.float16 cpu True\n",
"160 model.layers.17.mlp.down_proj.weight torch.Size([2048, 5632]) torch.float16 cpu True\n",
"161 model.layers.17.input_layernorm.weight torch.Size([2048]) torch.float16 cpu True\n",
"162 model.layers.17.post_attention_layernorm.weight torch.Size([2048]) torch.float16 cpu True\n",
"163 model.layers.18.self_attn.q_proj.weight torch.Size([2048, 2048]) torch.float16 cpu True\n",
"164 model.layers.18.self_attn.k_proj.weight torch.Size([256, 2048]) torch.float16 cpu True\n",
"165 model.layers.18.self_attn.v_proj.weight torch.Size([256, 2048]) torch.float16 cpu True\n",
"166 model.layers.18.self_attn.o_proj.weight torch.Size([2048, 2048]) torch.float16 cpu True\n",
"167 model.layers.18.mlp.gate_proj.weight torch.Size([5632, 2048]) torch.float16 cpu True\n",
"168 model.layers.18.mlp.up_proj.weight torch.Size([5632, 2048]) torch.float16 cpu True\n",
"169 model.layers.18.mlp.down_proj.weight torch.Size([2048, 5632]) torch.float16 cpu True\n",
"170 model.layers.18.input_layernorm.weight torch.Size([2048]) torch.float16 cpu True\n",
"171 model.layers.18.post_attention_layernorm.weight torch.Size([2048]) torch.float16 cpu True\n",
"172 model.layers.19.self_attn.q_proj.weight torch.Size([2048, 2048]) torch.float16 cpu True\n",
"173 model.layers.19.self_attn.k_proj.weight torch.Size([256, 2048]) torch.float16 cpu True\n",
"174 model.layers.19.self_attn.v_proj.weight torch.Size([256, 2048]) torch.float16 cpu True\n",
"175 model.layers.19.self_attn.o_proj.weight torch.Size([2048, 2048]) torch.float16 cpu True\n",
"176 model.layers.19.mlp.gate_proj.weight torch.Size([5632, 2048]) torch.float16 cpu True\n",
"177 model.layers.19.mlp.up_proj.weight torch.Size([5632, 2048]) torch.float16 cpu True\n",
"178 model.layers.19.mlp.down_proj.weight torch.Size([2048, 5632]) torch.float16 cpu True\n",
"179 model.layers.19.input_layernorm.weight torch.Size([2048]) torch.float16 cpu True\n",
"180 model.layers.19.post_attention_layernorm.weight torch.Size([2048]) torch.float16 cpu True\n",
"181 model.layers.20.self_attn.q_proj.weight torch.Size([2048, 2048]) torch.float16 cpu True\n",
"182 model.layers.20.self_attn.k_proj.weight torch.Size([256, 2048]) torch.float16 cpu True\n",
"183 model.layers.20.self_attn.v_proj.weight torch.Size([256, 2048]) torch.float16 cpu True\n",
"184 model.layers.20.self_attn.o_proj.weight torch.Size([2048, 2048]) torch.float16 cpu True\n",
"185 model.layers.20.mlp.gate_proj.weight torch.Size([5632, 2048]) torch.float16 cpu True\n",
"186 model.layers.20.mlp.up_proj.weight torch.Size([5632, 2048]) torch.float16 cpu True\n",
"187 model.layers.20.mlp.down_proj.weight torch.Size([2048, 5632]) torch.float16 cpu True\n",
"188 model.layers.20.input_layernorm.weight torch.Size([2048]) torch.float16 cpu True\n",
"189 model.layers.20.post_attention_layernorm.weight torch.Size([2048]) torch.float16 cpu True\n",
"190 model.layers.21.self_attn.q_proj.weight torch.Size([2048, 2048]) torch.float16 cpu True\n",
"191 model.layers.21.self_attn.k_proj.weight torch.Size([256, 2048]) torch.float16 cpu True\n",
"192 model.layers.21.self_attn.v_proj.weight torch.Size([256, 2048]) torch.float16 cpu True\n",
"193 model.layers.21.self_attn.o_proj.weight torch.Size([2048, 2048]) torch.float16 cpu True\n",
"194 model.layers.21.mlp.gate_proj.weight torch.Size([5632, 2048]) torch.float16 cpu True\n",
"195 model.layers.21.mlp.up_proj.weight torch.Size([5632, 2048]) torch.float16 cpu True\n",
"196 model.layers.21.mlp.down_proj.weight torch.Size([2048, 5632]) torch.float16 cpu True\n",
"197 model.layers.21.input_layernorm.weight torch.Size([2048]) torch.float16 cpu True\n",
"198 model.layers.21.post_attention_layernorm.weight torch.Size([2048]) torch.float16 cpu True\n",
"199 model.norm.weight torch.Size([2048]) torch.float16 cpu True\n",
"200 lm_head.weight torch.Size([32000, 2048]) torch.float16 cpu True\n"
]
}
],
"source": [
"for i, k in enumerate(keys):\n",
" tensor = models[0].state_dict()[k]\n",
" print(f\"{i:3d} {k} {tensor.shape} {tensor.dtype} {tensor.device} {tensor.is_contiguous()}\")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Hello my dog is cute and loves to play with me. I love to play with him. He is very friendly and loves to play with me. He is very smart and loves to play with me. He is very friendly and loves to play with me. He is very smart and loves to play with me. He is very friendly and loves to play with me. He is very smart and loves to play with me. He is very friendly and loves to play with me. He is very smart and loves to play with me. He is very friendly and loves to play with me. He is very smart and loves to play with me. He is very friendly and loves to play with me. He is very smart and loves to play with me. He is very friendly and loves to play with me. He is very smart and loves to play with me. He is very friendly and loves to play with me. He is very smart and loves to play with me. He is very friendly and loves to play with me. He is very smart and loves to play with me. He is very friendly and loves to play with me. He is very smart and loves to play with me. He is very friendly and loves to play with me. He is very smart and loves to play with me. He is very friendly and loves to play with me. He is very smart and loves to play with me. He is very friendly and loves to play with me. He is very smart and loves to play with me. He is very friendly and loves to play with me. He is very smart and loves to play with me. He is very friendly and loves to play with me. He is very smart and loves to play with me. He is very friendly and loves to play with me. He is very smart and loves to play with me. He is very friendly and loves to play with me. He is very smart and loves to play with me. He is very friendly and loves to play with me. He is very smart and loves to play with me. He is very friendly and loves to play with me. He is very smart and loves to play with me. He is very friendly and loves to play with me. He is very smart and loves to play with me. He is very friendly and loves to play with me. He is very smart and loves to play with me. He is very friendly and loves to play with me. He is very smart and loves to play with me. He is very friendly and loves to play with me. He is very smart and loves to play with me. He is very friendly and loves to play with me. He is very smart and loves to play with me. He is very friendly and loves to play with me. He is very smart and loves to play with me. He is very friendly and loves to play with me. He is very smart and loves to play with me. He is very friendly and loves to play with me. He is very smart and loves to play with me. He is very friendly and loves to play with me. He is very smart and loves to play with me. He is very friendly and loves to play with me. He is very smart and loves to play with me. He is very friendly and loves to play with me. He is very smart and loves to play with me. He is very friendly and loves to play with me. He is very smart and loves to play with me. He is very friendly and loves to play with me. He is very smart and loves to play with me. He is very friendly and loves to play with me. He is very smart and loves to play with me. He is very friendly and loves to play with me. He is very smart and loves to play with me. He is very friendly and loves to play with me. He is very smart and loves to play with me. He is very friendly and loves to play with me. He is very smart and loves to play with me. He is very friendly and loves to play with me. He is very smart and loves to play with me. He is very friendly and loves to play with me. He is very smart and loves to play with me. He is very friendly and loves to play with me. He is very smart and loves to play with me. He is very friendly and loves to play with me. He is very smart and loves to play with me. He is very friendly and loves to play with me. He is very smart and loves to play with me. He is very friendly and loves to play with me. He is very smart and loves to play with me. He is very friendly and loves to play with me. He is very smart and loves to play with me. He is very friendly and loves to play with me. He is very smart and loves to play with me. He is very friendly and loves to play with me. He is very smart and loves to play with me. He is very friendly and loves to play with me. He is very smart and loves to play with me. He is very friendly and loves to play with me. He is very smart and loves to play with me. He is very friendly and loves to play with me. He is very smart and loves to play with me. He is very friendly and loves to play with me. He is very smart and loves to play with me. He is very friendly and loves to play with me. He is very smart and loves to play with me. He is very friendly and loves to play with me. He is very smart and loves to play with me. He is very friendly and loves to play with me. He is very smart and loves to play with me. He is very friendly and loves to play with me. He is very smart and loves to play with me. He is very friendly and loves to play with me. He is very smart and loves to play with me. He is very friendly and loves to play with me. He is very smart and loves to play with me. He is very friendly and loves to play with me. He is very smart and loves to play with me. He is very friendly and loves to play with me. He is very smart and loves to play with me. He is very friendly and loves to play with me. He is very smart and loves to play with me. He is very friendly and loves to play with me. He is very smart and loves to play with me. He is very friendly and loves to play with me. He is very smart and loves to play with me. He is very friendly and loves to play with me. He is very smart and loves to play with me. He is very friendly and loves to play with me. He is very smart and loves to play with me. He is very friendly and loves to play with me. He is very smart and loves to play with me. He is very friendly and loves to play with me. He is very smart and loves to play with me. He is very friendly and loves to play with me. He is very smart and loves to play with me. He is very friendly and loves to play with me. He is very smart and loves to play with me. He is very friendly and loves to play with me. He is very smart and loves to play with me. He is very friendly and loves to play with me. He is very smart and loves to play with me. He is very friendly and loves to play with me. He is very smart and loves to play with me. He is very friendly and loves to play with me. He is very smart and loves to play with me. He is very friendly and loves to play with me. He is very smart and loves to play with me. He is very friendly and loves to play with me. He is very smart and loves to play with me. He is very friendly and loves to play with me. He is very smart and loves to play with me. He is very friendly and loves to play with me. He is very smart and loves to play with me. He is very friendly and loves to play with me. He is very smart and loves to play with me. He is very friendly and loves to play with me. He is very smart and loves to play with me. He is very friendly and loves to play with me. He is very smart and loves to play with me. He is very friendly and loves to play with me. He is very smart and loves to play with me. He is very friendly and loves to play with me. He is very smart and loves to play with me. He is very friendly and loves to play with me. He is very smart and loves to play with me. He is very friendly and loves to play with me. He is very smart and loves to play with me. He is very friendly and loves to play with me. He is very smart and loves to play with me. He is very friendly and loves to play with me. He is very smart and loves to play with me. He is very friendly and loves to play with me. He is very smart and loves to play with me. He is very friendly and loves to play with\n"
]
}
],
"source": [
"base_model = models[0].to(\"cuda\")\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(model_o_path, torch_dtype=torch.float16)\n",
"\n",
"input_text = \"Hello my dog is cute and\"\n",
"inputs = tokenizer(input_text, return_tensors=\"pt\").to(\"cuda\")\n",
"\n",
"with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):\n",
" outputs = base_model.generate(**inputs)\n",
"\n",
"print(tokenizer.decode(outputs[0], skip_special_tokens=True))\n",
"base_model = base_model.to(\"cpu\")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model model.embed_tokens.weight torch.Size([32000, 2048]) torch.float16 cpu True\n",
"Model model.layers.0.self_attn.q_proj.weight torch.Size([2048, 2048]) torch.float16 cpu True\n",
"Model model.layers.0.self_attn.k_proj.weight torch.Size([256, 2048]) torch.float16 cpu True\n",
"Model model.layers.0.self_attn.v_proj.weight torch.Size([256, 2048]) torch.float16 cpu True\n",
"Model model.layers.0.self_attn.o_proj.weight torch.Size([2048, 2048]) torch.float16 cpu True\n",
"Model model.layers.0.mlp.gate_proj.weight torch.Size([5632, 2048]) torch.float16 cpu True\n",
"Model model.layers.0.mlp.up_proj.weight torch.Size([5632, 2048]) torch.float16 cpu True\n",
"Model model.layers.0.mlp.down_proj.weight torch.Size([2048, 5632]) torch.float16 cpu True\n",
"Model model.layers.0.input_layernorm.weight torch.Size([2048]) torch.float16 cpu True\n",
"Model model.layers.0.post_attention_layernorm.weight torch.Size([2048]) torch.float16 cpu True\n",
"Model model.layers.1.self_attn.q_proj.weight torch.Size([2048, 2048]) torch.float16 cpu True\n",
"Model model.layers.1.self_attn.k_proj.weight torch.Size([256, 2048]) torch.float16 cpu True\n",
"Model model.layers.1.self_attn.v_proj.weight torch.Size([256, 2048]) torch.float16 cpu True\n",
"Model model.layers.1.self_attn.o_proj.weight torch.Size([2048, 2048]) torch.float16 cpu True\n",
"Model model.layers.1.mlp.gate_proj.weight torch.Size([5632, 2048]) torch.float16 cpu True\n",
"Model model.layers.1.mlp.up_proj.weight torch.Size([5632, 2048]) torch.float16 cpu True\n",
"Model model.layers.1.mlp.down_proj.weight torch.Size([2048, 5632]) torch.float16 cpu True\n",
"Model model.layers.1.input_layernorm.weight torch.Size([2048]) torch.float16 cpu True\n",
"Model model.layers.1.post_attention_layernorm.weight torch.Size([2048]) torch.float16 cpu True\n",
"Model model.layers.2.self_attn.q_proj.weight torch.Size([2048, 2048]) torch.float16 cpu True\n",
"Model model.layers.2.self_attn.k_proj.weight torch.Size([256, 2048]) torch.float16 cpu True\n",
"Model model.layers.2.self_attn.v_proj.weight torch.Size([256, 2048]) torch.float16 cpu True\n",
"Model model.layers.2.self_attn.o_proj.weight torch.Size([2048, 2048]) torch.float16 cpu True\n",
"Model model.layers.2.mlp.gate_proj.weight torch.Size([5632, 2048]) torch.float16 cpu True\n",
"Model model.layers.2.mlp.up_proj.weight torch.Size([5632, 2048]) torch.float16 cpu True\n",
"Model model.layers.2.mlp.down_proj.weight torch.Size([2048, 5632]) torch.float16 cpu True\n",
"Model model.layers.2.input_layernorm.weight torch.Size([2048]) torch.float16 cpu True\n",
"Model model.layers.2.post_attention_layernorm.weight torch.Size([2048]) torch.float16 cpu True\n",
"Model model.layers.3.self_attn.q_proj.weight torch.Size([2048, 2048]) torch.float16 cpu True\n",
"Model model.layers.3.self_attn.k_proj.weight torch.Size([256, 2048]) torch.float16 cpu True\n",
"Model model.layers.3.self_attn.v_proj.weight torch.Size([256, 2048]) torch.float16 cpu True\n",
"Model model.layers.3.self_attn.o_proj.weight torch.Size([2048, 2048]) torch.float16 cpu True\n",
"Model model.layers.3.mlp.gate_proj.weight torch.Size([5632, 2048]) torch.float16 cpu True\n",
"Model model.layers.3.mlp.up_proj.weight torch.Size([5632, 2048]) torch.float16 cpu True\n",
"Model model.layers.3.mlp.down_proj.weight torch.Size([2048, 5632]) torch.float16 cpu True\n",
"Model model.layers.3.input_layernorm.weight torch.Size([2048]) torch.float16 cpu True\n",
"Model model.layers.3.post_attention_layernorm.weight torch.Size([2048]) torch.float16 cpu True\n",
"Model model.layers.4.self_attn.q_proj.weight torch.Size([2048, 2048]) torch.float16 cpu True\n",
"Model model.layers.4.self_attn.k_proj.weight torch.Size([256, 2048]) torch.float16 cpu True\n",
"Model model.layers.4.self_attn.v_proj.weight torch.Size([256, 2048]) torch.float16 cpu True\n",
"Model model.layers.4.self_attn.o_proj.weight torch.Size([2048, 2048]) torch.float16 cpu True\n",
"Model model.layers.4.mlp.gate_proj.weight torch.Size([5632, 2048]) torch.float16 cpu True\n",
"Model model.layers.4.mlp.up_proj.weight torch.Size([5632, 2048]) torch.float16 cpu True\n",
"Model model.layers.4.mlp.down_proj.weight torch.Size([2048, 5632]) torch.float16 cpu True\n",
"Model model.layers.4.input_layernorm.weight torch.Size([2048]) torch.float16 cpu True\n",
"Model model.layers.4.post_attention_layernorm.weight torch.Size([2048]) torch.float16 cpu True\n",
"Model model.layers.5.self_attn.q_proj.weight torch.Size([2048, 2048]) torch.float16 cpu True\n",
"Model model.layers.5.self_attn.k_proj.weight torch.Size([256, 2048]) torch.float16 cpu True\n",
"Model model.layers.5.self_attn.v_proj.weight torch.Size([256, 2048]) torch.float16 cpu True\n",
"Model model.layers.5.self_attn.o_proj.weight torch.Size([2048, 2048]) torch.float16 cpu True\n",
"Model model.layers.5.mlp.gate_proj.weight torch.Size([5632, 2048]) torch.float16 cpu True\n",
"Model model.layers.5.mlp.up_proj.weight torch.Size([5632, 2048]) torch.float16 cpu True\n",
"Model model.layers.5.mlp.down_proj.weight torch.Size([2048, 5632]) torch.float16 cpu True\n",
"Model model.layers.5.input_layernorm.weight torch.Size([2048]) torch.float16 cpu True\n",
"Model model.layers.5.post_attention_layernorm.weight torch.Size([2048]) torch.float16 cpu True\n",
"Model model.layers.6.self_attn.q_proj.weight torch.Size([2048, 2048]) torch.float16 cpu True\n",
"Model model.layers.6.self_attn.k_proj.weight torch.Size([256, 2048]) torch.float16 cpu True\n",
"Model model.layers.6.self_attn.v_proj.weight torch.Size([256, 2048]) torch.float16 cpu True\n",
"Model model.layers.6.self_attn.o_proj.weight torch.Size([2048, 2048]) torch.float16 cpu True\n",
"Model model.layers.6.mlp.gate_proj.weight torch.Size([5632, 2048]) torch.float16 cpu True\n",
"Model model.layers.6.mlp.up_proj.weight torch.Size([5632, 2048]) torch.float16 cpu True\n",
"Model model.layers.6.mlp.down_proj.weight torch.Size([2048, 5632]) torch.float16 cpu True\n",
"Model model.layers.6.input_layernorm.weight torch.Size([2048]) torch.float16 cpu True\n",
"Model model.layers.6.post_attention_layernorm.weight torch.Size([2048]) torch.float16 cpu True\n",
"Model model.layers.7.self_attn.q_proj.weight torch.Size([2048, 2048]) torch.float16 cpu True\n",
"Model model.layers.7.self_attn.k_proj.weight torch.Size([256, 2048]) torch.float16 cpu True\n",
"Model model.layers.7.self_attn.v_proj.weight torch.Size([256, 2048]) torch.float16 cpu True\n",
"Model model.layers.7.self_attn.o_proj.weight torch.Size([2048, 2048]) torch.float16 cpu True\n",
"Model model.layers.7.mlp.gate_proj.weight torch.Size([5632, 2048]) torch.float16 cpu True\n",
"Model model.layers.7.mlp.up_proj.weight torch.Size([5632, 2048]) torch.float16 cpu True\n",
"Model model.layers.7.mlp.down_proj.weight torch.Size([2048, 5632]) torch.float16 cpu True\n",
"Model model.layers.7.input_layernorm.weight torch.Size([2048]) torch.float16 cpu True\n",
"Model model.layers.7.post_attention_layernorm.weight torch.Size([2048]) torch.float16 cpu True\n",
"Model model.layers.8.self_attn.q_proj.weight torch.Size([2048, 2048]) torch.float16 cpu True\n",
"Model model.layers.8.self_attn.k_proj.weight torch.Size([256, 2048]) torch.float16 cpu True\n",
"Model model.layers.8.self_attn.v_proj.weight torch.Size([256, 2048]) torch.float16 cpu True\n",
"Model model.layers.8.self_attn.o_proj.weight torch.Size([2048, 2048]) torch.float16 cpu True\n",
"Model model.layers.8.mlp.gate_proj.weight torch.Size([5632, 2048]) torch.float16 cpu True\n",
"Model model.layers.8.mlp.up_proj.weight torch.Size([5632, 2048]) torch.float16 cpu True\n",
"Model model.layers.8.mlp.down_proj.weight torch.Size([2048, 5632]) torch.float16 cpu True\n",
"Model model.layers.8.input_layernorm.weight torch.Size([2048]) torch.float16 cpu True\n",
"Model model.layers.8.post_attention_layernorm.weight torch.Size([2048]) torch.float16 cpu True\n",
"Model model.layers.9.self_attn.q_proj.weight torch.Size([2048, 2048]) torch.float16 cpu True\n",
"Model model.layers.9.self_attn.k_proj.weight torch.Size([256, 2048]) torch.float16 cpu True\n",
"Model model.layers.9.self_attn.v_proj.weight torch.Size([256, 2048]) torch.float16 cpu True\n",
"Model model.layers.9.self_attn.o_proj.weight torch.Size([2048, 2048]) torch.float16 cpu True\n",
"Model model.layers.9.mlp.gate_proj.weight torch.Size([5632, 2048]) torch.float16 cpu True\n",
"Model model.layers.9.mlp.up_proj.weight torch.Size([5632, 2048]) torch.float16 cpu True\n",
"Model model.layers.9.mlp.down_proj.weight torch.Size([2048, 5632]) torch.float16 cpu True\n",
"Model model.layers.9.input_layernorm.weight torch.Size([2048]) torch.float16 cpu True\n",
"Model model.layers.9.post_attention_layernorm.weight torch.Size([2048]) torch.float16 cpu True\n",
"Model model.layers.10.self_attn.q_proj.weight torch.Size([2048, 2048]) torch.float16 cpu True\n",
"Model model.layers.10.self_attn.k_proj.weight torch.Size([256, 2048]) torch.float16 cpu True\n",
"Model model.layers.10.self_attn.v_proj.weight torch.Size([256, 2048]) torch.float16 cpu True\n",
"Model model.layers.10.self_attn.o_proj.weight torch.Size([2048, 2048]) torch.float16 cpu True\n",
"Model model.layers.10.mlp.gate_proj.weight torch.Size([5632, 2048]) torch.float16 cpu True\n",
"Model model.layers.10.mlp.up_proj.weight torch.Size([5632, 2048]) torch.float16 cpu True\n",
"Model model.layers.10.mlp.down_proj.weight torch.Size([2048, 5632]) torch.float16 cpu True\n",
"Model model.layers.10.input_layernorm.weight torch.Size([2048]) torch.float16 cpu True\n",
"Model model.layers.10.post_attention_layernorm.weight torch.Size([2048]) torch.float16 cpu True\n",
"Model model.layers.11.self_attn.q_proj.weight torch.Size([2048, 2048]) torch.float16 cpu True\n",
"Model model.layers.11.self_attn.k_proj.weight torch.Size([256, 2048]) torch.float16 cpu True\n",
"Model model.layers.11.self_attn.v_proj.weight torch.Size([256, 2048]) torch.float16 cpu True\n",
"Model model.layers.11.self_attn.o_proj.weight torch.Size([2048, 2048]) torch.float16 cpu True\n",
"Model model.layers.11.mlp.gate_proj.weight torch.Size([5632, 2048]) torch.float16 cpu True\n",
"Model model.layers.11.mlp.up_proj.weight torch.Size([5632, 2048]) torch.float16 cpu True\n",
"Model model.layers.11.mlp.down_proj.weight torch.Size([2048, 5632]) torch.float16 cpu True\n",
"Model model.layers.11.input_layernorm.weight torch.Size([2048]) torch.float16 cpu True\n",
"Model model.layers.11.post_attention_layernorm.weight torch.Size([2048]) torch.float16 cpu True\n",
"Model model.layers.12.self_attn.q_proj.weight torch.Size([2048, 2048]) torch.float16 cpu True\n",
"Model model.layers.12.self_attn.k_proj.weight torch.Size([256, 2048]) torch.float16 cpu True\n",
"Model model.layers.12.self_attn.v_proj.weight torch.Size([256, 2048]) torch.float16 cpu True\n",
"Model model.layers.12.self_attn.o_proj.weight torch.Size([2048, 2048]) torch.float16 cpu True\n",
"Model model.layers.12.mlp.gate_proj.weight torch.Size([5632, 2048]) torch.float16 cpu True\n",
"Model model.layers.12.mlp.up_proj.weight torch.Size([5632, 2048]) torch.float16 cpu True\n",
"Model model.layers.12.mlp.down_proj.weight torch.Size([2048, 5632]) torch.float16 cpu True\n",
"Model model.layers.12.input_layernorm.weight torch.Size([2048]) torch.float16 cpu True\n",
"Model model.layers.12.post_attention_layernorm.weight torch.Size([2048]) torch.float16 cpu True\n",
"Model model.layers.13.self_attn.q_proj.weight torch.Size([2048, 2048]) torch.float16 cpu True\n",
"Model model.layers.13.self_attn.k_proj.weight torch.Size([256, 2048]) torch.float16 cpu True\n",
"Model model.layers.13.self_attn.v_proj.weight torch.Size([256, 2048]) torch.float16 cpu True\n",
"Model model.layers.13.self_attn.o_proj.weight torch.Size([2048, 2048]) torch.float16 cpu True\n",
"Model model.layers.13.mlp.gate_proj.weight torch.Size([5632, 2048]) torch.float16 cpu True\n",
"Model model.layers.13.mlp.up_proj.weight torch.Size([5632, 2048]) torch.float16 cpu True\n",
"Model model.layers.13.mlp.down_proj.weight torch.Size([2048, 5632]) torch.float16 cpu True\n",
"Model model.layers.13.input_layernorm.weight torch.Size([2048]) torch.float16 cpu True\n",
"Model model.layers.13.post_attention_layernorm.weight torch.Size([2048]) torch.float16 cpu True\n",
"Model model.layers.14.self_attn.q_proj.weight torch.Size([2048, 2048]) torch.float16 cpu True\n",
"Model model.layers.14.self_attn.k_proj.weight torch.Size([256, 2048]) torch.float16 cpu True\n",
"Model model.layers.14.self_attn.v_proj.weight torch.Size([256, 2048]) torch.float16 cpu True\n",
"Model model.layers.14.self_attn.o_proj.weight torch.Size([2048, 2048]) torch.float16 cpu True\n",
"Model model.layers.14.mlp.gate_proj.weight torch.Size([5632, 2048]) torch.float16 cpu True\n",
"Model model.layers.14.mlp.up_proj.weight torch.Size([5632, 2048]) torch.float16 cpu True\n",
"Model model.layers.14.mlp.down_proj.weight torch.Size([2048, 5632]) torch.float16 cpu True\n",
"Model model.layers.14.input_layernorm.weight torch.Size([2048]) torch.float16 cpu True\n",
"Model model.layers.14.post_attention_layernorm.weight torch.Size([2048]) torch.float16 cpu True\n",
"Model model.layers.15.self_attn.q_proj.weight torch.Size([2048, 2048]) torch.float16 cpu True\n",
"Model model.layers.15.self_attn.k_proj.weight torch.Size([256, 2048]) torch.float16 cpu True\n",
"Model model.layers.15.self_attn.v_proj.weight torch.Size([256, 2048]) torch.float16 cpu True\n",
"Model model.layers.15.self_attn.o_proj.weight torch.Size([2048, 2048]) torch.float16 cpu True\n",
"Model model.layers.15.mlp.gate_proj.weight torch.Size([5632, 2048]) torch.float16 cpu True\n",
"Model model.layers.15.mlp.up_proj.weight torch.Size([5632, 2048]) torch.float16 cpu True\n",
"Model model.layers.15.mlp.down_proj.weight torch.Size([2048, 5632]) torch.float16 cpu True\n",
"Model model.layers.15.input_layernorm.weight torch.Size([2048]) torch.float16 cpu True\n",
"Model model.layers.15.post_attention_layernorm.weight torch.Size([2048]) torch.float16 cpu True\n",
"Model model.layers.16.self_attn.q_proj.weight torch.Size([2048, 2048]) torch.float16 cpu True\n",
"Model model.layers.16.self_attn.k_proj.weight torch.Size([256, 2048]) torch.float16 cpu True\n",
"Model model.layers.16.self_attn.v_proj.weight torch.Size([256, 2048]) torch.float16 cpu True\n",
"Model model.layers.16.self_attn.o_proj.weight torch.Size([2048, 2048]) torch.float16 cpu True\n",
"Model model.layers.16.mlp.gate_proj.weight torch.Size([5632, 2048]) torch.float16 cpu True\n",
"Model model.layers.16.mlp.up_proj.weight torch.Size([5632, 2048]) torch.float16 cpu True\n",
"Model model.layers.16.mlp.down_proj.weight torch.Size([2048, 5632]) torch.float16 cpu True\n",
"Model model.layers.16.input_layernorm.weight torch.Size([2048]) torch.float16 cpu True\n",
"Model model.layers.16.post_attention_layernorm.weight torch.Size([2048]) torch.float16 cpu True\n",
"Model model.layers.17.self_attn.q_proj.weight torch.Size([2048, 2048]) torch.float16 cpu True\n",
"Model model.layers.17.self_attn.k_proj.weight torch.Size([256, 2048]) torch.float16 cpu True\n",
"Model model.layers.17.self_attn.v_proj.weight torch.Size([256, 2048]) torch.float16 cpu True\n",
"Model model.layers.17.self_attn.o_proj.weight torch.Size([2048, 2048]) torch.float16 cpu True\n",
"Model model.layers.17.mlp.gate_proj.weight torch.Size([5632, 2048]) torch.float16 cpu True\n",
"Model model.layers.17.mlp.up_proj.weight torch.Size([5632, 2048]) torch.float16 cpu True\n",
"Model model.layers.17.mlp.down_proj.weight torch.Size([2048, 5632]) torch.float16 cpu True\n",
"Model model.layers.17.input_layernorm.weight torch.Size([2048]) torch.float16 cpu True\n",
"Model model.layers.17.post_attention_layernorm.weight torch.Size([2048]) torch.float16 cpu True\n",
"Model model.layers.18.self_attn.q_proj.weight torch.Size([2048, 2048]) torch.float16 cpu True\n",
"Model model.layers.18.self_attn.k_proj.weight torch.Size([256, 2048]) torch.float16 cpu True\n",
"Model model.layers.18.self_attn.v_proj.weight torch.Size([256, 2048]) torch.float16 cpu True\n",
"Model model.layers.18.self_attn.o_proj.weight torch.Size([2048, 2048]) torch.float16 cpu True\n",
"Model model.layers.18.mlp.gate_proj.weight torch.Size([5632, 2048]) torch.float16 cpu True\n",
"Model model.layers.18.mlp.up_proj.weight torch.Size([5632, 2048]) torch.float16 cpu True\n",
"Model model.layers.18.mlp.down_proj.weight torch.Size([2048, 5632]) torch.float16 cpu True\n",
"Model model.layers.18.input_layernorm.weight torch.Size([2048]) torch.float16 cpu True\n",
"Model model.layers.18.post_attention_layernorm.weight torch.Size([2048]) torch.float16 cpu True\n",
"Model model.layers.19.self_attn.q_proj.weight torch.Size([2048, 2048]) torch.float16 cpu True\n",
"Model model.layers.19.self_attn.k_proj.weight torch.Size([256, 2048]) torch.float16 cpu True\n",
"Model model.layers.19.self_attn.v_proj.weight torch.Size([256, 2048]) torch.float16 cpu True\n",
"Model model.layers.19.self_attn.o_proj.weight torch.Size([2048, 2048]) torch.float16 cpu True\n",
"Model model.layers.19.mlp.gate_proj.weight torch.Size([5632, 2048]) torch.float16 cpu True\n",
"Model model.layers.19.mlp.up_proj.weight torch.Size([5632, 2048]) torch.float16 cpu True\n",
"Model model.layers.19.mlp.down_proj.weight torch.Size([2048, 5632]) torch.float16 cpu True\n",
"Model model.layers.19.input_layernorm.weight torch.Size([2048]) torch.float16 cpu True\n",
"Model model.layers.19.post_attention_layernorm.weight torch.Size([2048]) torch.float16 cpu True\n",
"Model model.layers.20.self_attn.q_proj.weight torch.Size([2048, 2048]) torch.float16 cpu True\n",
"Model model.layers.20.self_attn.k_proj.weight torch.Size([256, 2048]) torch.float16 cpu True\n",
"Model model.layers.20.self_attn.v_proj.weight torch.Size([256, 2048]) torch.float16 cpu True\n",
"Model model.layers.20.self_attn.o_proj.weight torch.Size([2048, 2048]) torch.float16 cpu True\n",
"Model model.layers.20.mlp.gate_proj.weight torch.Size([5632, 2048]) torch.float16 cpu True\n",
"Model model.layers.20.mlp.up_proj.weight torch.Size([5632, 2048]) torch.float16 cpu True\n",
"Model model.layers.20.mlp.down_proj.weight torch.Size([2048, 5632]) torch.float16 cpu True\n",
"Model model.layers.20.input_layernorm.weight torch.Size([2048]) torch.float16 cpu True\n",
"Model model.layers.20.post_attention_layernorm.weight torch.Size([2048]) torch.float16 cpu True\n",
"Model model.layers.21.self_attn.q_proj.weight torch.Size([2048, 2048]) torch.float16 cpu True\n",
"Model model.layers.21.self_attn.k_proj.weight torch.Size([256, 2048]) torch.float16 cpu True\n",
"Model model.layers.21.self_attn.v_proj.weight torch.Size([256, 2048]) torch.float16 cpu True\n",
"Model model.layers.21.self_attn.o_proj.weight torch.Size([2048, 2048]) torch.float16 cpu True\n",
"Model model.layers.21.mlp.gate_proj.weight torch.Size([5632, 2048]) torch.float16 cpu True\n",
"Model model.layers.21.mlp.up_proj.weight torch.Size([5632, 2048]) torch.float16 cpu True\n",
"Model model.layers.21.mlp.down_proj.weight torch.Size([2048, 5632]) torch.float16 cpu True\n",
"Model model.layers.21.input_layernorm.weight torch.Size([2048]) torch.float16 cpu True\n",
"Model model.layers.21.post_attention_layernorm.weight torch.Size([2048]) torch.float16 cpu True\n",
"Model model.norm.weight torch.Size([2048]) torch.float16 cpu True\n",
"Model lm_head.weight torch.Size([32000, 2048]) torch.float16 cpu True\n"
]
}
],
"source": [
"from daremerge.ddare.merge import merge_tensors\n",
"from daremerge.ddare.tensor import dare_ties_sparsification, relative_norm, divide_tensor_into_sets\n",
"from daremerge.ddare.util import get_device\n",
"import re\n",
"from typing import Dict, Tuple\n",
"\n",
"def get_layer_type(k : str) -> Tuple[int, str]:\n",
" matcher = re.compile(r\"model.layers.(\\d+).(.+)\")\n",
"\n",
" m = matcher.match(k)\n",
" if m is None:\n",
" if \"model.norm.weight\" == k:\n",
" return -1, \"norm\"\n",
" if \"model.embed_tokens.weight\" == k:\n",
" return -1, \"embed\"\n",
" if \"lm_head.weight\" == k:\n",
" return -1, \"head\"\n",
" print(f\"Unknown key {k}\")\n",
" return -1, \"unknown\"\n",
" return int(m.group(1)), m.group(2)\n",
"\n",
"config = {}\n",
"result_dict : Dict[str, torch.Tensor] = {}\n",
"device = get_device()\n",
"for k in keys:\n",
" block, layer_type = get_layer_type(k)\n",
" m0 : torch.Tensor = models[0].state_dict()[k]\n",
" result = m0.clone()\n",
" print(f\"Model {k} {m0.shape} {m0.dtype} {m0.device} {m0.is_contiguous()}\")\n",
" sets = divide_tensor_into_sets(tensor=m0, n_sets=4)\n",
" m = [\n",
" models[1].state_dict()[k],\n",
" models[2].state_dict()[k],\n",
" models[3].state_dict()[k],\n",
" models[4].state_dict()[k],\n",
" ]\n",
"\n",
" ratio = {\n",
" 'to_q': 0.0,\n",
" 'to_k': 0.0,\n",
" 'to_v': 0.0,\n",
" }.get(layer_type, .5)\n",
"\n",
" for i, tensor in enumerate(m):\n",
" norm_ratio = 1.0\n",
" if layer_type == \"to_k\":\n",
" # Get to_q key\n",
" q_base = models[0].state_dict()[k.replace(\"to_k\", \"to_q\")]\n",
" q_merge = models[i].state_dict()[k.replace(\"to_k\", \"to_q\")]\n",
" scale = relative_norm(q_merge, q_base)\n",
" tensor = tensor.to(device) / scale\n",
" del scale\n",
" elif layer_type == \"to_q\":\n",
" scale = relative_norm(tensor, m0)\n",
" tensor = tensor.to(device) * scale\n",
" del scale\n",
" slice_mask = (sets == i).bool()\n",
" new_tensor = dare_ties_sparsification(model_a_param=m0, model_b_param=tensor, drop_rate=0.68, ties=\"sum\", rescale=\"off\", device=device, **config)\n",
" new_tensor = merge_tensors(\"slerp\", m0, tensor, ratio)\n",
" result = torch.where(slice_mask, new_tensor, result)\n",
" del new_tensor, slice_mask\n",
" \n",
" result_dict[k] = result\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"model_out_path = \"../llama/raw/TinyLlama-QuadMerge-Alpha\"\n",
"model = transformers.AutoModelForCausalLM.from_pretrained(model_o_path, **config)\n",
"model.state_dict = lambda : result_dict\n",
"model.save_pretrained(model_out_path)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Hello my dog is cute and loves to play with other dogs. He is very friendly and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very playful and loves to play with other dogs. He is very\n"
]
}
],
"source": [
"model = AutoModelForCausalLM.from_pretrained(model_out_path, torch_dtype=torch.float16).to(\"cuda\")\n",
"\n",
"input_text = \"Hello my dog is cute and\"\n",
"inputs = tokenizer(input_text, return_tensors=\"pt\").to(\"cuda\")\n",
"\n",
"with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):\n",
" outputs = model.generate(**inputs)\n",
"\n",
"print(tokenizer.decode(outputs[0], skip_special_tokens=True))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "jupyterenv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment