Skip to content

Instantly share code, notes, and snippets.

@arjunguha
Created April 9, 2024 10:16
Show Gist options
  • Save arjunguha/46c0557ee0323f748eddbbce1981c5a5 to your computer and use it in GitHub Desktop.
Save arjunguha/46c0557ee0323f748eddbbce1981c5a5 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "1e146d86-0145-47b8-b633-6bbe23b22d3c",
"metadata": {},
"outputs": [],
"source": [
"from nnsight import LanguageModel\n",
"import torch\n",
"import torch.nn.functional as F"
]
},
{
"cell_type": "markdown",
"id": "e4c03961-206c-42ac-9355-490317e09cdc",
"metadata": {},
"source": [
"I'm using Code Llama, but this should work with any model."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "8ab7a262-1dfb-483f-ae18-6dacae634edd",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "6e3c62c61ca4422684b50a0e5fe002af",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"LlamaForCausalLM(\n",
" (model): LlamaModel(\n",
" (embed_tokens): Embedding(32016, 4096)\n",
" (layers): ModuleList(\n",
" (0-31): 32 x LlamaDecoderLayer(\n",
" (self_attn): LlamaSdpaAttention(\n",
" (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
" (k_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
" (v_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
" (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
" (rotary_emb): LlamaRotaryEmbedding()\n",
" )\n",
" (mlp): LlamaMLP(\n",
" (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)\n",
" (up_proj): Linear(in_features=4096, out_features=11008, bias=False)\n",
" (down_proj): Linear(in_features=11008, out_features=4096, bias=False)\n",
" (act_fn): SiLU()\n",
" )\n",
" (input_layernorm): LlamaRMSNorm()\n",
" (post_attention_layernorm): LlamaRMSNorm()\n",
" )\n",
" )\n",
" (norm): LlamaRMSNorm()\n",
" )\n",
" (lm_head): Linear(in_features=4096, out_features=32016, bias=False)\n",
" (generator): WrapperModule()\n",
")"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model = LanguageModel(\n",
" \"/work/arjunguha-research-group/arjun/models/codellama_7b_base\",\n",
" device_map=\"cuda\", dispatch=True)\n",
"model"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "e3d5988c-5d0f-4ce8-9cc7-7871be9b5328",
"metadata": {},
"outputs": [],
"source": [
"PROMPT = \"\"\"(function (t: boolean) { return 5; }) (\"\"\""
]
},
{
"cell_type": "markdown",
"id": "55d08c71-c414-4e12-9331-ed6e303ede32",
"metadata": {},
"source": [
"This is basically http://nnsight.net/notebooks/features/modules/:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "b0193829-0876-477f-b285-cc34ae71795a",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"You're using a CodeLlamaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n"
]
},
{
"data": {
"text/plain": [
"torch.Size([1, 15, 32016])"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"with model.trace(PROMPT) as tracer:\n",
" (layer10_out, layer10_out_rest) = model.model.layers[10].output\n",
" out_0 = model.lm_head(layer10_out).save()\n",
"out_0.shape"
]
},
{
"cell_type": "markdown",
"id": "c56343c7-78db-4a55-8510-311af66c9569",
"metadata": {},
"source": [
"This does not work."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "41d72cb1-0718-41e7-9c08-2d8ee28cd2d2",
"metadata": {},
"outputs": [
{
"ename": "AssertionError",
"evalue": "fake mode (<torch._subclasses.fake_tensor.FakeTensorMode object at 0x14d5a53d2f10>) from active fake mode 0 doesn't match mode (<torch._subclasses.fake_tensor.FakeTensorMode object at 0x14d5a53f9210>) from fake tensor input 0\n\nfake mode from active fake mode 0 allocated at:\n File \"<frozen runpy>\", line 198, in _run_module_as_main\n File \"<frozen runpy>\", line 88, in _run_code\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/ipykernel_launcher.py\", line 17, in <module>\n app.launch_new_instance()\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/traitlets/config/application.py\", line 1075, in launch_instance\n app.start()\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/ipykernel/kernelapp.py\", line 701, in start\n self.io_loop.start()\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/tornado/platform/asyncio.py\", line 205, in start\n self.asyncio_loop.run_forever()\n File \"/home/a.guha/miniconda3/lib/python3.11/asyncio/base_events.py\", line 607, in run_forever\n self._run_once()\n File \"/home/a.guha/miniconda3/lib/python3.11/asyncio/base_events.py\", line 1922, in _run_once\n handle._run()\n File \"/home/a.guha/miniconda3/lib/python3.11/asyncio/events.py\", line 80, in _run\n self._context.run(self._callback, *self._args)\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/ipykernel/kernelbase.py\", line 534, in dispatch_queue\n await self.process_one()\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/ipykernel/kernelbase.py\", line 523, in process_one\n await dispatch(*args)\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/ipykernel/kernelbase.py\", line 429, in dispatch_shell\n await result\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/ipykernel/kernelbase.py\", line 767, in execute_request\n reply_content = await reply_content\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/ipykernel/ipkernel.py\", line 429, in do_execute\n res = shell.run_cell(\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/ipykernel/zmqshell.py\", line 549, in run_cell\n return super().run_cell(*args, **kwargs)\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/IPython/core/interactiveshell.py\", line 3051, in run_cell\n result = self._run_cell(\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/IPython/core/interactiveshell.py\", line 3106, in _run_cell\n result = runner(coro)\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/IPython/core/async_helpers.py\", line 129, in _pseudo_sync_runner\n coro.send(None)\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/IPython/core/interactiveshell.py\", line 3311, in run_cell_async\n has_raised = await self.run_ast_nodes(code_ast.body, cell_name,\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/IPython/core/interactiveshell.py\", line 3493, in run_ast_nodes\n if await self.run_code(code, result, async_=asy):\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/IPython/core/interactiveshell.py\", line 3553, in run_code\n exec(code_obj, self.user_global_ns, self.user_ns)\n File \"/tmp/ipykernel_1636525/3636171273.py\", line 3, in <module>\n out_1 = model.model.layers[11](layer10_out) # error here\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/nnsight/envoy.py\", line 336, in __call__\n proxy = module_proxy.forward(*args, **kwargs)\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/nnsight/intervention.py\", line 144, in __call__\n return super().__call__(*args, **kwargs)\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/nnsight/tracing/Proxy.py\", line 80, in __call__\n return self.node.add(\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/nnsight/tracing/Node.py\", line 199, in add\n return self.graph.add(\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/nnsight/tracing/Graph.py\", line 140, in add\n with FakeTensorMode(\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py\", line 1364, in __init__\n self.stack = \"\".join(traceback.format_stack())\n\nfake mode from fake tensor input 0 allocated at:\n File \"<frozen runpy>\", line 198, in _run_module_as_main\n File \"<frozen runpy>\", line 88, in _run_code\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/ipykernel_launcher.py\", line 17, in <module>\n app.launch_new_instance()\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/traitlets/config/application.py\", line 1075, in launch_instance\n app.start()\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/ipykernel/kernelapp.py\", line 701, in start\n self.io_loop.start()\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/tornado/platform/asyncio.py\", line 205, in start\n self.asyncio_loop.run_forever()\n File \"/home/a.guha/miniconda3/lib/python3.11/asyncio/base_events.py\", line 607, in run_forever\n self._run_once()\n File \"/home/a.guha/miniconda3/lib/python3.11/asyncio/base_events.py\", line 1922, in _run_once\n handle._run()\n File \"/home/a.guha/miniconda3/lib/python3.11/asyncio/events.py\", line 80, in _run\n self._context.run(self._callback, *self._args)\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/ipykernel/kernelbase.py\", line 534, in dispatch_queue\n await self.process_one()\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/ipykernel/kernelbase.py\", line 523, in process_one\n await dispatch(*args)\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/ipykernel/kernelbase.py\", line 429, in dispatch_shell\n await result\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/ipykernel/kernelbase.py\", line 767, in execute_request\n reply_content = await reply_content\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/ipykernel/ipkernel.py\", line 429, in do_execute\n res = shell.run_cell(\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/ipykernel/zmqshell.py\", line 549, in run_cell\n return super().run_cell(*args, **kwargs)\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/IPython/core/interactiveshell.py\", line 3051, in run_cell\n result = self._run_cell(\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/IPython/core/interactiveshell.py\", line 3106, in _run_cell\n result = runner(coro)\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/IPython/core/async_helpers.py\", line 129, in _pseudo_sync_runner\n coro.send(None)\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/IPython/core/interactiveshell.py\", line 3311, in run_cell_async\n has_raised = await self.run_ast_nodes(code_ast.body, cell_name,\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/IPython/core/interactiveshell.py\", line 3493, in run_ast_nodes\n if await self.run_code(code, result, async_=asy):\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/IPython/core/interactiveshell.py\", line 3553, in run_code\n exec(code_obj, self.user_global_ns, self.user_ns)\n File \"/tmp/ipykernel_1636525/3636171273.py\", line 1, in <module>\n with model.trace(PROMPT) as tracer:\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/nnsight/models/NNsightModel.py\", line 200, in trace\n runner.invoke(*inputs, **invoker_args).__enter__()\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/nnsight/contexts/Invoker.py\", line 60, in __enter__\n with FakeTensorMode(\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py\", line 1364, in __init__\n self.stack = \"\".join(traceback.format_stack())\n",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[5], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28;43;01mwith\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrace\u001b[49m\u001b[43m(\u001b[49m\u001b[43mPROMPT\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mas\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mtracer\u001b[49m\u001b[43m:\u001b[49m\n\u001b[1;32m 2\u001b[0m \u001b[43m \u001b[49m\u001b[43m(\u001b[49m\u001b[43mlayer10_out\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlayer10_out_rest\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlayers\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m10\u001b[39;49m\u001b[43m]\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moutput\u001b[49m\n\u001b[1;32m 3\u001b[0m \u001b[43m \u001b[49m\u001b[43mout_1\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlayers\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m11\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m(\u001b[49m\u001b[43mlayer10_out\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# error here\u001b[39;49;00m\n",
"File \u001b[0;32m/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/nnsight/contexts/Runner.py:41\u001b[0m, in \u001b[0;36mRunner.__exit__\u001b[0;34m(self, exc_type, exc_val, exc_tb)\u001b[0m\n\u001b[1;32m 39\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"On exit, run and generate using the model whether locally or on the server.\"\"\"\u001b[39;00m\n\u001b[1;32m 40\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(exc_val, \u001b[38;5;167;01mBaseException\u001b[39;00m):\n\u001b[0;32m---> 41\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m exc_val\n\u001b[1;32m 43\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mremote:\n\u001b[1;32m 44\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mrun_server()\n",
"Cell \u001b[0;32mIn[5], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m model\u001b[38;5;241m.\u001b[39mtrace(PROMPT) \u001b[38;5;28;01mas\u001b[39;00m tracer:\n\u001b[1;32m 2\u001b[0m (layer10_out, layer10_out_rest) \u001b[38;5;241m=\u001b[39m model\u001b[38;5;241m.\u001b[39mmodel\u001b[38;5;241m.\u001b[39mlayers[\u001b[38;5;241m10\u001b[39m]\u001b[38;5;241m.\u001b[39moutput\n\u001b[0;32m----> 3\u001b[0m out_1 \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlayers\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m11\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m(\u001b[49m\u001b[43mlayer10_out\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# error here\u001b[39;00m\n\u001b[1;32m 4\u001b[0m out_1 \u001b[38;5;241m=\u001b[39m out_1\u001b[38;5;241m.\u001b[39msave()\n\u001b[1;32m 5\u001b[0m out_1\u001b[38;5;241m.\u001b[39mshape\n",
"File \u001b[0;32m/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/nnsight/envoy.py:336\u001b[0m, in \u001b[0;36mEnvoy.__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 332\u001b[0m device \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mdevice(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcpu\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 334\u001b[0m torch\u001b[38;5;241m.\u001b[39mset_default_device(device)\n\u001b[0;32m--> 336\u001b[0m proxy \u001b[38;5;241m=\u001b[39m \u001b[43mmodule_proxy\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mforward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 338\u001b[0m torch\u001b[38;5;241m.\u001b[39m_GLOBAL_DEVICE_CONTEXT\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__exit__\u001b[39m(\u001b[38;5;28;01mNone\u001b[39;00m, \u001b[38;5;28;01mNone\u001b[39;00m, \u001b[38;5;28;01mNone\u001b[39;00m)\n\u001b[1;32m 340\u001b[0m torch\u001b[38;5;241m.\u001b[39m_GLOBAL_DEVICE_CONTEXT \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
"File \u001b[0;32m/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/nnsight/intervention.py:144\u001b[0m, in \u001b[0;36mInterventionProxy.__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 135\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnode\u001b[38;5;241m.\u001b[39mgraph\u001b[38;5;241m.\u001b[39mn_backward_calls \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m 137\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnode\u001b[38;5;241m.\u001b[39madd(\n\u001b[1;32m 138\u001b[0m value\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 139\u001b[0m target\u001b[38;5;241m=\u001b[39mProxy\u001b[38;5;241m.\u001b[39mproxy_call,\n\u001b[1;32m 140\u001b[0m args\u001b[38;5;241m=\u001b[39m[\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnode] \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mlist\u001b[39m(args),\n\u001b[1;32m 141\u001b[0m kwargs\u001b[38;5;241m=\u001b[39mkwargs,\n\u001b[1;32m 142\u001b[0m )\n\u001b[0;32m--> 144\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;21;43m__call__\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/nnsight/tracing/Proxy.py:80\u001b[0m, in \u001b[0;36mProxy.__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 72\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__call__\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Self:\n\u001b[1;32m 73\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 74\u001b[0m \u001b[38;5;124;03m Calling a Proxy object just creates a Proxy.proxy_call operation.\u001b[39;00m\n\u001b[1;32m 75\u001b[0m \n\u001b[1;32m 76\u001b[0m \u001b[38;5;124;03m Returns:\u001b[39;00m\n\u001b[1;32m 77\u001b[0m \u001b[38;5;124;03m Proxy: New call proxy.\u001b[39;00m\n\u001b[1;32m 78\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m---> 80\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnode\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madd\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 81\u001b[0m \u001b[43m \u001b[49m\u001b[43mtarget\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mProxy\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mproxy_call\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 82\u001b[0m \u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m[\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnode\u001b[49m\u001b[43m]\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;28;43mlist\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43margs\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 83\u001b[0m \u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 84\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/nnsight/tracing/Node.py:199\u001b[0m, in \u001b[0;36mNode.add\u001b[0;34m(self, target, value, args, kwargs, name)\u001b[0m\n\u001b[1;32m 186\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mis_tracing():\n\u001b[1;32m 188\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m Proxy(\n\u001b[1;32m 189\u001b[0m Node(\n\u001b[1;32m 190\u001b[0m name\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 196\u001b[0m )\n\u001b[1;32m 197\u001b[0m )\u001b[38;5;241m.\u001b[39mvalue\n\u001b[0;32m--> 199\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgraph\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madd\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 200\u001b[0m \u001b[43m \u001b[49m\u001b[43mtarget\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtarget\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mvalue\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mvalue\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mname\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mname\u001b[49m\n\u001b[1;32m 201\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/nnsight/tracing/Graph.py:146\u001b[0m, in \u001b[0;36mGraph.add\u001b[0;34m(self, target, value, args, kwargs, name)\u001b[0m\n\u001b[1;32m 140\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m FakeTensorMode(\n\u001b[1;32m 141\u001b[0m allow_non_fake_inputs\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[1;32m 142\u001b[0m shape_env\u001b[38;5;241m=\u001b[39mShapeEnv(assume_static_by_default\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m),\n\u001b[1;32m 143\u001b[0m ) \u001b[38;5;28;01mas\u001b[39;00m fake_mode:\n\u001b[1;32m 144\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m FakeCopyMode(fake_mode):\n\u001b[0;32m--> 146\u001b[0m value \u001b[38;5;241m=\u001b[39m \u001b[43mtarget\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 147\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mNode\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mprepare_proxy_values\u001b[49m\u001b[43m(\u001b[49m\u001b[43m_args\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 148\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mNode\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mprepare_proxy_values\u001b[49m\u001b[43m(\u001b[49m\u001b[43m_kwargs\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 149\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 151\u001b[0m target_name \u001b[38;5;241m=\u001b[39m target \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(target, \u001b[38;5;28mstr\u001b[39m) \u001b[38;5;28;01melse\u001b[39;00m target\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\n\u001b[1;32m 153\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m target_name \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mname_idx:\n",
"File \u001b[0;32m/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/nnsight/tracing/Proxy.py:32\u001b[0m, in \u001b[0;36mProxy.proxy_call\u001b[0;34m(callable, *args, **kwargs)\u001b[0m\n\u001b[1;32m 30\u001b[0m \u001b[38;5;129m@staticmethod\u001b[39m\n\u001b[1;32m 31\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mproxy_call\u001b[39m(\u001b[38;5;28mcallable\u001b[39m: Callable, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Self:\n\u001b[0;32m---> 32\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mcallable\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py:736\u001b[0m, in \u001b[0;36mLlamaDecoderLayer.forward\u001b[0;34m(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position, **kwargs)\u001b[0m\n\u001b[1;32m 730\u001b[0m warnings\u001b[38;5;241m.\u001b[39mwarn(\n\u001b[1;32m 731\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mPassing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 732\u001b[0m )\n\u001b[1;32m 734\u001b[0m residual \u001b[38;5;241m=\u001b[39m hidden_states\n\u001b[0;32m--> 736\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43minput_layernorm\u001b[49m\u001b[43m(\u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 738\u001b[0m \u001b[38;5;66;03m# Self Attention\u001b[39;00m\n\u001b[1;32m 739\u001b[0m hidden_states, self_attn_weights, present_key_value \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mself_attn(\n\u001b[1;32m 740\u001b[0m hidden_states\u001b[38;5;241m=\u001b[39mhidden_states,\n\u001b[1;32m 741\u001b[0m attention_mask\u001b[38;5;241m=\u001b[39mattention_mask,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 747\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs,\n\u001b[1;32m 748\u001b[0m )\n",
"File \u001b[0;32m/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/torch/nn/modules/module.py:1511\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1509\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1510\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1511\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/torch/nn/modules/module.py:1572\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1569\u001b[0m called_always_called_hooks\u001b[38;5;241m.\u001b[39madd(hook_id)\n\u001b[1;32m 1571\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m hook_id \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks_with_kwargs:\n\u001b[0;32m-> 1572\u001b[0m hook_result \u001b[38;5;241m=\u001b[39m \u001b[43mhook\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mresult\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1573\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 1574\u001b[0m hook_result \u001b[38;5;241m=\u001b[39m hook(\u001b[38;5;28mself\u001b[39m, args, result)\n",
"File \u001b[0;32m/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/nnsight/envoy.py:180\u001b[0m, in \u001b[0;36mEnvoy._hook\u001b[0;34m(self, module, input, input_kwargs, output)\u001b[0m\n\u001b[1;32m 176\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_hook\u001b[39m(\n\u001b[1;32m 177\u001b[0m \u001b[38;5;28mself\u001b[39m, module: torch\u001b[38;5;241m.\u001b[39mnn\u001b[38;5;241m.\u001b[39mModule, \u001b[38;5;28minput\u001b[39m: Any, input_kwargs: Dict, output: Any\n\u001b[1;32m 178\u001b[0m ):\n\u001b[0;32m--> 180\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[43mdetect_fake_mode\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m)\u001b[49m:\n\u001b[1;32m 182\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_reset_proxies(propagate\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[1;32m 184\u001b[0m \u001b[38;5;28minput\u001b[39m \u001b[38;5;241m=\u001b[39m (\u001b[38;5;28minput\u001b[39m, input_kwargs)\n",
"File \u001b[0;32m/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/torch/_guards.py:826\u001b[0m, in \u001b[0;36mdetect_fake_mode\u001b[0;34m(inputs)\u001b[0m\n\u001b[1;32m 824\u001b[0m fake_mode, desc1, i1 \u001b[38;5;241m=\u001b[39m fake_modes[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 825\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m m, desc2, i2 \u001b[38;5;129;01min\u001b[39;00m fake_modes[\u001b[38;5;241m1\u001b[39m:]:\n\u001b[0;32m--> 826\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m fake_mode \u001b[38;5;129;01mis\u001b[39;00m m, (\n\u001b[1;32m 827\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfake mode (\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfake_mode\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m) from \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mdesc1\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mi1\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m doesn\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mt match mode (\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mm\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m) from \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mdesc2\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mi2\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 828\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfake mode from \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mdesc1\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mi1\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m allocated at:\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;132;01m{\u001b[39;00mfake_mode\u001b[38;5;241m.\u001b[39mstack\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 829\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfake mode from \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mdesc2\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mi2\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m allocated at:\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;132;01m{\u001b[39;00mm\u001b[38;5;241m.\u001b[39mstack\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 830\u001b[0m )\n\u001b[1;32m 831\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m fake_mode\n\u001b[1;32m 832\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n",
"\u001b[0;31mAssertionError\u001b[0m: fake mode (<torch._subclasses.fake_tensor.FakeTensorMode object at 0x14d5a53d2f10>) from active fake mode 0 doesn't match mode (<torch._subclasses.fake_tensor.FakeTensorMode object at 0x14d5a53f9210>) from fake tensor input 0\n\nfake mode from active fake mode 0 allocated at:\n File \"<frozen runpy>\", line 198, in _run_module_as_main\n File \"<frozen runpy>\", line 88, in _run_code\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/ipykernel_launcher.py\", line 17, in <module>\n app.launch_new_instance()\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/traitlets/config/application.py\", line 1075, in launch_instance\n app.start()\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/ipykernel/kernelapp.py\", line 701, in start\n self.io_loop.start()\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/tornado/platform/asyncio.py\", line 205, in start\n self.asyncio_loop.run_forever()\n File \"/home/a.guha/miniconda3/lib/python3.11/asyncio/base_events.py\", line 607, in run_forever\n self._run_once()\n File \"/home/a.guha/miniconda3/lib/python3.11/asyncio/base_events.py\", line 1922, in _run_once\n handle._run()\n File \"/home/a.guha/miniconda3/lib/python3.11/asyncio/events.py\", line 80, in _run\n self._context.run(self._callback, *self._args)\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/ipykernel/kernelbase.py\", line 534, in dispatch_queue\n await self.process_one()\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/ipykernel/kernelbase.py\", line 523, in process_one\n await dispatch(*args)\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/ipykernel/kernelbase.py\", line 429, in dispatch_shell\n await result\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/ipykernel/kernelbase.py\", line 767, in execute_request\n reply_content = await reply_content\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/ipykernel/ipkernel.py\", line 429, in do_execute\n res = shell.run_cell(\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/ipykernel/zmqshell.py\", line 549, in run_cell\n return super().run_cell(*args, **kwargs)\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/IPython/core/interactiveshell.py\", line 3051, in run_cell\n result = self._run_cell(\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/IPython/core/interactiveshell.py\", line 3106, in _run_cell\n result = runner(coro)\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/IPython/core/async_helpers.py\", line 129, in _pseudo_sync_runner\n coro.send(None)\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/IPython/core/interactiveshell.py\", line 3311, in run_cell_async\n has_raised = await self.run_ast_nodes(code_ast.body, cell_name,\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/IPython/core/interactiveshell.py\", line 3493, in run_ast_nodes\n if await self.run_code(code, result, async_=asy):\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/IPython/core/interactiveshell.py\", line 3553, in run_code\n exec(code_obj, self.user_global_ns, self.user_ns)\n File \"/tmp/ipykernel_1636525/3636171273.py\", line 3, in <module>\n out_1 = model.model.layers[11](layer10_out) # error here\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/nnsight/envoy.py\", line 336, in __call__\n proxy = module_proxy.forward(*args, **kwargs)\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/nnsight/intervention.py\", line 144, in __call__\n return super().__call__(*args, **kwargs)\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/nnsight/tracing/Proxy.py\", line 80, in __call__\n return self.node.add(\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/nnsight/tracing/Node.py\", line 199, in add\n return self.graph.add(\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/nnsight/tracing/Graph.py\", line 140, in add\n with FakeTensorMode(\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py\", line 1364, in __init__\n self.stack = \"\".join(traceback.format_stack())\n\nfake mode from fake tensor input 0 allocated at:\n File \"<frozen runpy>\", line 198, in _run_module_as_main\n File \"<frozen runpy>\", line 88, in _run_code\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/ipykernel_launcher.py\", line 17, in <module>\n app.launch_new_instance()\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/traitlets/config/application.py\", line 1075, in launch_instance\n app.start()\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/ipykernel/kernelapp.py\", line 701, in start\n self.io_loop.start()\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/tornado/platform/asyncio.py\", line 205, in start\n self.asyncio_loop.run_forever()\n File \"/home/a.guha/miniconda3/lib/python3.11/asyncio/base_events.py\", line 607, in run_forever\n self._run_once()\n File \"/home/a.guha/miniconda3/lib/python3.11/asyncio/base_events.py\", line 1922, in _run_once\n handle._run()\n File \"/home/a.guha/miniconda3/lib/python3.11/asyncio/events.py\", line 80, in _run\n self._context.run(self._callback, *self._args)\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/ipykernel/kernelbase.py\", line 534, in dispatch_queue\n await self.process_one()\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/ipykernel/kernelbase.py\", line 523, in process_one\n await dispatch(*args)\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/ipykernel/kernelbase.py\", line 429, in dispatch_shell\n await result\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/ipykernel/kernelbase.py\", line 767, in execute_request\n reply_content = await reply_content\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/ipykernel/ipkernel.py\", line 429, in do_execute\n res = shell.run_cell(\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/ipykernel/zmqshell.py\", line 549, in run_cell\n return super().run_cell(*args, **kwargs)\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/IPython/core/interactiveshell.py\", line 3051, in run_cell\n result = self._run_cell(\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/IPython/core/interactiveshell.py\", line 3106, in _run_cell\n result = runner(coro)\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/IPython/core/async_helpers.py\", line 129, in _pseudo_sync_runner\n coro.send(None)\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/IPython/core/interactiveshell.py\", line 3311, in run_cell_async\n has_raised = await self.run_ast_nodes(code_ast.body, cell_name,\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/IPython/core/interactiveshell.py\", line 3493, in run_ast_nodes\n if await self.run_code(code, result, async_=asy):\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/IPython/core/interactiveshell.py\", line 3553, in run_code\n exec(code_obj, self.user_global_ns, self.user_ns)\n File \"/tmp/ipykernel_1636525/3636171273.py\", line 1, in <module>\n with model.trace(PROMPT) as tracer:\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/nnsight/models/NNsightModel.py\", line 200, in trace\n runner.invoke(*inputs, **invoker_args).__enter__()\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/nnsight/contexts/Invoker.py\", line 60, in __enter__\n with FakeTensorMode(\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py\", line 1364, in __init__\n self.stack = \"\".join(traceback.format_stack())\n"
]
}
],
"source": [
"with model.trace(PROMPT) as tracer:\n",
" (layer10_out, layer10_out_rest) = model.model.layers[10].output\n",
" out_1 = model.model.layers[11](layer10_out) # error here\n",
" out_1 = out_1.save()\n",
"out_1.shape"
]
},
{
"cell_type": "markdown",
"id": "821ea7cd-92c0-4605-be35-dfeef0498dbe",
"metadata": {},
"source": [
"This variation does not work:"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "be2c6e9f-ad3e-47d6-94bd-7eaa8dfb5a87",
"metadata": {},
"outputs": [
{
"ename": "AssertionError",
"evalue": "fake mode (<torch._subclasses.fake_tensor.FakeTensorMode object at 0x14d592e1e850>) from active fake mode 0 doesn't match mode (<torch._subclasses.fake_tensor.FakeTensorMode object at 0x14d592fc2450>) from fake tensor input 0\n\nfake mode from active fake mode 0 allocated at:\n File \"<frozen runpy>\", line 198, in _run_module_as_main\n File \"<frozen runpy>\", line 88, in _run_code\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/ipykernel_launcher.py\", line 17, in <module>\n app.launch_new_instance()\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/traitlets/config/application.py\", line 1075, in launch_instance\n app.start()\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/ipykernel/kernelapp.py\", line 701, in start\n self.io_loop.start()\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/tornado/platform/asyncio.py\", line 205, in start\n self.asyncio_loop.run_forever()\n File \"/home/a.guha/miniconda3/lib/python3.11/asyncio/base_events.py\", line 607, in run_forever\n self._run_once()\n File \"/home/a.guha/miniconda3/lib/python3.11/asyncio/base_events.py\", line 1922, in _run_once\n handle._run()\n File \"/home/a.guha/miniconda3/lib/python3.11/asyncio/events.py\", line 80, in _run\n self._context.run(self._callback, *self._args)\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/ipykernel/kernelbase.py\", line 534, in dispatch_queue\n await self.process_one()\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/ipykernel/kernelbase.py\", line 523, in process_one\n await dispatch(*args)\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/ipykernel/kernelbase.py\", line 429, in dispatch_shell\n await result\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/ipykernel/kernelbase.py\", line 767, in execute_request\n reply_content = await reply_content\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/ipykernel/ipkernel.py\", line 429, in do_execute\n res = shell.run_cell(\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/ipykernel/zmqshell.py\", line 549, in run_cell\n return super().run_cell(*args, **kwargs)\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/IPython/core/interactiveshell.py\", line 3051, in run_cell\n result = self._run_cell(\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/IPython/core/interactiveshell.py\", line 3106, in _run_cell\n result = runner(coro)\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/IPython/core/async_helpers.py\", line 129, in _pseudo_sync_runner\n coro.send(None)\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/IPython/core/interactiveshell.py\", line 3311, in run_cell_async\n has_raised = await self.run_ast_nodes(code_ast.body, cell_name,\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/IPython/core/interactiveshell.py\", line 3493, in run_ast_nodes\n if await self.run_code(code, result, async_=asy):\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/IPython/core/interactiveshell.py\", line 3553, in run_code\n exec(code_obj, self.user_global_ns, self.user_ns)\n File \"/tmp/ipykernel_1636525/1197697725.py\", line 3, in <module>\n out_1 = model.model.layers[11](layer10_out, layer10_out_rest).save()\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/nnsight/envoy.py\", line 336, in __call__\n proxy = module_proxy.forward(*args, **kwargs)\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/nnsight/intervention.py\", line 144, in __call__\n return super().__call__(*args, **kwargs)\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/nnsight/tracing/Proxy.py\", line 80, in __call__\n return self.node.add(\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/nnsight/tracing/Node.py\", line 199, in add\n return self.graph.add(\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/nnsight/tracing/Graph.py\", line 140, in add\n with FakeTensorMode(\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py\", line 1364, in __init__\n self.stack = \"\".join(traceback.format_stack())\n\nfake mode from fake tensor input 0 allocated at:\n File \"<frozen runpy>\", line 198, in _run_module_as_main\n File \"<frozen runpy>\", line 88, in _run_code\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/ipykernel_launcher.py\", line 17, in <module>\n app.launch_new_instance()\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/traitlets/config/application.py\", line 1075, in launch_instance\n app.start()\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/ipykernel/kernelapp.py\", line 701, in start\n self.io_loop.start()\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/tornado/platform/asyncio.py\", line 205, in start\n self.asyncio_loop.run_forever()\n File \"/home/a.guha/miniconda3/lib/python3.11/asyncio/base_events.py\", line 607, in run_forever\n self._run_once()\n File \"/home/a.guha/miniconda3/lib/python3.11/asyncio/base_events.py\", line 1922, in _run_once\n handle._run()\n File \"/home/a.guha/miniconda3/lib/python3.11/asyncio/events.py\", line 80, in _run\n self._context.run(self._callback, *self._args)\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/ipykernel/kernelbase.py\", line 534, in dispatch_queue\n await self.process_one()\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/ipykernel/kernelbase.py\", line 523, in process_one\n await dispatch(*args)\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/ipykernel/kernelbase.py\", line 429, in dispatch_shell\n await result\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/ipykernel/kernelbase.py\", line 767, in execute_request\n reply_content = await reply_content\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/ipykernel/ipkernel.py\", line 429, in do_execute\n res = shell.run_cell(\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/ipykernel/zmqshell.py\", line 549, in run_cell\n return super().run_cell(*args, **kwargs)\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/IPython/core/interactiveshell.py\", line 3051, in run_cell\n result = self._run_cell(\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/IPython/core/interactiveshell.py\", line 3106, in _run_cell\n result = runner(coro)\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/IPython/core/async_helpers.py\", line 129, in _pseudo_sync_runner\n coro.send(None)\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/IPython/core/interactiveshell.py\", line 3311, in run_cell_async\n has_raised = await self.run_ast_nodes(code_ast.body, cell_name,\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/IPython/core/interactiveshell.py\", line 3493, in run_ast_nodes\n if await self.run_code(code, result, async_=asy):\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/IPython/core/interactiveshell.py\", line 3553, in run_code\n exec(code_obj, self.user_global_ns, self.user_ns)\n File \"/tmp/ipykernel_1636525/1197697725.py\", line 1, in <module>\n with model.trace(PROMPT) as tracer:\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/nnsight/models/NNsightModel.py\", line 200, in trace\n runner.invoke(*inputs, **invoker_args).__enter__()\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/nnsight/contexts/Invoker.py\", line 60, in __enter__\n with FakeTensorMode(\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py\", line 1364, in __init__\n self.stack = \"\".join(traceback.format_stack())\n",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[7], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28;43;01mwith\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrace\u001b[49m\u001b[43m(\u001b[49m\u001b[43mPROMPT\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mas\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mtracer\u001b[49m\u001b[43m:\u001b[49m\n\u001b[1;32m 2\u001b[0m \u001b[43m \u001b[49m\u001b[43m(\u001b[49m\u001b[43mlayer10_out\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlayer10_out_rest\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlayers\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m10\u001b[39;49m\u001b[43m]\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moutput\u001b[49m\n\u001b[1;32m 3\u001b[0m \u001b[43m \u001b[49m\u001b[43mout_1\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlayers\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m11\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m(\u001b[49m\u001b[43mlayer10_out\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlayer10_out_rest\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msave\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/nnsight/contexts/Runner.py:41\u001b[0m, in \u001b[0;36mRunner.__exit__\u001b[0;34m(self, exc_type, exc_val, exc_tb)\u001b[0m\n\u001b[1;32m 39\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"On exit, run and generate using the model whether locally or on the server.\"\"\"\u001b[39;00m\n\u001b[1;32m 40\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(exc_val, \u001b[38;5;167;01mBaseException\u001b[39;00m):\n\u001b[0;32m---> 41\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m exc_val\n\u001b[1;32m 43\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mremote:\n\u001b[1;32m 44\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mrun_server()\n",
"Cell \u001b[0;32mIn[7], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m model\u001b[38;5;241m.\u001b[39mtrace(PROMPT) \u001b[38;5;28;01mas\u001b[39;00m tracer:\n\u001b[1;32m 2\u001b[0m (layer10_out, layer10_out_rest) \u001b[38;5;241m=\u001b[39m model\u001b[38;5;241m.\u001b[39mmodel\u001b[38;5;241m.\u001b[39mlayers[\u001b[38;5;241m10\u001b[39m]\u001b[38;5;241m.\u001b[39moutput\n\u001b[0;32m----> 3\u001b[0m out_1 \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlayers\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m11\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m(\u001b[49m\u001b[43mlayer10_out\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlayer10_out_rest\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39msave()\n\u001b[1;32m 4\u001b[0m out_1\u001b[38;5;241m.\u001b[39mshape\n",
"File \u001b[0;32m/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/nnsight/envoy.py:336\u001b[0m, in \u001b[0;36mEnvoy.__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 332\u001b[0m device \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mdevice(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcpu\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 334\u001b[0m torch\u001b[38;5;241m.\u001b[39mset_default_device(device)\n\u001b[0;32m--> 336\u001b[0m proxy \u001b[38;5;241m=\u001b[39m \u001b[43mmodule_proxy\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mforward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 338\u001b[0m torch\u001b[38;5;241m.\u001b[39m_GLOBAL_DEVICE_CONTEXT\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__exit__\u001b[39m(\u001b[38;5;28;01mNone\u001b[39;00m, \u001b[38;5;28;01mNone\u001b[39;00m, \u001b[38;5;28;01mNone\u001b[39;00m)\n\u001b[1;32m 340\u001b[0m torch\u001b[38;5;241m.\u001b[39m_GLOBAL_DEVICE_CONTEXT \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
"File \u001b[0;32m/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/nnsight/intervention.py:144\u001b[0m, in \u001b[0;36mInterventionProxy.__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 135\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnode\u001b[38;5;241m.\u001b[39mgraph\u001b[38;5;241m.\u001b[39mn_backward_calls \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m 137\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnode\u001b[38;5;241m.\u001b[39madd(\n\u001b[1;32m 138\u001b[0m value\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 139\u001b[0m target\u001b[38;5;241m=\u001b[39mProxy\u001b[38;5;241m.\u001b[39mproxy_call,\n\u001b[1;32m 140\u001b[0m args\u001b[38;5;241m=\u001b[39m[\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnode] \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mlist\u001b[39m(args),\n\u001b[1;32m 141\u001b[0m kwargs\u001b[38;5;241m=\u001b[39mkwargs,\n\u001b[1;32m 142\u001b[0m )\n\u001b[0;32m--> 144\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;21;43m__call__\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/nnsight/tracing/Proxy.py:80\u001b[0m, in \u001b[0;36mProxy.__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 72\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__call__\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Self:\n\u001b[1;32m 73\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 74\u001b[0m \u001b[38;5;124;03m Calling a Proxy object just creates a Proxy.proxy_call operation.\u001b[39;00m\n\u001b[1;32m 75\u001b[0m \n\u001b[1;32m 76\u001b[0m \u001b[38;5;124;03m Returns:\u001b[39;00m\n\u001b[1;32m 77\u001b[0m \u001b[38;5;124;03m Proxy: New call proxy.\u001b[39;00m\n\u001b[1;32m 78\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m---> 80\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnode\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madd\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 81\u001b[0m \u001b[43m \u001b[49m\u001b[43mtarget\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mProxy\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mproxy_call\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 82\u001b[0m \u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m[\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnode\u001b[49m\u001b[43m]\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;28;43mlist\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43margs\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 83\u001b[0m \u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 84\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/nnsight/tracing/Node.py:199\u001b[0m, in \u001b[0;36mNode.add\u001b[0;34m(self, target, value, args, kwargs, name)\u001b[0m\n\u001b[1;32m 186\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mis_tracing():\n\u001b[1;32m 188\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m Proxy(\n\u001b[1;32m 189\u001b[0m Node(\n\u001b[1;32m 190\u001b[0m name\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 196\u001b[0m )\n\u001b[1;32m 197\u001b[0m )\u001b[38;5;241m.\u001b[39mvalue\n\u001b[0;32m--> 199\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgraph\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madd\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 200\u001b[0m \u001b[43m \u001b[49m\u001b[43mtarget\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtarget\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mvalue\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mvalue\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mname\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mname\u001b[49m\n\u001b[1;32m 201\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/nnsight/tracing/Graph.py:146\u001b[0m, in \u001b[0;36mGraph.add\u001b[0;34m(self, target, value, args, kwargs, name)\u001b[0m\n\u001b[1;32m 140\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m FakeTensorMode(\n\u001b[1;32m 141\u001b[0m allow_non_fake_inputs\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[1;32m 142\u001b[0m shape_env\u001b[38;5;241m=\u001b[39mShapeEnv(assume_static_by_default\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m),\n\u001b[1;32m 143\u001b[0m ) \u001b[38;5;28;01mas\u001b[39;00m fake_mode:\n\u001b[1;32m 144\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m FakeCopyMode(fake_mode):\n\u001b[0;32m--> 146\u001b[0m value \u001b[38;5;241m=\u001b[39m \u001b[43mtarget\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 147\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mNode\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mprepare_proxy_values\u001b[49m\u001b[43m(\u001b[49m\u001b[43m_args\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 148\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mNode\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mprepare_proxy_values\u001b[49m\u001b[43m(\u001b[49m\u001b[43m_kwargs\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 149\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 151\u001b[0m target_name \u001b[38;5;241m=\u001b[39m target \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(target, \u001b[38;5;28mstr\u001b[39m) \u001b[38;5;28;01melse\u001b[39;00m target\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\n\u001b[1;32m 153\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m target_name \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mname_idx:\n",
"File \u001b[0;32m/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/nnsight/tracing/Proxy.py:32\u001b[0m, in \u001b[0;36mProxy.proxy_call\u001b[0;34m(callable, *args, **kwargs)\u001b[0m\n\u001b[1;32m 30\u001b[0m \u001b[38;5;129m@staticmethod\u001b[39m\n\u001b[1;32m 31\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mproxy_call\u001b[39m(\u001b[38;5;28mcallable\u001b[39m: Callable, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Self:\n\u001b[0;32m---> 32\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mcallable\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py:736\u001b[0m, in \u001b[0;36mLlamaDecoderLayer.forward\u001b[0;34m(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position, **kwargs)\u001b[0m\n\u001b[1;32m 730\u001b[0m warnings\u001b[38;5;241m.\u001b[39mwarn(\n\u001b[1;32m 731\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mPassing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 732\u001b[0m )\n\u001b[1;32m 734\u001b[0m residual \u001b[38;5;241m=\u001b[39m hidden_states\n\u001b[0;32m--> 736\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43minput_layernorm\u001b[49m\u001b[43m(\u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 738\u001b[0m \u001b[38;5;66;03m# Self Attention\u001b[39;00m\n\u001b[1;32m 739\u001b[0m hidden_states, self_attn_weights, present_key_value \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mself_attn(\n\u001b[1;32m 740\u001b[0m hidden_states\u001b[38;5;241m=\u001b[39mhidden_states,\n\u001b[1;32m 741\u001b[0m attention_mask\u001b[38;5;241m=\u001b[39mattention_mask,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 747\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs,\n\u001b[1;32m 748\u001b[0m )\n",
"File \u001b[0;32m/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/torch/nn/modules/module.py:1511\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1509\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1510\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1511\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/torch/nn/modules/module.py:1572\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1569\u001b[0m called_always_called_hooks\u001b[38;5;241m.\u001b[39madd(hook_id)\n\u001b[1;32m 1571\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m hook_id \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks_with_kwargs:\n\u001b[0;32m-> 1572\u001b[0m hook_result \u001b[38;5;241m=\u001b[39m \u001b[43mhook\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mresult\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1573\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 1574\u001b[0m hook_result \u001b[38;5;241m=\u001b[39m hook(\u001b[38;5;28mself\u001b[39m, args, result)\n",
"File \u001b[0;32m/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/nnsight/envoy.py:180\u001b[0m, in \u001b[0;36mEnvoy._hook\u001b[0;34m(self, module, input, input_kwargs, output)\u001b[0m\n\u001b[1;32m 176\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_hook\u001b[39m(\n\u001b[1;32m 177\u001b[0m \u001b[38;5;28mself\u001b[39m, module: torch\u001b[38;5;241m.\u001b[39mnn\u001b[38;5;241m.\u001b[39mModule, \u001b[38;5;28minput\u001b[39m: Any, input_kwargs: Dict, output: Any\n\u001b[1;32m 178\u001b[0m ):\n\u001b[0;32m--> 180\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[43mdetect_fake_mode\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m)\u001b[49m:\n\u001b[1;32m 182\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_reset_proxies(propagate\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[1;32m 184\u001b[0m \u001b[38;5;28minput\u001b[39m \u001b[38;5;241m=\u001b[39m (\u001b[38;5;28minput\u001b[39m, input_kwargs)\n",
"File \u001b[0;32m/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/torch/_guards.py:826\u001b[0m, in \u001b[0;36mdetect_fake_mode\u001b[0;34m(inputs)\u001b[0m\n\u001b[1;32m 824\u001b[0m fake_mode, desc1, i1 \u001b[38;5;241m=\u001b[39m fake_modes[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 825\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m m, desc2, i2 \u001b[38;5;129;01min\u001b[39;00m fake_modes[\u001b[38;5;241m1\u001b[39m:]:\n\u001b[0;32m--> 826\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m fake_mode \u001b[38;5;129;01mis\u001b[39;00m m, (\n\u001b[1;32m 827\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfake mode (\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfake_mode\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m) from \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mdesc1\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mi1\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m doesn\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mt match mode (\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mm\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m) from \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mdesc2\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mi2\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 828\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfake mode from \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mdesc1\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mi1\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m allocated at:\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;132;01m{\u001b[39;00mfake_mode\u001b[38;5;241m.\u001b[39mstack\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 829\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfake mode from \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mdesc2\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mi2\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m allocated at:\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;132;01m{\u001b[39;00mm\u001b[38;5;241m.\u001b[39mstack\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 830\u001b[0m )\n\u001b[1;32m 831\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m fake_mode\n\u001b[1;32m 832\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n",
"\u001b[0;31mAssertionError\u001b[0m: fake mode (<torch._subclasses.fake_tensor.FakeTensorMode object at 0x14d592e1e850>) from active fake mode 0 doesn't match mode (<torch._subclasses.fake_tensor.FakeTensorMode object at 0x14d592fc2450>) from fake tensor input 0\n\nfake mode from active fake mode 0 allocated at:\n File \"<frozen runpy>\", line 198, in _run_module_as_main\n File \"<frozen runpy>\", line 88, in _run_code\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/ipykernel_launcher.py\", line 17, in <module>\n app.launch_new_instance()\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/traitlets/config/application.py\", line 1075, in launch_instance\n app.start()\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/ipykernel/kernelapp.py\", line 701, in start\n self.io_loop.start()\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/tornado/platform/asyncio.py\", line 205, in start\n self.asyncio_loop.run_forever()\n File \"/home/a.guha/miniconda3/lib/python3.11/asyncio/base_events.py\", line 607, in run_forever\n self._run_once()\n File \"/home/a.guha/miniconda3/lib/python3.11/asyncio/base_events.py\", line 1922, in _run_once\n handle._run()\n File \"/home/a.guha/miniconda3/lib/python3.11/asyncio/events.py\", line 80, in _run\n self._context.run(self._callback, *self._args)\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/ipykernel/kernelbase.py\", line 534, in dispatch_queue\n await self.process_one()\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/ipykernel/kernelbase.py\", line 523, in process_one\n await dispatch(*args)\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/ipykernel/kernelbase.py\", line 429, in dispatch_shell\n await result\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/ipykernel/kernelbase.py\", line 767, in execute_request\n reply_content = await reply_content\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/ipykernel/ipkernel.py\", line 429, in do_execute\n res = shell.run_cell(\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/ipykernel/zmqshell.py\", line 549, in run_cell\n return super().run_cell(*args, **kwargs)\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/IPython/core/interactiveshell.py\", line 3051, in run_cell\n result = self._run_cell(\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/IPython/core/interactiveshell.py\", line 3106, in _run_cell\n result = runner(coro)\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/IPython/core/async_helpers.py\", line 129, in _pseudo_sync_runner\n coro.send(None)\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/IPython/core/interactiveshell.py\", line 3311, in run_cell_async\n has_raised = await self.run_ast_nodes(code_ast.body, cell_name,\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/IPython/core/interactiveshell.py\", line 3493, in run_ast_nodes\n if await self.run_code(code, result, async_=asy):\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/IPython/core/interactiveshell.py\", line 3553, in run_code\n exec(code_obj, self.user_global_ns, self.user_ns)\n File \"/tmp/ipykernel_1636525/1197697725.py\", line 3, in <module>\n out_1 = model.model.layers[11](layer10_out, layer10_out_rest).save()\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/nnsight/envoy.py\", line 336, in __call__\n proxy = module_proxy.forward(*args, **kwargs)\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/nnsight/intervention.py\", line 144, in __call__\n return super().__call__(*args, **kwargs)\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/nnsight/tracing/Proxy.py\", line 80, in __call__\n return self.node.add(\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/nnsight/tracing/Node.py\", line 199, in add\n return self.graph.add(\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/nnsight/tracing/Graph.py\", line 140, in add\n with FakeTensorMode(\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py\", line 1364, in __init__\n self.stack = \"\".join(traceback.format_stack())\n\nfake mode from fake tensor input 0 allocated at:\n File \"<frozen runpy>\", line 198, in _run_module_as_main\n File \"<frozen runpy>\", line 88, in _run_code\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/ipykernel_launcher.py\", line 17, in <module>\n app.launch_new_instance()\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/traitlets/config/application.py\", line 1075, in launch_instance\n app.start()\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/ipykernel/kernelapp.py\", line 701, in start\n self.io_loop.start()\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/tornado/platform/asyncio.py\", line 205, in start\n self.asyncio_loop.run_forever()\n File \"/home/a.guha/miniconda3/lib/python3.11/asyncio/base_events.py\", line 607, in run_forever\n self._run_once()\n File \"/home/a.guha/miniconda3/lib/python3.11/asyncio/base_events.py\", line 1922, in _run_once\n handle._run()\n File \"/home/a.guha/miniconda3/lib/python3.11/asyncio/events.py\", line 80, in _run\n self._context.run(self._callback, *self._args)\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/ipykernel/kernelbase.py\", line 534, in dispatch_queue\n await self.process_one()\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/ipykernel/kernelbase.py\", line 523, in process_one\n await dispatch(*args)\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/ipykernel/kernelbase.py\", line 429, in dispatch_shell\n await result\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/ipykernel/kernelbase.py\", line 767, in execute_request\n reply_content = await reply_content\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/ipykernel/ipkernel.py\", line 429, in do_execute\n res = shell.run_cell(\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/ipykernel/zmqshell.py\", line 549, in run_cell\n return super().run_cell(*args, **kwargs)\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/IPython/core/interactiveshell.py\", line 3051, in run_cell\n result = self._run_cell(\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/IPython/core/interactiveshell.py\", line 3106, in _run_cell\n result = runner(coro)\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/IPython/core/async_helpers.py\", line 129, in _pseudo_sync_runner\n coro.send(None)\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/IPython/core/interactiveshell.py\", line 3311, in run_cell_async\n has_raised = await self.run_ast_nodes(code_ast.body, cell_name,\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/IPython/core/interactiveshell.py\", line 3493, in run_ast_nodes\n if await self.run_code(code, result, async_=asy):\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/IPython/core/interactiveshell.py\", line 3553, in run_code\n exec(code_obj, self.user_global_ns, self.user_ns)\n File \"/tmp/ipykernel_1636525/1197697725.py\", line 1, in <module>\n with model.trace(PROMPT) as tracer:\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/nnsight/models/NNsightModel.py\", line 200, in trace\n runner.invoke(*inputs, **invoker_args).__enter__()\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/nnsight/contexts/Invoker.py\", line 60, in __enter__\n with FakeTensorMode(\n File \"/work/arjunguha-research-group/arjun/venvs/jan2024/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py\", line 1364, in __init__\n self.stack = \"\".join(traceback.format_stack())\n"
]
}
],
"source": [
"with model.trace(PROMPT) as tracer:\n",
" (layer10_out, layer10_out_rest) = model.model.layers[10].output\n",
" out_1 = model.model.layers[11](layer10_out, layer10_out_rest).save()\n",
"out_1.shape"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "jan2024",
"language": "python",
"name": "jan2024"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment