Skip to content

Instantly share code, notes, and snippets.

Created June 9, 2024 17:04
Show Gist options
  • Save joelburget/bae5ea4d997e804b2a65d02d5b61f5bc to your computer and use it in GitHub Desktop.
Save joelburget/bae5ea4d997e804b2a65d02d5b61f5bc to your computer and use it in GitHub Desktop.
Mixtral poking
Display the source blob
Display the rendered blob
"cells": [
"cell_type": "code",
"execution_count": null,
"id": "4fb7e0bc-4ef5-40c8-8222-336e83bd6e66",
"metadata": {},
"outputs": [],
"source": [
"%pip install transformers huggingface_hub"
"cell_type": "code",
"execution_count": null,
"id": "be241e96-3bbb-46a4-a4d4-0213eb094d6e",
"metadata": {},
"outputs": [],
"source": [
"%pip install git+"
"cell_type": "code",
"execution_count": 1,
"id": "6d7341d8-881c-41c3-8199-ae9590d51a5a",
"metadata": {},
"outputs": [
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "8ea54c24323c4e04b3890f01257e6549",
"version_major": 2,
"version_minor": 0
"text/plain": [
"VBox(children=(HTML(value='<center> <img\\nsrc=…"
"metadata": {},
"output_type": "display_data"
"source": [
"from huggingface_hub import login\n",
"cell_type": "code",
"execution_count": 3,
"id": "cba8adb4-03a4-4061-b62b-18bcc091b8af",
"metadata": {},
"outputs": [],
"source": [
"import einops\n",
"model_id = \"mistralai/Mixtral-8x7B-v0.1\"\n",
"text = \"Hello my name is\""
"cell_type": "code",
"execution_count": 1,
"id": "2c3cb338-cf1b-4775-b278-302999164e6a",
"metadata": {},
"outputs": [
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/conda/lib/python3.10/site-packages/huggingface_hub/ FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
" warnings.warn(\n"
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "b256c265e0664a47a55c4ef269aba32a",
"version_major": 2,
"version_minor": 0
"text/plain": [
"Loading checkpoint shards: 0%| | 0/19 [00:00<?, ?it/s]"
"metadata": {},
"output_type": "display_data"
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/conda/lib/python3.10/site-packages/huggingface_hub/ FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
" warnings.warn(\n"
"name": "stdout",
"output_type": "stream",
"text": [
"Loaded pretrained model mistralai/Mixtral-8x7B-v0.1 into HookedTransformer\n"
"source": [
"from transformer_lens import HookedTransformer\n",
"from transformers import AutoModelForCausalLM, AutoTokenizer\n",
"tl_model = HookedTransformer.from_pretrained_no_processing(\n",
" \"mistralai/Mixtral-8x7B-v0.1\"\n",
"cell_type": "code",
"execution_count": 4,
"id": "617609fe-060e-48b5-9daf-cb45cb6cc5d2",
"metadata": {},
"outputs": [
"data": {
"text/plain": [
"'Hello my name is Daniel. and butorr on2 for6 to amongst24 to to forfor so in has stuck to extreme continues to,)wh on” we have for every in bag to for in! to every to8 for for around'"
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
"source": [
" text,\n",
" verbose=False,\n",
" max_new_tokens=50,\n",
"cell_type": "code",
"execution_count": 5,
"id": "4f3ccaf0-1650-4621-84f4-dbcdef132447",
"metadata": {},
"outputs": [
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/conda/lib/python3.10/site-packages/huggingface_hub/ FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
" warnings.warn(\n"
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "0b50aa852274439a963f8594be58e4cf",
"version_major": 2,
"version_minor": 0
"text/plain": [
"Loading checkpoint shards: 0%| | 0/19 [00:00<?, ?it/s]"
"metadata": {},
"output_type": "display_data"
"source": [
"from transformers import AutoModelForCausalLM, AutoTokenizer\n",
"tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
"hf_model = AutoModelForCausalLM.from_pretrained(model_id)"
"cell_type": "code",
"execution_count": 6,
"id": "97cec596-b5fc-48c2-82db-08af7e238a0b",
"metadata": {},
"outputs": [
"name": "stderr",
"output_type": "stream",
"text": [
"Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.\n"
"name": "stdout",
"output_type": "stream",
"text": [
"Hello my name is Katie and I am a 20 year old student from the UK. I am currently studying a degree in English Literature and Creative Writing at the University of Winchester. I have always had a passion for writing and I am hoping to pursue\n"
"source": [
"inputs = tokenizer(text, return_tensors=\"pt\")\n",
"outputs = hf_model.generate(**inputs, max_new_tokens=50)\n",
"print(tokenizer.decode(outputs[0], skip_special_tokens=True))"
"cell_type": "code",
"execution_count": 45,
"id": "1df1c967-4dfd-41ef-bab4-dea6072678aa",
"metadata": {},
"outputs": [],
"source": [
"from torch.testing import assert_close\n",
"import torch"
"cell_type": "code",
"execution_count": 48,
"id": "cfa7d1c3-72e8-45b3-850c-091669f48b1e",
"metadata": {},
"outputs": [
"data": {
"text/plain": [
"execution_count": 48,
"metadata": {},
"output_type": "execute_result"
"source": [
" einops.rearrange(tl_model.blocks[0].attn.W_Q, \"n m h -> (n h) m\") ==\n",
" hf_model.model.layers[0].self_attn.q_proj.weight\n",
"cell_type": "code",
"execution_count": 106,
"id": "83649934-f06b-4f94-8004-59b8d4098589",
"metadata": {},
"outputs": [
"data": {
"text/plain": [
"(torch.Size([32, 4096, 128]), torch.Size([1024, 4096]))"
"execution_count": 106,
"metadata": {},
"output_type": "execute_result"
"source": [
"tl_model.blocks[0].attn.W_K.shape, hf_model.model.layers[0].self_attn.k_proj.weight.shape"
"cell_type": "code",
"execution_count": 104,
"id": "4fa20cf5-b720-4946-a7e5-e1d2e6277f6c",
"metadata": {},
"outputs": [
"data": {
"text/plain": [
"execution_count": 104,
"metadata": {},
"output_type": "execute_result"
"source": [
" einops.reduce(\n",
" tl_model.blocks[0].attn.W_K, \"(n repeat) m h -> (n h) m\",\n",
" 'max',\n",
" n=tl_model.cfg.n_key_value_heads,\n",
" repeat=4) ==\n",
" hf_model.model.layers[0].self_attn.k_proj.weight\n",
"cell_type": "code",
"execution_count": 105,
"id": "ef6f7ea9-ef0b-4091-8d00-504b481fc59a",
"metadata": {},
"outputs": [
"data": {
"text/plain": [
"execution_count": 105,
"metadata": {},
"output_type": "execute_result"
"source": [
" einops.reduce(\n",
" tl_model.blocks[0].attn.W_V, \"(n repeat) m h -> (n h) m\",\n",
" 'max',\n",
" n=tl_model.cfg.n_key_value_heads,\n",
" repeat=4) ==\n",
" hf_model.model.layers[0].self_attn.v_proj.weight\n",
"cell_type": "code",
"execution_count": 50,
"id": "04b8f4be-ce7d-4dc2-acda-d023c721525c",
"metadata": {},
"outputs": [
"data": {
"text/plain": [
"execution_count": 50,
"metadata": {},
"output_type": "execute_result"
"source": [
" einops.rearrange(tl_model.blocks[0].attn.W_O, \"n h m -> m (n h)\") ==\n",
" hf_model.model.layers[0].self_attn.o_proj.weight\n",
"cell_type": "code",
"execution_count": 29,
"id": "1e10ed87-31b5-4c1c-b726-7a3f49fbd136",
"metadata": {},
"outputs": [
"data": {
"text/plain": [
"Parameter containing:\n",
"tensor([[0., 0., 0., ..., 0., 0., 0.],\n",
" [0., 0., 0., ..., 0., 0., 0.],\n",
" [0., 0., 0., ..., 0., 0., 0.],\n",
" ...,\n",
" [0., 0., 0., ..., 0., 0., 0.],\n",
" [0., 0., 0., ..., 0., 0., 0.],\n",
" [0., 0., 0., ..., 0., 0., 0.]], requires_grad=True)"
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
"source": [
"cell_type": "code",
"execution_count": 46,
"id": "6caf9d98-adb2-45e7-8357-34288b2156f2",
"metadata": {},
"outputs": [
"data": {
"text/plain": [
"execution_count": 46,
"metadata": {},
"output_type": "execute_result"
"source": [
"torch.all(hf_model.model.layers[0].block_sparse_moe.gate.weight.T == tl_model.blocks[0].mlp.W_gate)"
"cell_type": "code",
"execution_count": 114,
"id": "00e9ea5d-74c2-4c2a-8e9d-6fc196cb8fc3",
"metadata": {},
"outputs": [
"data": {
"text/plain": [
"(torch.float32, torch.float32)"
"execution_count": 114,
"metadata": {},
"output_type": "execute_result"
"source": [
"hf_model.model.layers[0].block_sparse_moe.gate.weight.dtype, tl_model.blocks[0].mlp.W_gate.dtype"
"cell_type": "code",
"execution_count": 51,
"id": "44a55507-e639-414a-a297-e68e1c0696f9",
"metadata": {},
"outputs": [
"data": {
"text/plain": [
"execution_count": 51,
"metadata": {},
"output_type": "execute_result"
"source": [
" tl_model.blocks[0].mlp.experts[0].W_in ==\n",
" hf_model.model.layers[0].block_sparse_moe.experts[0].w3.weight.T\n",
"cell_type": "code",
"execution_count": 56,
"id": "03944deb-aa8d-46ff-83dd-4f7ee955656c",
"metadata": {},
"outputs": [],
"source": [
"test_tensor = torch.randn((1, 1, 4096,))"
"cell_type": "code",
"execution_count": 65,
"id": "eb0109ee-b82a-4ea0-b50b-8e6408647cea",
"metadata": {},
"outputs": [
"data": {
"text/plain": [
"execution_count": 65,
"metadata": {},
"output_type": "execute_result"
"source": [
" hf_model.model.layers[0].block_sparse_moe(test_tensor)[0] ==\n",
" tl_model.blocks[0].mlp(test_tensor)\n",
"cell_type": "code",
"execution_count": 63,
"id": "25ce75bf-706e-4ae8-8f74-bc9c40e88c25",
"metadata": {},
"outputs": [
"data": {
"text/plain": [
"tensor([[[ 0.3826, 0.0153, 0.0993, ..., -0.2474, 0.4459, -0.3026]]],\n",
" grad_fn=<ReshapeAliasBackward0>)"
"execution_count": 63,
"metadata": {},
"output_type": "execute_result"
"source": [
"cell_type": "code",
"execution_count": 62,
"id": "c016430e-0a30-426b-bfd0-0b1b423b3ff6",
"metadata": {},
"outputs": [
"data": {
"text/plain": [
"tensor([[[ 0.3826, 0.0153, 0.0993, ..., -0.2474, 0.4459, -0.3026]]],\n",
" grad_fn=<IndexPutBackward0>)"
"execution_count": 62,
"metadata": {},
"output_type": "execute_result"
"source": [
"cell_type": "code",
"execution_count": 107,
"id": "46353486-0a3f-4241-9cf5-ed25c7539f71",
"metadata": {},
"outputs": [
"data": {
"text/plain": [
"torch.Size([1, 1, 4096])"
"execution_count": 107,
"metadata": {},
"output_type": "execute_result"
"source": [
"cell_type": "code",
"execution_count": 67,
"id": "e25ada54-4e3c-42b7-8f35-ba67bfa500e3",
"metadata": {},
"outputs": [
"data": {
"text/plain": [
"tensor([[[False, False, False, ..., False, False, False]]])"
"execution_count": 67,
"metadata": {},
"output_type": "execute_result"
"source": [
"hf_model.model.layers[0].block_sparse_moe(test_tensor)[0] == tl_model.blocks[0].mlp(test_tensor)"
"cell_type": "code",
"execution_count": 83,
"id": "8f3a2865-645d-4441-95fb-32446f866760",
"metadata": {},
"outputs": [
"data": {
"text/plain": [
"execution_count": 83,
"metadata": {},
"output_type": "execute_result"
"source": [
"torch.sum(hf_model.model.layers[0].block_sparse_moe(test_tensor)[0] == tl_model.blocks[0].mlp(test_tensor))"
"cell_type": "code",
"execution_count": 72,
"id": "ac306e1c-9972-466a-8f4a-f3eb56042f53",
"metadata": {},
"outputs": [
"data": {
"text/plain": [
"execution_count": 72,
"metadata": {},
"output_type": "execute_result"
"source": [
"hf_model.model.layers[0].block_sparse_moe(test_tensor)[0][0, 0, 0].item()"
"cell_type": "code",
"execution_count": 73,
"id": "e9481397-6e87-435a-a0cf-ef409630d17c",
"metadata": {},
"outputs": [
"data": {
"text/plain": [
"execution_count": 73,
"metadata": {},
"output_type": "execute_result"
"source": [
"tl_model.blocks[0].mlp(test_tensor)[0, 0, 0].item()"
"cell_type": "code",
"execution_count": 74,
"id": "89f55163-fa2e-44f2-b112-7b052a6e85af",
"metadata": {},
"outputs": [
"data": {
"text/plain": [
" (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
" (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
" (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
" (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
" (rotary_emb): MixtralRotaryEmbedding()\n",
"execution_count": 74,
"metadata": {},
"output_type": "execute_result"
"source": [
"cell_type": "code",
"execution_count": 76,
"id": "f0186768-d8f0-4d55-a94c-606c4ba3f7ca",
"metadata": {},
"outputs": [
"data": {
"text/plain": [
"(tensor([[[-0.6824, 2.0180, -1.6793, ..., -1.3551, -0.2033, -0.8247]]],\n",
" grad_fn=<AddBackward0>),)"
"execution_count": 76,
"metadata": {},
"output_type": "execute_result"
"source": [
"cell_type": "code",
"execution_count": 77,
"id": "432bb274-b499-44c9-98d1-777d03425daa",
"metadata": {},
"outputs": [
"data": {
"text/plain": [
"tensor([[[-0.6824, 2.0180, -1.6793, ..., -1.3551, -0.2033, -0.8247]]],\n",
" grad_fn=<AddBackward0>)"
"execution_count": 77,
"metadata": {},
"output_type": "execute_result"
"source": [
"cell_type": "code",
"execution_count": 79,
"id": "4a440811-e7f0-4092-b8e7-f7cac80dc84a",
"metadata": {},
"outputs": [
"data": {
"text/plain": [
"tensor([[[ True, True, True, ..., True, False, True]]])"
"execution_count": 79,
"metadata": {},
"output_type": "execute_result"
"source": [
"hf_model.model.layers[0](test_tensor)[0] == tl_model.blocks[0](test_tensor)"
"cell_type": "code",
"execution_count": 110,
"id": "8ed65bb3-6990-48e5-9ef2-1becd9dfaffc",
"metadata": {},
"outputs": [
"data": {
"text/plain": [
"(-0.20334099233150482, -0.20334100723266602)"
"execution_count": 110,
"metadata": {},
"output_type": "execute_result"
"source": [
"hf_model.model.layers[0](test_tensor)[0][0, 0, -2].item(), tl_model.blocks[0](test_tensor)[0, 0, -2].item()"
"cell_type": "code",
"execution_count": 80,
"id": "763f6c2e-b71f-4724-b2f7-f79a9ab29caf",
"metadata": {},
"outputs": [
"data": {
"text/plain": [
"execution_count": 80,
"metadata": {},
"output_type": "execute_result"
"source": [
"torch.sum(hf_model.model.layers[0](test_tensor)[0] == tl_model.blocks[0](test_tensor))"
"cell_type": "code",
"execution_count": 84,
"id": "5172efa2-0066-4ae0-a6a2-530d815b053b",
"metadata": {},
"outputs": [
"data": {
"text/plain": [
"tensor([[[-0.2402, -0.0474, 0.0492, ..., -0.3049, 0.0006, 0.1522]]],\n",
" grad_fn=<AddBackward0>)"
"execution_count": 84,
"metadata": {},
"output_type": "execute_result"
"source": [
"tl_model.blocks[0].attn.forward(test_tensor, test_tensor, test_tensor)"
"cell_type": "code",
"execution_count": 86,
"id": "92781a06-e16d-43f9-be4c-3ef04b3d4b08",
"metadata": {},
"outputs": [
"data": {
"text/plain": [
"tensor([[[-0.2402, -0.0474, 0.0492, ..., -0.3049, 0.0006, 0.1522]]],\n",
" grad_fn=<UnsafeViewBackward0>)"
"execution_count": 86,
"metadata": {},
"output_type": "execute_result"
"source": [
"cell_type": "code",
"execution_count": 88,
"id": "943cd506-2bb8-45bf-afc7-7f6b4f8043f1",
"metadata": {},
"outputs": [
"data": {
"text/plain": [
"tensor([[[False, False, False, ..., False, False, True]]])"
"execution_count": 88,
"metadata": {},
"output_type": "execute_result"
"source": [
"(tl_model.blocks[0].attn.forward(test_tensor, test_tensor, test_tensor) == \n",
" hf_model.model.layers[0].self_attn.forward(test_tensor)[0])"
"cell_type": "code",
"execution_count": 89,
"id": "57ffc181-abed-4784-86eb-6e6b4f174bc5",
"metadata": {},
"outputs": [
"data": {
"text/plain": [
"execution_count": 89,
"metadata": {},
"output_type": "execute_result"
"source": [
"torch.sum(tl_model.blocks[0].attn.forward(test_tensor, test_tensor, test_tensor) == \n",
" hf_model.model.layers[0].self_attn.forward(test_tensor)[0])"
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
"nbformat": 4,
"nbformat_minor": 5
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment