-
-
Save joelburget/bae5ea4d997e804b2a65d02d5b61f5bc to your computer and use it in GitHub Desktop.
Mixtral poking
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "4fb7e0bc-4ef5-40c8-8222-336e83bd6e66", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"%pip install transformers huggingface_hub" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "be241e96-3bbb-46a4-a4d4-0213eb094d6e", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"%pip install git+https://github.com/TransformerLensOrg/TransformerLens.git" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "6d7341d8-881c-41c3-8199-ae9590d51a5a", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "8ea54c24323c4e04b3890f01257e6549", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"VBox(children=(HTML(value='<center> <img\\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"from huggingface_hub import login\n", | |
"login()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "cba8adb4-03a4-4061-b62b-18bcc091b8af", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import einops\n", | |
"model_id = \"mistralai/Mixtral-8x7B-v0.1\"\n", | |
"text = \"Hello my name is\"" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "2c3cb338-cf1b-4775-b278-302999164e6a", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/opt/conda/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n", | |
" warnings.warn(\n" | |
] | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "b256c265e0664a47a55c4ef269aba32a", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Loading checkpoint shards: 0%| | 0/19 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/opt/conda/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n", | |
" warnings.warn(\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Loaded pretrained model mistralai/Mixtral-8x7B-v0.1 into HookedTransformer\n" | |
] | |
} | |
], | |
"source": [ | |
"from transformer_lens import HookedTransformer\n", | |
"from transformers import AutoModelForCausalLM, AutoTokenizer\n", | |
"\n", | |
"tl_model = HookedTransformer.from_pretrained_no_processing(\n", | |
" \"mistralai/Mixtral-8x7B-v0.1\"\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "617609fe-060e-48b5-9daf-cb45cb6cc5d2", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"'Hello my name is Daniel. and butorr on2 for6 to amongst24 to to forfor so in has stuck to extreme continues to,)wh on” we have for every in bag to for in! to every to8 for for around'" | |
] | |
}, | |
"execution_count": 4, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"tl_model.generate(\n", | |
" text,\n", | |
" verbose=False,\n", | |
" max_new_tokens=50,\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "4f3ccaf0-1650-4621-84f4-dbcdef132447", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/opt/conda/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n", | |
" warnings.warn(\n" | |
] | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "0b50aa852274439a963f8594be58e4cf", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Loading checkpoint shards: 0%| | 0/19 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"from transformers import AutoModelForCausalLM, AutoTokenizer\n", | |
"\n", | |
"tokenizer = AutoTokenizer.from_pretrained(model_id)\n", | |
"hf_model = AutoModelForCausalLM.from_pretrained(model_id)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "97cec596-b5fc-48c2-82db-08af7e238a0b", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Hello my name is Katie and I am a 20 year old student from the UK. I am currently studying a degree in English Literature and Creative Writing at the University of Winchester. I have always had a passion for writing and I am hoping to pursue\n" | |
] | |
} | |
], | |
"source": [ | |
"inputs = tokenizer(text, return_tensors=\"pt\")\n", | |
"outputs = hf_model.generate(**inputs, max_new_tokens=50)\n", | |
"print(tokenizer.decode(outputs[0], skip_special_tokens=True))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 45, | |
"id": "1df1c967-4dfd-41ef-bab4-dea6072678aa", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from torch.testing import assert_close\n", | |
"import torch" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 48, | |
"id": "cfa7d1c3-72e8-45b3-850c-091669f48b1e", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor(True)" | |
] | |
}, | |
"execution_count": 48, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"torch.all(\n", | |
" einops.rearrange(tl_model.blocks[0].attn.W_Q, \"n m h -> (n h) m\") ==\n", | |
" hf_model.model.layers[0].self_attn.q_proj.weight\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 106, | |
"id": "83649934-f06b-4f94-8004-59b8d4098589", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(torch.Size([32, 4096, 128]), torch.Size([1024, 4096]))" | |
] | |
}, | |
"execution_count": 106, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"tl_model.blocks[0].attn.W_K.shape, hf_model.model.layers[0].self_attn.k_proj.weight.shape" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 104, | |
"id": "4fa20cf5-b720-4946-a7e5-e1d2e6277f6c", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor(True)" | |
] | |
}, | |
"execution_count": 104, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"torch.all(\n", | |
" einops.reduce(\n", | |
" tl_model.blocks[0].attn.W_K, \"(n repeat) m h -> (n h) m\",\n", | |
" 'max',\n", | |
" n=tl_model.cfg.n_key_value_heads,\n", | |
" repeat=4) ==\n", | |
" hf_model.model.layers[0].self_attn.k_proj.weight\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 105, | |
"id": "ef6f7ea9-ef0b-4091-8d00-504b481fc59a", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor(True)" | |
] | |
}, | |
"execution_count": 105, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"torch.all(\n", | |
" einops.reduce(\n", | |
" tl_model.blocks[0].attn.W_V, \"(n repeat) m h -> (n h) m\",\n", | |
" 'max',\n", | |
" n=tl_model.cfg.n_key_value_heads,\n", | |
" repeat=4) ==\n", | |
" hf_model.model.layers[0].self_attn.v_proj.weight\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 50, | |
"id": "04b8f4be-ce7d-4dc2-acda-d023c721525c", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor(True)" | |
] | |
}, | |
"execution_count": 50, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"torch.all(\n", | |
" einops.rearrange(tl_model.blocks[0].attn.W_O, \"n h m -> m (n h)\") ==\n", | |
" hf_model.model.layers[0].self_attn.o_proj.weight\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 29, | |
"id": "1e10ed87-31b5-4c1c-b726-7a3f49fbd136", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"Parameter containing:\n", | |
"tensor([[0., 0., 0., ..., 0., 0., 0.],\n", | |
" [0., 0., 0., ..., 0., 0., 0.],\n", | |
" [0., 0., 0., ..., 0., 0., 0.],\n", | |
" ...,\n", | |
" [0., 0., 0., ..., 0., 0., 0.],\n", | |
" [0., 0., 0., ..., 0., 0., 0.],\n", | |
" [0., 0., 0., ..., 0., 0., 0.]], requires_grad=True)" | |
] | |
}, | |
"execution_count": 29, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"tl_model.blocks[0].attn.b_Q" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 46, | |
"id": "6caf9d98-adb2-45e7-8357-34288b2156f2", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor(True)" | |
] | |
}, | |
"execution_count": 46, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"torch.all(hf_model.model.layers[0].block_sparse_moe.gate.weight.T == tl_model.blocks[0].mlp.W_gate)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 114, | |
"id": "00e9ea5d-74c2-4c2a-8e9d-6fc196cb8fc3", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(torch.float32, torch.float32)" | |
] | |
}, | |
"execution_count": 114, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"hf_model.model.layers[0].block_sparse_moe.gate.weight.dtype, tl_model.blocks[0].mlp.W_gate.dtype" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 51, | |
"id": "44a55507-e639-414a-a297-e68e1c0696f9", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor(True)" | |
] | |
}, | |
"execution_count": 51, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"torch.all(\n", | |
" tl_model.blocks[0].mlp.experts[0].W_in ==\n", | |
" hf_model.model.layers[0].block_sparse_moe.experts[0].w3.weight.T\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 56, | |
"id": "03944deb-aa8d-46ff-83dd-4f7ee955656c", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"test_tensor = torch.randn((1, 1, 4096,))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 65, | |
"id": "eb0109ee-b82a-4ea0-b50b-8e6408647cea", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor(False)" | |
] | |
}, | |
"execution_count": 65, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"torch.all(\n", | |
" hf_model.model.layers[0].block_sparse_moe(test_tensor)[0] ==\n", | |
" tl_model.blocks[0].mlp(test_tensor)\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 63, | |
"id": "25ce75bf-706e-4ae8-8f74-bc9c40e88c25", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[[ 0.3826, 0.0153, 0.0993, ..., -0.2474, 0.4459, -0.3026]]],\n", | |
" grad_fn=<ReshapeAliasBackward0>)" | |
] | |
}, | |
"execution_count": 63, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"hf_model.model.layers[0].block_sparse_moe(test_tensor)[0]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 62, | |
"id": "c016430e-0a30-426b-bfd0-0b1b423b3ff6", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[[ 0.3826, 0.0153, 0.0993, ..., -0.2474, 0.4459, -0.3026]]],\n", | |
" grad_fn=<IndexPutBackward0>)" | |
] | |
}, | |
"execution_count": 62, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"tl_model.blocks[0].mlp(test_tensor)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 107, | |
"id": "46353486-0a3f-4241-9cf5-ed25c7539f71", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"torch.Size([1, 1, 4096])" | |
] | |
}, | |
"execution_count": 107, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"tl_model.blocks[0].mlp(test_tensor).shape" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 67, | |
"id": "e25ada54-4e3c-42b7-8f35-ba67bfa500e3", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[[False, False, False, ..., False, False, False]]])" | |
] | |
}, | |
"execution_count": 67, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"hf_model.model.layers[0].block_sparse_moe(test_tensor)[0] == tl_model.blocks[0].mlp(test_tensor)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 83, | |
"id": "8f3a2865-645d-4441-95fb-32446f866760", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor(89)" | |
] | |
}, | |
"execution_count": 83, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"torch.sum(hf_model.model.layers[0].block_sparse_moe(test_tensor)[0] == tl_model.blocks[0].mlp(test_tensor))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 72, | |
"id": "ac306e1c-9972-466a-8f4a-f3eb56042f53", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"0.38261502981185913" | |
] | |
}, | |
"execution_count": 72, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"hf_model.model.layers[0].block_sparse_moe(test_tensor)[0][0, 0, 0].item()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 73, | |
"id": "e9481397-6e87-435a-a0cf-ef409630d17c", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"0.38261523842811584" | |
] | |
}, | |
"execution_count": 73, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"tl_model.blocks[0].mlp(test_tensor)[0, 0, 0].item()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 74, | |
"id": "89f55163-fa2e-44f2-b112-7b052a6e85af", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"MixtralAttention(\n", | |
" (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n", | |
" (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", | |
" (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", | |
" (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n", | |
" (rotary_emb): MixtralRotaryEmbedding()\n", | |
")" | |
] | |
}, | |
"execution_count": 74, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"hf_model.model.layers[0].self_attn" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 76, | |
"id": "f0186768-d8f0-4d55-a94c-606c4ba3f7ca", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(tensor([[[-0.6824, 2.0180, -1.6793, ..., -1.3551, -0.2033, -0.8247]]],\n", | |
" grad_fn=<AddBackward0>),)" | |
] | |
}, | |
"execution_count": 76, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"hf_model.model.layers[0](test_tensor)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 77, | |
"id": "432bb274-b499-44c9-98d1-777d03425daa", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[[-0.6824, 2.0180, -1.6793, ..., -1.3551, -0.2033, -0.8247]]],\n", | |
" grad_fn=<AddBackward0>)" | |
] | |
}, | |
"execution_count": 77, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"tl_model.blocks[0](test_tensor)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 79, | |
"id": "4a440811-e7f0-4092-b8e7-f7cac80dc84a", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[[ True, True, True, ..., True, False, True]]])" | |
] | |
}, | |
"execution_count": 79, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"hf_model.model.layers[0](test_tensor)[0] == tl_model.blocks[0](test_tensor)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 110, | |
"id": "8ed65bb3-6990-48e5-9ef2-1becd9dfaffc", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(-0.20334099233150482, -0.20334100723266602)" | |
] | |
}, | |
"execution_count": 110, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"hf_model.model.layers[0](test_tensor)[0][0, 0, -2].item(), tl_model.blocks[0](test_tensor)[0, 0, -2].item()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 80, | |
"id": "763f6c2e-b71f-4724-b2f7-f79a9ab29caf", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor(3218)" | |
] | |
}, | |
"execution_count": 80, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"torch.sum(hf_model.model.layers[0](test_tensor)[0] == tl_model.blocks[0](test_tensor))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 84, | |
"id": "5172efa2-0066-4ae0-a6a2-530d815b053b", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[[-0.2402, -0.0474, 0.0492, ..., -0.3049, 0.0006, 0.1522]]],\n", | |
" grad_fn=<AddBackward0>)" | |
] | |
}, | |
"execution_count": 84, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"tl_model.blocks[0].attn.forward(test_tensor, test_tensor, test_tensor)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 86, | |
"id": "92781a06-e16d-43f9-be4c-3ef04b3d4b08", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[[-0.2402, -0.0474, 0.0492, ..., -0.3049, 0.0006, 0.1522]]],\n", | |
" grad_fn=<UnsafeViewBackward0>)" | |
] | |
}, | |
"execution_count": 86, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"hf_model.model.layers[0].self_attn.forward(test_tensor)[0]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 88, | |
"id": "943cd506-2bb8-45bf-afc7-7f6b4f8043f1", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[[False, False, False, ..., False, False, True]]])" | |
] | |
}, | |
"execution_count": 88, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"(tl_model.blocks[0].attn.forward(test_tensor, test_tensor, test_tensor) == \n", | |
" hf_model.model.layers[0].self_attn.forward(test_tensor)[0])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 89, | |
"id": "57ffc181-abed-4784-86eb-6e6b4f174bc5", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor(254)" | |
] | |
}, | |
"execution_count": 89, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"torch.sum(tl_model.blocks[0].attn.forward(test_tensor, test_tensor, test_tensor) == \n", | |
" hf_model.model.layers[0].self_attn.forward(test_tensor)[0])" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.10.14" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment