Skip to content

Instantly share code, notes, and snippets.

@joelburget
Created June 9, 2024 17:04
Show Gist options
  • Save joelburget/bae5ea4d997e804b2a65d02d5b61f5bc to your computer and use it in GitHub Desktop.
Save joelburget/bae5ea4d997e804b2a65d02d5b61f5bc to your computer and use it in GitHub Desktop.
Mixtral poking
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "4fb7e0bc-4ef5-40c8-8222-336e83bd6e66",
"metadata": {},
"outputs": [],
"source": [
"%pip install transformers huggingface_hub"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "be241e96-3bbb-46a4-a4d4-0213eb094d6e",
"metadata": {},
"outputs": [],
"source": [
"%pip install git+https://github.com/TransformerLensOrg/TransformerLens.git"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "6d7341d8-881c-41c3-8199-ae9590d51a5a",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "8ea54c24323c4e04b3890f01257e6549",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"VBox(children=(HTML(value='<center> <img\\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from huggingface_hub import login\n",
"login()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "cba8adb4-03a4-4061-b62b-18bcc091b8af",
"metadata": {},
"outputs": [],
"source": [
"import einops\n",
"model_id = \"mistralai/Mixtral-8x7B-v0.1\"\n",
"text = \"Hello my name is\""
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "2c3cb338-cf1b-4775-b278-302999164e6a",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/conda/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
" warnings.warn(\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "b256c265e0664a47a55c4ef269aba32a",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Loading checkpoint shards: 0%| | 0/19 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/conda/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
" warnings.warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loaded pretrained model mistralai/Mixtral-8x7B-v0.1 into HookedTransformer\n"
]
}
],
"source": [
"from transformer_lens import HookedTransformer\n",
"from transformers import AutoModelForCausalLM, AutoTokenizer\n",
"\n",
"tl_model = HookedTransformer.from_pretrained_no_processing(\n",
" \"mistralai/Mixtral-8x7B-v0.1\"\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "617609fe-060e-48b5-9daf-cb45cb6cc5d2",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'Hello my name is Daniel. and butorr on2 for6 to amongst24 to to forfor so in has stuck to extreme continues to,)wh on” we have for every in bag to for in! to every to8 for for around'"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tl_model.generate(\n",
" text,\n",
" verbose=False,\n",
" max_new_tokens=50,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "4f3ccaf0-1650-4621-84f4-dbcdef132447",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/conda/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
" warnings.warn(\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "0b50aa852274439a963f8594be58e4cf",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Loading checkpoint shards: 0%| | 0/19 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from transformers import AutoModelForCausalLM, AutoTokenizer\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
"hf_model = AutoModelForCausalLM.from_pretrained(model_id)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "97cec596-b5fc-48c2-82db-08af7e238a0b",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Hello my name is Katie and I am a 20 year old student from the UK. I am currently studying a degree in English Literature and Creative Writing at the University of Winchester. I have always had a passion for writing and I am hoping to pursue\n"
]
}
],
"source": [
"inputs = tokenizer(text, return_tensors=\"pt\")\n",
"outputs = hf_model.generate(**inputs, max_new_tokens=50)\n",
"print(tokenizer.decode(outputs[0], skip_special_tokens=True))"
]
},
{
"cell_type": "code",
"execution_count": 45,
"id": "1df1c967-4dfd-41ef-bab4-dea6072678aa",
"metadata": {},
"outputs": [],
"source": [
"from torch.testing import assert_close\n",
"import torch"
]
},
{
"cell_type": "code",
"execution_count": 48,
"id": "cfa7d1c3-72e8-45b3-850c-091669f48b1e",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(True)"
]
},
"execution_count": 48,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.all(\n",
" einops.rearrange(tl_model.blocks[0].attn.W_Q, \"n m h -> (n h) m\") ==\n",
" hf_model.model.layers[0].self_attn.q_proj.weight\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 106,
"id": "83649934-f06b-4f94-8004-59b8d4098589",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([32, 4096, 128]), torch.Size([1024, 4096]))"
]
},
"execution_count": 106,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tl_model.blocks[0].attn.W_K.shape, hf_model.model.layers[0].self_attn.k_proj.weight.shape"
]
},
{
"cell_type": "code",
"execution_count": 104,
"id": "4fa20cf5-b720-4946-a7e5-e1d2e6277f6c",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(True)"
]
},
"execution_count": 104,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.all(\n",
" einops.reduce(\n",
" tl_model.blocks[0].attn.W_K, \"(n repeat) m h -> (n h) m\",\n",
" 'max',\n",
" n=tl_model.cfg.n_key_value_heads,\n",
" repeat=4) ==\n",
" hf_model.model.layers[0].self_attn.k_proj.weight\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 105,
"id": "ef6f7ea9-ef0b-4091-8d00-504b481fc59a",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(True)"
]
},
"execution_count": 105,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.all(\n",
" einops.reduce(\n",
" tl_model.blocks[0].attn.W_V, \"(n repeat) m h -> (n h) m\",\n",
" 'max',\n",
" n=tl_model.cfg.n_key_value_heads,\n",
" repeat=4) ==\n",
" hf_model.model.layers[0].self_attn.v_proj.weight\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 50,
"id": "04b8f4be-ce7d-4dc2-acda-d023c721525c",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(True)"
]
},
"execution_count": 50,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.all(\n",
" einops.rearrange(tl_model.blocks[0].attn.W_O, \"n h m -> m (n h)\") ==\n",
" hf_model.model.layers[0].self_attn.o_proj.weight\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 29,
"id": "1e10ed87-31b5-4c1c-b726-7a3f49fbd136",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Parameter containing:\n",
"tensor([[0., 0., 0., ..., 0., 0., 0.],\n",
" [0., 0., 0., ..., 0., 0., 0.],\n",
" [0., 0., 0., ..., 0., 0., 0.],\n",
" ...,\n",
" [0., 0., 0., ..., 0., 0., 0.],\n",
" [0., 0., 0., ..., 0., 0., 0.],\n",
" [0., 0., 0., ..., 0., 0., 0.]], requires_grad=True)"
]
},
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tl_model.blocks[0].attn.b_Q"
]
},
{
"cell_type": "code",
"execution_count": 46,
"id": "6caf9d98-adb2-45e7-8357-34288b2156f2",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(True)"
]
},
"execution_count": 46,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.all(hf_model.model.layers[0].block_sparse_moe.gate.weight.T == tl_model.blocks[0].mlp.W_gate)"
]
},
{
"cell_type": "code",
"execution_count": 114,
"id": "00e9ea5d-74c2-4c2a-8e9d-6fc196cb8fc3",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.float32, torch.float32)"
]
},
"execution_count": 114,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"hf_model.model.layers[0].block_sparse_moe.gate.weight.dtype, tl_model.blocks[0].mlp.W_gate.dtype"
]
},
{
"cell_type": "code",
"execution_count": 51,
"id": "44a55507-e639-414a-a297-e68e1c0696f9",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(True)"
]
},
"execution_count": 51,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.all(\n",
" tl_model.blocks[0].mlp.experts[0].W_in ==\n",
" hf_model.model.layers[0].block_sparse_moe.experts[0].w3.weight.T\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 56,
"id": "03944deb-aa8d-46ff-83dd-4f7ee955656c",
"metadata": {},
"outputs": [],
"source": [
"test_tensor = torch.randn((1, 1, 4096,))"
]
},
{
"cell_type": "code",
"execution_count": 65,
"id": "eb0109ee-b82a-4ea0-b50b-8e6408647cea",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(False)"
]
},
"execution_count": 65,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.all(\n",
" hf_model.model.layers[0].block_sparse_moe(test_tensor)[0] ==\n",
" tl_model.blocks[0].mlp(test_tensor)\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 63,
"id": "25ce75bf-706e-4ae8-8f74-bc9c40e88c25",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[ 0.3826, 0.0153, 0.0993, ..., -0.2474, 0.4459, -0.3026]]],\n",
" grad_fn=<ReshapeAliasBackward0>)"
]
},
"execution_count": 63,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"hf_model.model.layers[0].block_sparse_moe(test_tensor)[0]"
]
},
{
"cell_type": "code",
"execution_count": 62,
"id": "c016430e-0a30-426b-bfd0-0b1b423b3ff6",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[ 0.3826, 0.0153, 0.0993, ..., -0.2474, 0.4459, -0.3026]]],\n",
" grad_fn=<IndexPutBackward0>)"
]
},
"execution_count": 62,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tl_model.blocks[0].mlp(test_tensor)"
]
},
{
"cell_type": "code",
"execution_count": 107,
"id": "46353486-0a3f-4241-9cf5-ed25c7539f71",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([1, 1, 4096])"
]
},
"execution_count": 107,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tl_model.blocks[0].mlp(test_tensor).shape"
]
},
{
"cell_type": "code",
"execution_count": 67,
"id": "e25ada54-4e3c-42b7-8f35-ba67bfa500e3",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[False, False, False, ..., False, False, False]]])"
]
},
"execution_count": 67,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"hf_model.model.layers[0].block_sparse_moe(test_tensor)[0] == tl_model.blocks[0].mlp(test_tensor)"
]
},
{
"cell_type": "code",
"execution_count": 83,
"id": "8f3a2865-645d-4441-95fb-32446f866760",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(89)"
]
},
"execution_count": 83,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.sum(hf_model.model.layers[0].block_sparse_moe(test_tensor)[0] == tl_model.blocks[0].mlp(test_tensor))"
]
},
{
"cell_type": "code",
"execution_count": 72,
"id": "ac306e1c-9972-466a-8f4a-f3eb56042f53",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.38261502981185913"
]
},
"execution_count": 72,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"hf_model.model.layers[0].block_sparse_moe(test_tensor)[0][0, 0, 0].item()"
]
},
{
"cell_type": "code",
"execution_count": 73,
"id": "e9481397-6e87-435a-a0cf-ef409630d17c",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.38261523842811584"
]
},
"execution_count": 73,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tl_model.blocks[0].mlp(test_tensor)[0, 0, 0].item()"
]
},
{
"cell_type": "code",
"execution_count": 74,
"id": "89f55163-fa2e-44f2-b112-7b052a6e85af",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"MixtralAttention(\n",
" (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
" (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
" (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
" (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
" (rotary_emb): MixtralRotaryEmbedding()\n",
")"
]
},
"execution_count": 74,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"hf_model.model.layers[0].self_attn"
]
},
{
"cell_type": "code",
"execution_count": 76,
"id": "f0186768-d8f0-4d55-a94c-606c4ba3f7ca",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(tensor([[[-0.6824, 2.0180, -1.6793, ..., -1.3551, -0.2033, -0.8247]]],\n",
" grad_fn=<AddBackward0>),)"
]
},
"execution_count": 76,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"hf_model.model.layers[0](test_tensor)"
]
},
{
"cell_type": "code",
"execution_count": 77,
"id": "432bb274-b499-44c9-98d1-777d03425daa",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[-0.6824, 2.0180, -1.6793, ..., -1.3551, -0.2033, -0.8247]]],\n",
" grad_fn=<AddBackward0>)"
]
},
"execution_count": 77,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tl_model.blocks[0](test_tensor)"
]
},
{
"cell_type": "code",
"execution_count": 79,
"id": "4a440811-e7f0-4092-b8e7-f7cac80dc84a",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[ True, True, True, ..., True, False, True]]])"
]
},
"execution_count": 79,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"hf_model.model.layers[0](test_tensor)[0] == tl_model.blocks[0](test_tensor)"
]
},
{
"cell_type": "code",
"execution_count": 110,
"id": "8ed65bb3-6990-48e5-9ef2-1becd9dfaffc",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(-0.20334099233150482, -0.20334100723266602)"
]
},
"execution_count": 110,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"hf_model.model.layers[0](test_tensor)[0][0, 0, -2].item(), tl_model.blocks[0](test_tensor)[0, 0, -2].item()"
]
},
{
"cell_type": "code",
"execution_count": 80,
"id": "763f6c2e-b71f-4724-b2f7-f79a9ab29caf",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(3218)"
]
},
"execution_count": 80,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.sum(hf_model.model.layers[0](test_tensor)[0] == tl_model.blocks[0](test_tensor))"
]
},
{
"cell_type": "code",
"execution_count": 84,
"id": "5172efa2-0066-4ae0-a6a2-530d815b053b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[-0.2402, -0.0474, 0.0492, ..., -0.3049, 0.0006, 0.1522]]],\n",
" grad_fn=<AddBackward0>)"
]
},
"execution_count": 84,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tl_model.blocks[0].attn.forward(test_tensor, test_tensor, test_tensor)"
]
},
{
"cell_type": "code",
"execution_count": 86,
"id": "92781a06-e16d-43f9-be4c-3ef04b3d4b08",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[-0.2402, -0.0474, 0.0492, ..., -0.3049, 0.0006, 0.1522]]],\n",
" grad_fn=<UnsafeViewBackward0>)"
]
},
"execution_count": 86,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"hf_model.model.layers[0].self_attn.forward(test_tensor)[0]"
]
},
{
"cell_type": "code",
"execution_count": 88,
"id": "943cd506-2bb8-45bf-afc7-7f6b4f8043f1",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[False, False, False, ..., False, False, True]]])"
]
},
"execution_count": 88,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"(tl_model.blocks[0].attn.forward(test_tensor, test_tensor, test_tensor) == \n",
" hf_model.model.layers[0].self_attn.forward(test_tensor)[0])"
]
},
{
"cell_type": "code",
"execution_count": 89,
"id": "57ffc181-abed-4784-86eb-6e6b4f174bc5",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(254)"
]
},
"execution_count": 89,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.sum(tl_model.blocks[0].attn.forward(test_tensor, test_tensor, test_tensor) == \n",
" hf_model.model.layers[0].self_attn.forward(test_tensor)[0])"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment