Skip to content

Instantly share code, notes, and snippets.

@poedator
Created March 18, 2023 15:36
Show Gist options
  • Save poedator/0a28b18a0cc275cfbc972a29c7185b55 to your computer and use it in GitHub Desktop.
Save poedator/0a28b18a0cc275cfbc972a29c7185b55 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "e10492b6-5a32-422c-8ce5-2050347df3aa",
"metadata": {
"tags": []
},
"source": [
"# WHOLE LAYER TEST. L46"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "bebc0c8e-d531-4a2a-9c31-1892beae7c70",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Sat Mar 18 16:13:42 MSK 2023\n"
]
}
],
"source": [
"!date"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "40692f46-77d4-42f7-9388-19e4a53bc499",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"env: CUDA_VISIBLE_DEVICES=4\n",
"env: TRANSFORMERS_CACHE=/extra_disk_1/yozh/transformers_cache\n"
]
},
{
"data": {
"text/html": [
"<style>.container { width:100% !important; }</style>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"CUDA extension not installed.\n"
]
}
],
"source": [
"%env CUDA_VISIBLE_DEVICES=4\n",
"%env TRANSFORMERS_CACHE=/extra_disk_1/yozh/transformers_cache\n",
"\n",
"%load_ext autoreload\n",
"%autoreload 2\n",
"\n",
"import numpy as np\n",
"import os\n",
"import re\n",
"import gc\n",
"import copy\n",
"from typing import Optional, Tuple, Union, Sequence\n",
"import time\n",
"from sklearn.model_selection import train_test_split\n",
"from operator import attrgetter\n",
"from tqdm.auto import trange, tqdm\n",
"from collections import defaultdict\n",
"\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"import transformers\n",
"\n",
"torch.backends.cuda.matmul.allow_tf32 = False\n",
"torch.backends.cudnn.allow_tf32 = False\n",
"\n",
"from IPython.display import display, HTML\n",
"display(HTML(\"<style>.container { width:100% !important; }</style>\"))\n",
"\n",
"import pandas as pd\n",
"import seaborn as sns\n",
"from matplotlib import pyplot as plt\n",
"\n",
"from petals.bloom.from_pretrained import load_pretrained_block\n",
"from transformers.models.bloom.modeling_bloom import BloomGelu\n",
"\n",
"# local imports\n",
"from compress import *\n",
"from utils import * \n",
"from compressable_layers import *"
]
},
{
"cell_type": "markdown",
"id": "fde12a06-da0a-4262-9506-085df94df3d5",
"metadata": {},
"source": [
"### load model\n",
"\n",
"see block code at \n",
"- https://github.com/bigscience-workshop/petals/blob/main/src/petals/bloom/block.py\n",
"- https://github.com/huggingface/transformers/blob/main/src/transformers/models/bloom/modeling_bloom.py"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "f8d259ca-7b4c-4bed-a6f6-0907fa2acd94",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Mar 18 16:14:04.070 [\u001b[1m\u001b[34mINFO\u001b[0m] Loaded bigscience/bloom-petals block 46, <All keys matched successfully>\n"
]
}
],
"source": [
"# Loading single block (petals.bloom.block.WrappedBloomBlock)\n",
"\n",
"model_name = 'bigscience/bloom-petals'\n",
"block_index = 46\n",
"\n",
"block = load_pretrained_block(model_name, block_index=block_index, cache_dir=os.environ['TRANSFORMERS_CACHE'])\n",
"for param in block.parameters():\n",
" param.requires_grad=False\n",
" \n",
"config = transformers.AutoConfig.from_pretrained(model_name)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "857f9d41-d57d-47a3-97a5-98400819465c",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"BloomConfig {\n",
" \"_name_or_path\": \"bigscience/bloom-petals\",\n",
" \"apply_residual_connection_post_layernorm\": false,\n",
" \"architectures\": [\n",
" \"BloomForCausalLM\"\n",
" ],\n",
" \"attention_dropout\": 0.0,\n",
" \"attention_softmax_in_fp32\": true,\n",
" \"bos_token_id\": 1,\n",
" \"dht_prefix\": \"bigscience/bloom-petals\",\n",
" \"eos_token_id\": 2,\n",
" \"hidden_dropout\": 0.0,\n",
" \"hidden_size\": 14336,\n",
" \"initializer_range\": 0.02,\n",
" \"layer_norm_epsilon\": 1e-05,\n",
" \"masked_softmax_fusion\": true,\n",
" \"model_type\": \"bloom\",\n",
" \"n_head\": 112,\n",
" \"n_layer\": 70,\n",
" \"pad_token_id\": 3,\n",
" \"pretraining_tp\": 4,\n",
" \"slow_but_exact\": false,\n",
" \"torch_dtype\": \"bfloat16\",\n",
" \"transformers_version\": \"4.25.1\",\n",
" \"use_cache\": true,\n",
" \"vocab_size\": 250880\n",
"}"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"config"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "55d73e8f-8eec-4be0-8b65-b72e3739b1a4",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"WrappedBloomBlock(\n",
" (input_layernorm): LayerNorm((14336,), eps=1e-05, elementwise_affine=True)\n",
" (self_attention): BloomAttention(\n",
" (query_key_value): Linear(in_features=14336, out_features=43008, bias=True)\n",
" (dense): Linear(in_features=14336, out_features=14336, bias=True)\n",
" (attention_dropout): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (post_attention_layernorm): LayerNorm((14336,), eps=1e-05, elementwise_affine=True)\n",
" (mlp): BloomMLP(\n",
" (dense_h_to_4h): Linear(in_features=14336, out_features=57344, bias=True)\n",
" (gelu_impl): BloomGelu()\n",
" (dense_4h_to_h): Linear(in_features=57344, out_features=14336, bias=True)\n",
" )\n",
")"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"block"
]
},
{
"cell_type": "markdown",
"id": "0a2da100-e4a0-4f82-a3a1-f2e8299bbe4f",
"metadata": {},
"source": [
"### load data"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "3cc2695b-135f-4046-876a-163fbb19e879",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from pathlib import Path\n",
"import re\n",
"import torch\n",
"from petals.bloom.from_pretrained import load_pretrained_block\n",
"from transformers.models.bloom.modeling_bloom import BloomGelu"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "817d231a-ea4c-405f-af9a-660770a4e54e",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a8a1a1bb973c45908f0c33a26db906d3",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/271 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"loaded 271 samples with shape torch.Size([1, 1024, 14336]) each\n"
]
}
],
"source": [
"path46 = '/home/jheuristic/blocks_46_47'\n",
"samples_paths = [p for p in Path(path46).iterdir()]\n",
"samples_paths.sort(key = lambda x: int(re.search(r'sample(\\d+)_', x.name)[1]))\n",
"x_samples = [torch.load(p) for p in tqdm(samples_paths)]\n",
"# X = torch.cat(x_slices, dim=1).squeeze(0) # merging tensors into one\n",
"# x_samples = [dict(hidden_states=x) for i, x in enumerate(x_samples)]\n",
"print(f\"loaded {len(x_samples)} samples with shape {x_samples[0].shape} each\")"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "999864ee-521f-40f6-8878-93b5b126dcc3",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"(243, 28, torch.bfloat16)"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Train_Test_split using whole samples of 1024 tokens\n",
"train_idx, test_idx = train_test_split(range(len(x_samples)), test_size=0.1)\n",
"X0_train = [x_samples[i] for i in train_idx]\n",
"X0_test = [x_samples[i] for i in test_idx]\n",
"# del x_samples\n",
"len(X0_train), len(X0_test), X0_train[0].dtype"
]
},
{
"cell_type": "markdown",
"id": "738a4c93-c75f-4f3c-9346-6dc0decadf68",
"metadata": {},
"source": [
"# FORWARD PASS"
]
},
{
"cell_type": "markdown",
"id": "dad32079-9514-4bd7-842f-e57fac7b7c68",
"metadata": {
"tags": []
},
"source": [
"#### 0. setup"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "15bb8bf3-5b00-481d-9b58-c7dae42d3395",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"def calc_losses(i, layer_name, tasks=('train', 'test'), suffixes=('_q', '_q_cum', '_refit')):\n",
" loss1 = {'layer_name': layer_name}\n",
" for t in tasks:\n",
" for suffix in suffixes:\n",
" tag = t + suffix\n",
" if tag in X[i].keys():\n",
" mse_loss, rse_loss = batch_mse_loss(X[i][t], X[i][tag], desc = tag)\n",
" loss1.update({f'{t}{suffix}': (mse_loss, rse_loss)})\n",
" return loss1"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "1fc39bad-d0fe-4a80-8797-a6e56a053eb7",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"wbits = 3"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "322af863-23a5-4cf1-99ab-7ea628aa32f7",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"losses = dict() # storage for losses by layer"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "9e679b55-abce-42e9-88c5-2864a9b5f235",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# main storage for interim data\n",
"X = defaultdict(dict)\n",
"X[0]={'train': [x_samples[i] for i in train_idx],\n",
" 'test': [x_samples[i] for i in test_idx]\n",
" }"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "a3284c09-abcb-4499-97e3-2d61f2c87749",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a9d471710dcf464784f78e85185d8e23",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/243 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c0d6478772f8424384025302cda6e2d2",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/28 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Add extra fields to data (alibi, residual, masks)\n",
"for k, v in X[0].items():\n",
" X[0][k] = [augment_batch(xb, block) for xb in tqdm(v)]"
]
},
{
"cell_type": "markdown",
"id": "6626bc4d-ef33-4e23-bdfc-36aff66b7a2d",
"metadata": {
"tags": []
},
"source": [
"#### 1. layer_norm"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "708f7b21-a267-4e9d-95e8-71836b17bfa6",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# run this cell before any other one for this layer\n",
"i = 1\n",
"layer_name = '46_layer_norm_0'\n",
"\n",
"layer_n0 = LayerNormLayer(module=block.input_layernorm)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "43856496-e37e-4998-8892-5add70afbc8d",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a9fb9a320d3746fa92a2e04927e50f7b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"LNorm: 0%| | 0/243 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "010882312cff4b24a12b969644fc3fd3",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"LNorm: 0%| | 0/28 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Regular test and train data with original layer \n",
"layer = layer_n0\n",
"suffix = ''\n",
"\n",
"for k, v in X[i - 1].items():\n",
" X[i][k + suffix] = layer(X[i - 1][k])"
]
},
{
"cell_type": "markdown",
"id": "46f99e48-2431-4f33-aaeb-fa5ed1d03ca4",
"metadata": {
"tags": []
},
"source": [
"#### 2. attention"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "ed94e7f3-7461-4289-9921-021cdf08c2e0",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# run this cell before any other one for this layer\n",
"i = 2\n",
"layer_name = '46_attention'\n",
"layer_a = AttnLayer(module=block.self_attention)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "0f9d0d6b-5bee-41bb-91bd-b22f15e01ad1",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "ab31c2fd407c465d93b278d183172a94",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/243 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "2620cfc5204e47c49b24fb16f23a7351",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/28 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Regular test and train data with original layer \n",
"layer = layer_a\n",
"suffix = ''\n",
"\n",
"for k in ['train', 'test']:\n",
" X[i][k + suffix] = layer(X[i - 1][k])"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "41b71794-e004-4dbf-a83c-ca80c5f468ae",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "f8ef1dce6d0e4f2cacd24cc7dcd42574",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"calculating XTX matrix: 0%| | 0/243 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "57608e6dab6a48c3892246f1218aa2f8",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"gptq: 0%| | 0/112 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"layer_a_q = layer_a.compress_gptq(X[1]['train'], wbits=3)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "53403baf-22de-4a50-87d9-5cffa3fbcfd6",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "922bdcdd2d6b46aa8057c056f05edb8f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/243 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e7ceb848bf1f41698ddd0821d217515c",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/28 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Regular test and train data with GPTQ-compressed layer\n",
"layer = layer_a_q\n",
"suffix = '_q'\n",
"\n",
"for k in ['train', 'test']:\n",
" X[i][k + suffix] = layer(X[i - 1][k])"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "3afa2322-3275-4302-95f1-22ce31098b35",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"train_q: 0%| | 0/243 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"test_q: 0%| | 0/28 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'layer_name': '46_attention', 'train_q': (0.0108642578125, 0.07802173495292664), 'test_q': (0.01251220703125, 0.09104174375534058)}\n"
]
}
],
"source": [
"losses[layer_name] = calc_losses(i=i, layer_name=layer_name)\n",
"print(losses[layer_name]) "
]
},
{
"cell_type": "markdown",
"id": "41e1d789-5ee2-4585-b047-f408a6692bd1",
"metadata": {
"tags": []
},
"source": [
"#### 3. dense"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "569e031a-f2ce-45aa-a739-84cdaa3fd1b9",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# run this cell before any other one for this layer\n",
"i = 3\n",
"layer_name = '46_dense'\n",
"\n",
"layer_d = LinearLayer(block.self_attention.dense)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "968506d0-d324-4ff9-b507-d541bdcf89f8",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "f8b9a59021494d1c99a273626aaacafe",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/243 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "02d4879b827144c986dc17ecd0e9081e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/28 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Regular test and train data with original layer \n",
"layer = layer_d\n",
"suffix = ''\n",
"\n",
"for k in 'train test'.split():\n",
" X[i][k + suffix] = layer(X[i - 1][k])"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "cebfcbe0-d942-487c-b620-a13e76a965c5",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "7ec0b04ac63f48b3a9e3fc891441cb73",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"calculating XTX matrix: 0%| | 0/243 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "effd521588b84de9b2e298b465e700ee",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"gptq: 0%| | 0/112 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"layer_d_q = layer_d.compress_gptq(X[2]['train'], wbits=wbits)"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "0e918b87-9d90-4125-ad1b-a51fae71361f",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "840383a6117e464a8bbe46ed47e90daa",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/243 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "00ba87913125465aa22014bda9545383",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/28 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Regular test and train data with GPTQ layer\n",
"layer = layer_d_q\n",
"suffix = '_q'\n",
"\n",
"for k in 'train test'.split():\n",
" X[i][k + suffix] = layer(X[i - 1][k])"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "7145a99c-8e5d-4fbf-a6c2-1f3ca7d19ee3",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "59a89b78fc28449886d038c8a6c05771",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/243 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "ba2041c7b81541c792aacefa6febae4a",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/28 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# CUMULATIVE GPTQ test and train data with GPTQ layer - NO REFIT\n",
"layer = layer_d_q\n",
"suffix = '_q_cum'\n",
"\n",
"for k in 'train test'.split():\n",
" X[i][k + suffix] = layer(X[i - 1][k + '_q'])"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "09070b35-c8ad-42b7-8bed-a26593d820ef",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"without validation set, using train loss for early stopping\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "18ea0df6bd0e4c0f91849c4e2b5b9515",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Refitter starting: 0%| | 0/3037 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"EARLY STOP\n",
"Fit complete with train loss 0.010352, test loss 0.010410\n"
]
}
],
"source": [
"layer_d_refit = layer_d.refit(X_train=X[2]['train_q'], Y_train=X[3]['train'])"
]
},
{
"cell_type": "code",
"execution_count": 28,
"id": "3d7c7d4d-f11a-48e5-850f-8c00744b0a50",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d96b65f7f39a4dc8b7731fcfb78b98b0",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"calculating XTX matrix: 0%| | 0/243 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "5c6e36e025924c81865ea16af4b151ac",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"gptq: 0%| | 0/112 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"layer_d_refit_q = layer_d_refit.compress_gptq(X[2]['train_q'], wbits=wbits)"
]
},
{
"cell_type": "code",
"execution_count": 29,
"id": "a4837c77-7f15-49a8-a200-b7509134afc3",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "f8baabd8bd5049a9a2d3544be6033857",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/243 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e5926c2da60345649b7c8891188212ad",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/28 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# CUMULATIVE GPTQ test and train data with REFITTED GPTQ layer\n",
"layer = layer_d_refit_q\n",
"suffix_in = '_q'\n",
"suffix_out = '_refit'\n",
"\n",
"for k in 'train test'.split():\n",
" X[i][k + suffix_out] = layer(X[i - 1][k + suffix_in])"
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "2442e9c4-c08c-4c73-9cd7-b38d4c6a58a1",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"train_q: 0%| | 0/243 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"train_q_cum: 0%| | 0/243 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"train_refit: 0%| | 0/243 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"test_q: 0%| | 0/28 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"test_q_cum: 0%| | 0/28 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"test_refit: 0%| | 0/28 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'layer_name': '46_dense', 'train_q': (0.01177978515625, 0.0545545294880867), 'train_q_cum': (0.03173828125, 0.14698629081249237), 'train_refit': (0.020751953125, 0.09610642492771149), 'test_q': (0.01513671875, 0.07037457823753357), 'test_q_cum': (0.032470703125, 0.15096482634544373), 'test_refit': (0.030029296875, 0.139614075422287)}\n"
]
}
],
"source": [
"losses[layer_name] = calc_losses(i=i, layer_name=layer_name)\n",
"print(losses[layer_name]) "
]
},
{
"cell_type": "code",
"execution_count": 31,
"id": "c4b5c209-aaae-473c-ae99-3f6246a2808b",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"dict_keys(['train', 'test', 'train_q', 'test_q', 'train_q_cum', 'test_q_cum', 'train_refit', 'test_refit'])"
]
},
"execution_count": 31,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X[3].keys()"
]
},
{
"cell_type": "code",
"execution_count": 32,
"id": "e63831f9-c77b-4d69-b0c1-de29e3634284",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"allocated: 0.0 GB, max: 8.8 GB, before flush: 0.0 GB\n"
]
}
],
"source": [
"cuda_mem_use()"
]
},
{
"cell_type": "markdown",
"id": "0129840e-049c-4820-ad1e-cfd03ab20e47",
"metadata": {
"tags": []
},
"source": [
"#### 4. add resid 0"
]
},
{
"cell_type": "code",
"execution_count": 33,
"id": "f537e13d-606d-4635-bd42-76b3f6e4ef9d",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# run this cell before any other one for this layer\n",
"i = 4\n",
"layer_name = '46_add_resid_0'"
]
},
{
"cell_type": "code",
"execution_count": 34,
"id": "f65d9f4d-695d-4623-882a-c4e491d757dc",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"def add_sequences(aa, bb):\n",
" \"\"\"adds 2 lists of tensors pairwise, using cuda\"\"\"\n",
" cc = []\n",
" for a, b in tqdm(zip(aa, bb), leave=False, total=len(aa)):\n",
" cc.append((a.cuda() + b.cuda()).cpu())\n",
" return cc"
]
},
{
"cell_type": "code",
"execution_count": 35,
"id": "44ba45e0-1310-4abf-9cbb-ce85232882f9",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"dict_keys(['hidden_states', 'attention_mask', 'alibi', 'residual'])"
]
},
"execution_count": 35,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X[0]['test'][0].keys()"
]
},
{
"cell_type": "code",
"execution_count": 36,
"id": "32145b2f-2da6-4ce5-9a11-d690b3ba91de",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "974650bb35704699a32a0a22d1a1c1a5",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/8 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"train train; "
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/243 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"test test; "
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/28 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"train_q train; "
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/243 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"test_q test; "
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/28 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"train_q_cum train; "
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/243 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"test_q_cum test; "
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/28 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"train_refit train; "
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/243 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"test_refit test; "
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/28 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# adding corresponding residual\n",
"\n",
"for key, val in tqdm(X[i - 1].items()):\n",
" resid_key = key.split('_')[0]\n",
" print(key, resid_key, end = '; ')\n",
" X[i][key] = add_sequences(val, [xb['hidden_states'] for xb in X[0][resid_key]])"
]
},
{
"cell_type": "code",
"execution_count": 37,
"id": "edfe7a0c-ea47-49d6-833a-ebdaad5a8d25",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"dict_keys(['train', 'test', 'train_q', 'test_q', 'train_q_cum', 'test_q_cum', 'train_refit', 'test_refit'])"
]
},
"execution_count": 37,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X[4].keys()"
]
},
{
"cell_type": "code",
"execution_count": 38,
"id": "9e7ec44d-d124-44a7-8dc8-5f0c172c694b",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"train_q: 0%| | 0/243 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"train_q_cum: 0%| | 0/243 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"train_refit: 0%| | 0/243 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"test_q: 0%| | 0/28 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"test_q_cum: 0%| | 0/28 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"test_refit: 0%| | 0/28 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'layer_name': '46_add_resid_0', 'train_q': (0.01202392578125, 0.00022347424237523228), 'train_q_cum': (0.031982421875, 0.000594418786931783), 'train_refit': (0.020751953125, 0.0003856915864162147), 'test_q': (0.01531982421875, 0.0002849726297426969), 'test_q_cum': (0.032470703125, 0.0006040057633072138), 'test_refit': (0.0301513671875, 0.0005608624778687954)}\n"
]
}
],
"source": [
"losses[layer_name] = calc_losses(i=i, layer_name=layer_name)\n",
"print(losses[layer_name]) "
]
},
{
"cell_type": "code",
"execution_count": 39,
"id": "8fe06f37-af23-42ce-85d6-30b7dddf75b8",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"allocated: 0.0 GB, max: 8.8 GB, before flush: 0.0 GB\n"
]
}
],
"source": [
"cuda_mem_use()"
]
},
{
"cell_type": "markdown",
"id": "910a19ef-8e81-4fe3-b7af-c119c538fd89",
"metadata": {
"tags": []
},
"source": [
"#### 5. layer_norm_1"
]
},
{
"cell_type": "code",
"execution_count": 40,
"id": "ed0b38ed-4334-430a-ba18-65233be15eb4",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# run this cell before any other one for this layer\n",
"i = 5\n",
"layer_name = '46_layer_norm_1'\n",
"\n",
"layer_n1 = LayerNormLayer(module=block.post_attention_layernorm)"
]
},
{
"cell_type": "code",
"execution_count": 41,
"id": "6470829a-3315-4ba5-b849-2b341ceb5e91",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "70432747c2ff4407ba3320b1c9a19873",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/8 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "0a66e81823314cfbb29c74b14ae92282",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"LNorm: 0%| | 0/243 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d8097cdc2ec64044a1f29ec4fbc8ae2e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"LNorm: 0%| | 0/28 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "9f012accc78d44f698965a5e6c49e49d",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"LNorm: 0%| | 0/243 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "5264afb399c24504b4c293fd406dd8de",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"LNorm: 0%| | 0/28 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e93621c538104940ba1418f680c6926d",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"LNorm: 0%| | 0/243 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a098f4ff1a684e6b963fb97c7aeaffb9",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"LNorm: 0%| | 0/28 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "af731628c3084d3f82dd1f98466f8259",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"LNorm: 0%| | 0/243 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "2baa2c78fc32410395e8e78f2b7410ae",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"LNorm: 0%| | 0/28 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"for k, v in tqdm(X[i - 1].items()):\n",
" X[i][k] = layer_n1(v)"
]
},
{
"cell_type": "code",
"execution_count": 42,
"id": "8d8cf8bd-a867-4635-aa24-f008f4e0409c",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"train_q: 0%| | 0/243 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"train_q_cum: 0%| | 0/243 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"train_refit: 0%| | 0/243 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"test_q: 0%| | 0/28 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"test_q_cum: 0%| | 0/28 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"test_refit: 0%| | 0/28 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'layer_name': '46_layer_norm_1', 'train_q': (0.000545501708984375, 0.0007736138068139553), 'train_q_cum': (0.00115203857421875, 0.0016337857814505696), 'train_refit': (0.00104522705078125, 0.001482308958657086), 'test_q': (0.00066375732421875, 0.0009506119531579316), 'test_q_cum': (0.00141143798828125, 0.002021416090428829), 'test_refit': (0.00130462646484375, 0.0018684441456571221)}\n"
]
}
],
"source": [
"losses[layer_name] = calc_losses(i=i, layer_name=layer_name)\n",
"print(losses[layer_name]) "
]
},
{
"cell_type": "markdown",
"id": "0c68088b-5dd9-498a-9052-af9159f7bea9",
"metadata": {},
"source": [
"#### 6. MLP-1"
]
},
{
"cell_type": "code",
"execution_count": 43,
"id": "9a031013-08b8-4b64-8d72-e8f1667f3499",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# run this cell before any other one for this layer\n",
"i = 6\n",
"layer_name = '46_mlp1'"
]
},
{
"cell_type": "code",
"execution_count": 44,
"id": "3f8b934c-d794-4e87-98c8-b8179e0966d9",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"layer_mlp1 = LinearLayer(block.mlp.dense_h_to_4h)"
]
},
{
"cell_type": "code",
"execution_count": 45,
"id": "534ec464-95da-428f-8c6a-e3731c39b21a",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "8268f7a712334665ad8aaeaced347ca1",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/243 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "8337d76791fe43be8616048e4ae88a46",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/28 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Regular test and train data with original layer \n",
"layer = layer_mlp1\n",
"suffix = ''\n",
"\n",
"for k in ('train', 'test'):\n",
" X[i][k + suffix] = layer(X[i - 1][k])"
]
},
{
"cell_type": "code",
"execution_count": 46,
"id": "f8ebdbed-bee3-4202-95c8-e1dbce28d76c",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "20192bb0d2ac46d9ad218d02e1f51ae2",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"calculating XTX matrix: 0%| | 0/243 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "53e039472df545d2b62ba448693b1ab3",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"gptq: 0%| | 0/112 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"layer_mlp1_q = layer_mlp1.compress_gptq(X[5]['train'], wbits=wbits)"
]
},
{
"cell_type": "code",
"execution_count": 47,
"id": "d03a9e50-ab92-49d7-88d1-13e27bfa93e0",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([57344, 14336])"
]
},
"execution_count": 47,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"layer_mlp1_q.module.weight.data.shape"
]
},
{
"cell_type": "code",
"execution_count": 48,
"id": "b6006798-48da-4cd0-b6d1-b53d14afe76c",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "3787e343b303496c9cd8723586e5f3f4",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/243 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "6d5ae4b0e68e42199d055b3aa20ee13b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/28 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Regular test and train data with GPTQ layer\n",
"layer = layer_mlp1_q\n",
"suffix = '_q'\n",
"\n",
"for k in ('train', 'test'):\n",
" X[i][k + suffix] = layer(X[i - 1][k])"
]
},
{
"cell_type": "code",
"execution_count": 49,
"id": "43118fe7-eb98-46fd-9966-009fe7abbf0e",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "f109242060e84cd5baaf807f333ef6c5",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/243 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "31bda58082b24a00a1cc29c189edfa26",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/28 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# CUMULATIVE GPTQ test and train data with GPTQ layer - NO REFIT\n",
"layer = layer_mlp1_q\n",
"suffix = '_q_cum'\n",
"\n",
"for k in ('train', 'test'):\n",
" X[i][k + suffix] = layer(X[i - 1][k + suffix])"
]
},
{
"cell_type": "code",
"execution_count": 52,
"id": "697fdd8b-4d50-4c86-a59b-6db17bfebd48",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"without validation set, using train loss for early stopping\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "ebb3a29cee0f417895418f5800126b2a",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Refitter starting: 0%| | 0/3037 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"EARLY STOP\n",
"Fit complete with train loss 0.006116, test loss 0.006138\n"
]
}
],
"source": [
"layer_mlp1_refit = layer_mlp1.refit(X_train=X[5]['train_refit'], Y_train=X[6]['train'],\n",
" X_val=X[5]['test_refit'], Y_val=X[6]['test'], )"
]
},
{
"cell_type": "code",
"execution_count": 53,
"id": "002dd949-157c-41bb-b22d-5d269b678173",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "974668df15924537a989541458c15929",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"calculating XTX matrix: 0%| | 0/243 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "4b312552c0774199b453f20145723d47",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"gptq: 0%| | 0/112 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"layer_mlp1_refit_q = layer_mlp1_refit.compress_gptq(X[5]['train_refit'], wbits=wbits)"
]
},
{
"cell_type": "code",
"execution_count": 54,
"id": "7c46503d-5b3a-446f-9ee2-0c4a91ac1a10",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "2a7c52ec399f4236bbbb265fa27cdb97",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/243 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "9cd451339e704afcb84ffa376cce9b4f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/28 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# CUMULATIVE GPTQ test and train data with REFITTED GPTQ layer\n",
"layer = layer_mlp1_refit_q\n",
"suffix_in = '_refit'\n",
"suffix_out = '_refit'\n",
"\n",
"for k in 'train test'.split():\n",
" X[i][k + suffix_out] = layer(X[i - 1][k + suffix_in])"
]
},
{
"cell_type": "code",
"execution_count": 55,
"id": "ebacab4d-e9bb-4d38-91fd-f55011d227d2",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"train_q: 0%| | 0/243 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"train_q_cum: 0%| | 0/243 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"train_refit: 0%| | 0/243 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"test_q: 0%| | 0/28 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"test_q_cum: 0%| | 0/28 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"test_refit: 0%| | 0/28 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'layer_name': '46_mlp1', 'train_q': (0.047119140625, 0.024921249598264694), 'train_q_cum': (0.0517578125, 0.027374638244509697), 'train_refit': (0.04736328125, 0.02505037561058998), 'test_q': (0.053955078125, 0.02883181720972061), 'test_q_cum': (0.0576171875, 0.030788728967308998), 'test_refit': (0.056396484375, 0.03013642504811287)}\n"
]
}
],
"source": [
"losses[layer_name] = calc_losses(i=i, layer_name=layer_name)\n",
"print(losses[layer_name]) "
]
},
{
"cell_type": "code",
"execution_count": 56,
"id": "ca036c50-57b0-4995-9d5b-35d2ea556f2d",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"dict_keys(['train', 'test', 'train_q', 'test_q', 'train_q_cum', 'test_q_cum', 'train_refit', 'test_refit'])"
]
},
"execution_count": 56,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X[6].keys()"
]
},
{
"cell_type": "code",
"execution_count": 57,
"id": "4306f711-f93b-448e-969d-7d1149d9618c",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"allocated: 7.5 GB, max: 19.1 GB, before flush: 7.5 GB\n"
]
}
],
"source": [
"cuda_mem_use()"
]
},
{
"cell_type": "markdown",
"id": "4a0886d4-0cd8-4613-bd40-5c14b26f6948",
"metadata": {},
"source": [
"#### 7. Gelu"
]
},
{
"cell_type": "code",
"execution_count": 58,
"id": "7ab9375b-8874-48f1-a893-55a2526ef321",
"metadata": {},
"outputs": [],
"source": [
"# run this cell before any other one for this layer\n",
"i = 7\n",
"layer_name = '46_mlp_gelu'\n",
"layer_mlp_gelu = Layer(module=block.mlp.gelu_impl)"
]
},
{
"cell_type": "code",
"execution_count": 59,
"id": "0f537a25-a8e9-40a2-9b3e-3993c63854fa",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "171fd1c17127480f86cdde743f7ea017",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/8 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "0c927f08bfc240448214e6f6367e3a86",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/243 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "6c5b6b93ea784963a2ba7c687eb8bd79",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/28 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "ab2b6b10421e439bae1d2a688cbea1b6",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/243 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "2f118d7ead39444cbd2fe14b82387ae7",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/28 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "66c4dcb4b1cd4aec8be587d3bee10df4",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/243 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "3c3b0aca107a4bdb8fa4c79bd067ab49",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/28 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "39b37eff929d4fbeb18857313c824f1a",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/243 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d5d22480048b42e8a118ab6181932d81",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/28 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"for k, v in tqdm(X[i - 1].items()):\n",
" X[i][k] = layer_mlp_gelu(v)"
]
},
{
"cell_type": "code",
"execution_count": 60,
"id": "8d5088ea-d2a4-46ea-98f9-b195bd7ce081",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"train_q: 0%| | 0/243 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"train_q_cum: 0%| | 0/243 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"train_refit: 0%| | 0/243 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"test_q: 0%| | 0/28 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"test_q_cum: 0%| | 0/28 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"test_refit: 0%| | 0/28 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'layer_name': '46_mlp_gelu', 'train_q': (0.00421142578125, 0.059492044150829315), 'train_q_cum': (0.00439453125, 0.062078654766082764), 'train_refit': (0.004119873046875, 0.05819873884320259), 'test_q': (0.00445556640625, 0.06271477788686752), 'test_q_cum': (0.0047607421875, 0.0670103132724762), 'test_refit': (0.004364013671875, 0.06142611801624298)}\n"
]
}
],
"source": [
"losses[layer_name] = calc_losses(i=i, layer_name=layer_name)\n",
"print(losses[layer_name]) "
]
},
{
"cell_type": "markdown",
"id": "2d4c0299-b612-4ccc-ba42-0258d1d6b155",
"metadata": {},
"source": [
"#### 8. MLP-2"
]
},
{
"cell_type": "code",
"execution_count": 61,
"id": "07aa5a66-7bb4-4e79-8385-136b06661447",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# run this cell before any other one for this layer\n",
"i = 8\n",
"layer_name = '46_mlp2'\n",
"\n",
"layer_mlp2 = LinearLayer(block.mlp.dense_4h_to_h)"
]
},
{
"cell_type": "code",
"execution_count": 62,
"id": "ca3108dc-9b31-4cb5-a075-5718ddcbc770",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c5c6fd78af2047d88b4a4e8df3a9adea",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/243 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "95a4cef45bc8484f9e843384be835091",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/28 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Regular test and train data with original layer \n",
"layer = layer_mlp2\n",
"suffix = ''\n",
"\n",
"for k in ('train', 'test'):\n",
" X[i][k + suffix] = layer(X[i - 1][k])"
]
},
{
"cell_type": "code",
"execution_count": 63,
"id": "2dbf0a6d-dc3e-454c-b721-3ae6b32354e3",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "84c33a4d457b47a7a8e78a9636ab4d4f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"calculating XTX matrix: 0%| | 0/243 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "ad49cc1229a748eab84333f886b0f22e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"gptq: 0%| | 0/448 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"layer_mlp2_q = layer_mlp2.compress_gptq(X[7]['train'], wbits=wbits)"
]
},
{
"cell_type": "code",
"execution_count": 64,
"id": "8d52b692-1179-4b9d-87bd-e130301e0a08",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "9f8103e9b6b64f3eb6c1214e1a8ee009",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/243 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "89a57becd4ad4fde8d665c58ac3120b7",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/28 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Regular test and train data with GPTQ layer\n",
"layer = layer_mlp2_q\n",
"suffix = '_q'\n",
"\n",
"for k in ('train', 'test'):\n",
" X[i][k + suffix] = layer(X[i - 1][k])"
]
},
{
"cell_type": "code",
"execution_count": 65,
"id": "535f21d4-27c6-47fe-9015-9976b6354393",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "1680828d1e7240a78ee434f1972ce28f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/243 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e907dee31e6b4d1bb5009c4a8152da1f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/28 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# CUMULATIVE GPTQ test and train data with GPTQ layer - NO REFIT\n",
"layer = layer_mlp2_q\n",
"suffix = '_q_cum'\n",
"\n",
"for k in ('train', 'test'):\n",
" X[i][k + suffix] = layer(X[i - 1][k + suffix])"
]
},
{
"cell_type": "code",
"execution_count": 68,
"id": "281ac375-10a8-41a0-b248-025f1f2fd780",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"early stopping when val loss stops improving.\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "9d100acf27e6423983e04e0d6006e1da",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Refitter starting: 0%| | 0/3037 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"EARLY STOP\n",
"Fit complete with train loss 0.022766, test loss 0.040771\n"
]
}
],
"source": [
"# INITIAL REFIT ATTEMPT \n",
"layer_mlp2_refit = layer_mlp2.refit(X[7]['train_refit'], X[8]['train'], \n",
" X[7]['test_refit'], X[8]['test'])"
]
},
{
"cell_type": "code",
"execution_count": 69,
"id": "a7e83c08-171c-4a2d-9a13-7ee7811da1a9",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "027049bee96a40fd9711293e3121187c",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"calculating XTX matrix: 0%| | 0/243 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "55ee5cef79f5473285f2f9f8c7f2e4a1",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"gptq: 0%| | 0/448 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"layer_mlp2_refit_q = layer_mlp2_refit.compress_gptq(X[7]['train_refit'], wbits=wbits)"
]
},
{
"cell_type": "code",
"execution_count": 70,
"id": "f197e4b4-0081-447d-b90c-db5fdb731daa",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "3de37ecdaa534243b78d102c840857fe",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/243 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "ba3cbeb60c414554872d90bf438595eb",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/28 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# CUMULATIVE GPTQ test and train data with REFITTED GPTQ layer\n",
"layer = layer_mlp2_refit_q\n",
"suffix_in = '_refit'\n",
"suffix_out = '_refit'\n",
"\n",
"for k in 'train test'.split():\n",
" X[i][k + suffix_out] = layer(X[i - 1][k + suffix_in])"
]
},
{
"cell_type": "code",
"execution_count": 71,
"id": "8c7ce3e4-b764-456c-a621-94765f440ffb",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"train_q: 0%| | 0/243 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"train_q_cum: 0%| | 0/243 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"train_refit: 0%| | 0/243 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"test_q: 0%| | 0/28 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"test_q_cum: 0%| | 0/28 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"test_refit: 0%| | 0/28 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'layer_name': '46_mlp2', 'train_q': (0.043212890625, 0.09118527173995972), 'train_q_cum': (0.0732421875, 0.15455131232738495), 'train_refit': (0.0703125, 0.14836925268173218), 'test_q': (0.0654296875, 0.13786007463932037), 'test_q_cum': (0.1005859375, 0.2119341492652893), 'test_refit': (0.103515625, 0.21810698509216309)}\n"
]
}
],
"source": [
"losses[layer_name] = calc_losses(i=i, layer_name=layer_name)\n",
"print(losses[layer_name]) "
]
},
{
"cell_type": "code",
"execution_count": 72,
"id": "993d5f35-e022-4611-b9b1-147b0f9b8646",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"dict_keys(['train', 'test', 'train_q', 'test_q', 'train_q_cum', 'test_q_cum', 'train_refit', 'test_refit'])"
]
},
"execution_count": 72,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X[8].keys()"
]
},
{
"cell_type": "code",
"execution_count": 73,
"id": "e8d09e8e-2bac-48aa-bbf7-880fc01beef3",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"allocated: 7.5 GB, max: 56.5 GB, before flush: 7.5 GB\n"
]
}
],
"source": [
"cuda_mem_use()"
]
},
{
"cell_type": "markdown",
"id": "afe7a7df-9073-48a7-9ae5-d468f57ba155",
"metadata": {},
"source": [
"#### 9. add resid 1"
]
},
{
"cell_type": "code",
"execution_count": 74,
"id": "5128996a-9506-4919-8b16-61cc037592ff",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# run this cell before any other one for this layer\n",
"i = 9\n",
"layer_name = '46_add_resid_1'"
]
},
{
"cell_type": "code",
"execution_count": 75,
"id": "d64d96f8-3d2f-444e-b213-3590bc5a26f6",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"dict_keys(['train', 'test', 'train_q', 'test_q', 'train_q_cum', 'test_q_cum', 'train_refit', 'test_refit'])"
]
},
"execution_count": 75,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# source for resid\n",
"X[4].keys()"
]
},
{
"cell_type": "code",
"execution_count": 76,
"id": "ea809e07-5f30-463c-a3fb-7f5d52a088a4",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "7722f73aae254526b96e8ab8511f687d",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/8 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/243 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/28 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/243 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/28 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/243 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/28 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/243 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/28 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# adding corresponding residual\n",
"\n",
"for key, val in tqdm(X[i - 1].items()):\n",
" resid_key = key\n",
" # print(key, resid_key, end = '; ')\n",
" X[i][key] = add_sequences(val, X[4][resid_key])"
]
},
{
"cell_type": "code",
"execution_count": 77,
"id": "1b358f33-0140-4269-a800-5a0aefbc0c49",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"train_q: 0%| | 0/243 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"train_q_cum: 0%| | 0/243 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"train_refit: 0%| | 0/243 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"test_q: 0%| | 0/28 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"test_q_cum: 0%| | 0/28 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"test_refit: 0%| | 0/28 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'layer_name': '46_add_resid_1', 'train_q': (0.06494140625, 0.0011800023494288325), 'train_q_cum': (0.109375, 0.0019873722922056913), 'train_refit': (0.08056640625, 0.0014639126602560282), 'test_q': (0.08154296875, 0.0014840448275208473), 'test_q_cum': (0.1318359375, 0.0023993540089577436), 'test_refit': (0.1318359375, 0.0023993540089577436)}\n"
]
}
],
"source": [
"losses[layer_name] = calc_losses(i=i, layer_name=layer_name)\n",
"print(losses[layer_name]) "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b34a15c3-7cc0-4af2-bf6c-4872cc5a4acd",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 91,
"id": "113c8f3c-8a6b-4a21-9d0e-9da16b43ea00",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"loss_calc: 0%| | 0/28 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"(0.1318359375, 0.0023993540089577436)"
]
},
"execution_count": 91,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"j = 0\n",
"ii = 9\n",
"# torch.equal(X[ii]['test_refit'][ii], X[4]['test_q_cum'][j])\n",
"t = 'test'\n",
"tag = 'test_q_cum'\n",
"batch_mse_loss(X[ii][t], X[ii][tag])"
]
},
{
"cell_type": "code",
"execution_count": 92,
"id": "5f092017-81d8-43dc-a006-a51396471f45",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"loss_calc: 0%| | 0/28 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"(0.1318359375, 0.0023993540089577436)"
]
},
"execution_count": 92,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"j = 0\n",
"ii = 9\n",
"# torch.equal(X[ii]['test_refit'][ii], X[4]['test_q_cum'][j])\n",
"t = 'test'\n",
"tag = 'test_refit'\n",
"batch_mse_loss(X[ii][t], X[ii][tag])"
]
},
{
"cell_type": "code",
"execution_count": 90,
"id": "c82bf042-0bb9-4fba-b129-371b50f884a1",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"loss_calc: 0%| | 0/28 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"(0.1640625, 0.3178808093070984)"
]
},
"execution_count": 90,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"j = 0\n",
"ii = 8\n",
"# torch.equal(X[ii]['test_refit'][ii], X[4]['test_q_cum'][j])\n",
"t = 'test_refit'\n",
"tag = 'test_q_cum'\n",
"batch_mse_loss(X[ii][t], X[ii][tag])"
]
},
{
"cell_type": "markdown",
"id": "10c2511b-bfcf-4286-8ab0-f9c6cb6845fc",
"metadata": {},
"source": [
"#### 10. final check"
]
},
{
"cell_type": "code",
"execution_count": 78,
"id": "be37a466-4677-43b9-8c66-ec40d3d569fe",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "8c48bed48ad64b4a80cd6c3459e4da6a",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/28 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 2.17 s, sys: 0 ns, total: 2.17 s\n",
"Wall time: 2.16 s\n"
]
}
],
"source": [
"%%time\n",
"Y_ref_test = [block.cuda()(xb['hidden_states'].cuda())[0].cpu() for xb in tqdm(X[0]['test'])]"
]
},
{
"cell_type": "code",
"execution_count": 79,
"id": "9934336e-212f-4a19-9a6e-6c87fcff0135",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"loss_calc: 0%| | 0/28 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"(0.0, 0.0)"
]
},
"execution_count": 79,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"batch_mse_loss(X[9]['test'], Y_ref_test)"
]
},
{
"cell_type": "markdown",
"id": "537cce12-e89d-40b8-bcef-b00ab70b0a13",
"metadata": {},
"source": [
"### SUMMARY of losses "
]
},
{
"cell_type": "code",
"execution_count": 80,
"id": "df9c7fad-93f5-434f-aebb-bf9d546840b7",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"mses = {kk: {k:v[0] for k, v in vv.items() if k != 'layer_name'} for kk, vv in losses.items()}\n",
"df = pd.DataFrame(mses).T"
]
},
{
"cell_type": "code",
"execution_count": 81,
"id": "4b86d3ef-274b-44e4-8734-d653de2a6fa2",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>train_q</th>\n",
" <th>test_q</th>\n",
" <th>train_q_cum</th>\n",
" <th>train_refit</th>\n",
" <th>test_q_cum</th>\n",
" <th>test_refit</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>46_attention</th>\n",
" <td>0.010864</td>\n",
" <td>0.012512</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>46_dense</th>\n",
" <td>0.011780</td>\n",
" <td>0.015137</td>\n",
" <td>0.031738</td>\n",
" <td>0.020752</td>\n",
" <td>0.032471</td>\n",
" <td>0.030029</td>\n",
" </tr>\n",
" <tr>\n",
" <th>46_add_resid_0</th>\n",
" <td>0.012024</td>\n",
" <td>0.015320</td>\n",
" <td>0.031982</td>\n",
" <td>0.020752</td>\n",
" <td>0.032471</td>\n",
" <td>0.030151</td>\n",
" </tr>\n",
" <tr>\n",
" <th>46_layer_norm_1</th>\n",
" <td>0.000546</td>\n",
" <td>0.000664</td>\n",
" <td>0.001152</td>\n",
" <td>0.001045</td>\n",
" <td>0.001411</td>\n",
" <td>0.001305</td>\n",
" </tr>\n",
" <tr>\n",
" <th>46_mlp1</th>\n",
" <td>0.047119</td>\n",
" <td>0.053955</td>\n",
" <td>0.051758</td>\n",
" <td>0.047363</td>\n",
" <td>0.057617</td>\n",
" <td>0.056396</td>\n",
" </tr>\n",
" <tr>\n",
" <th>46_mlp_gelu</th>\n",
" <td>0.004211</td>\n",
" <td>0.004456</td>\n",
" <td>0.004395</td>\n",
" <td>0.004120</td>\n",
" <td>0.004761</td>\n",
" <td>0.004364</td>\n",
" </tr>\n",
" <tr>\n",
" <th>46_mlp2</th>\n",
" <td>0.043213</td>\n",
" <td>0.065430</td>\n",
" <td>0.073242</td>\n",
" <td>0.070312</td>\n",
" <td>0.100586</td>\n",
" <td>0.103516</td>\n",
" </tr>\n",
" <tr>\n",
" <th>46_add_resid_1</th>\n",
" <td>0.064941</td>\n",
" <td>0.081543</td>\n",
" <td>0.109375</td>\n",
" <td>0.080566</td>\n",
" <td>0.131836</td>\n",
" <td>0.131836</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" train_q test_q train_q_cum train_refit test_q_cum \\\n",
"46_attention 0.010864 0.012512 NaN NaN NaN \n",
"46_dense 0.011780 0.015137 0.031738 0.020752 0.032471 \n",
"46_add_resid_0 0.012024 0.015320 0.031982 0.020752 0.032471 \n",
"46_layer_norm_1 0.000546 0.000664 0.001152 0.001045 0.001411 \n",
"46_mlp1 0.047119 0.053955 0.051758 0.047363 0.057617 \n",
"46_mlp_gelu 0.004211 0.004456 0.004395 0.004120 0.004761 \n",
"46_mlp2 0.043213 0.065430 0.073242 0.070312 0.100586 \n",
"46_add_resid_1 0.064941 0.081543 0.109375 0.080566 0.131836 \n",
"\n",
" test_refit \n",
"46_attention NaN \n",
"46_dense 0.030029 \n",
"46_add_resid_0 0.030151 \n",
"46_layer_norm_1 0.001305 \n",
"46_mlp1 0.056396 \n",
"46_mlp_gelu 0.004364 \n",
"46_mlp2 0.103516 \n",
"46_add_resid_1 0.131836 "
]
},
"execution_count": 81,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df # MSEs"
]
},
{
"cell_type": "markdown",
"id": "88d8d651-d6bd-4f25-9235-bc54083e2bf3",
"metadata": {},
"source": [
"#### # RELATIVE STANDARD ERRORS"
]
},
{
"cell_type": "code",
"execution_count": 82,
"id": "a4786349-0578-4b2f-b3d9-479f46c8a49f",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"rses = {kk: {k:v[1] for k, v in vv.items() if k != 'layer_name'} for kk, vv in losses.items()}\n",
"dfr = pd.DataFrame(rses).T"
]
},
{
"cell_type": "code",
"execution_count": 83,
"id": "392d5ead-83d1-4f8e-86e1-f35211783d27",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>train_q</th>\n",
" <th>test_q</th>\n",
" <th>train_q_cum</th>\n",
" <th>train_refit</th>\n",
" <th>test_q_cum</th>\n",
" <th>test_refit</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>46_attention</th>\n",
" <td>0.078022</td>\n",
" <td>0.091042</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>46_dense</th>\n",
" <td>0.054555</td>\n",
" <td>0.070375</td>\n",
" <td>0.146986</td>\n",
" <td>0.096106</td>\n",
" <td>0.150965</td>\n",
" <td>0.139614</td>\n",
" </tr>\n",
" <tr>\n",
" <th>46_add_resid_0</th>\n",
" <td>0.000223</td>\n",
" <td>0.000285</td>\n",
" <td>0.000594</td>\n",
" <td>0.000386</td>\n",
" <td>0.000604</td>\n",
" <td>0.000561</td>\n",
" </tr>\n",
" <tr>\n",
" <th>46_layer_norm_1</th>\n",
" <td>0.000774</td>\n",
" <td>0.000951</td>\n",
" <td>0.001634</td>\n",
" <td>0.001482</td>\n",
" <td>0.002021</td>\n",
" <td>0.001868</td>\n",
" </tr>\n",
" <tr>\n",
" <th>46_mlp1</th>\n",
" <td>0.024921</td>\n",
" <td>0.028832</td>\n",
" <td>0.027375</td>\n",
" <td>0.025050</td>\n",
" <td>0.030789</td>\n",
" <td>0.030136</td>\n",
" </tr>\n",
" <tr>\n",
" <th>46_mlp_gelu</th>\n",
" <td>0.059492</td>\n",
" <td>0.062715</td>\n",
" <td>0.062079</td>\n",
" <td>0.058199</td>\n",
" <td>0.067010</td>\n",
" <td>0.061426</td>\n",
" </tr>\n",
" <tr>\n",
" <th>46_mlp2</th>\n",
" <td>0.091185</td>\n",
" <td>0.137860</td>\n",
" <td>0.154551</td>\n",
" <td>0.148369</td>\n",
" <td>0.211934</td>\n",
" <td>0.218107</td>\n",
" </tr>\n",
" <tr>\n",
" <th>46_add_resid_1</th>\n",
" <td>0.001180</td>\n",
" <td>0.001484</td>\n",
" <td>0.001987</td>\n",
" <td>0.001464</td>\n",
" <td>0.002399</td>\n",
" <td>0.002399</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" train_q test_q train_q_cum train_refit test_q_cum \\\n",
"46_attention 0.078022 0.091042 NaN NaN NaN \n",
"46_dense 0.054555 0.070375 0.146986 0.096106 0.150965 \n",
"46_add_resid_0 0.000223 0.000285 0.000594 0.000386 0.000604 \n",
"46_layer_norm_1 0.000774 0.000951 0.001634 0.001482 0.002021 \n",
"46_mlp1 0.024921 0.028832 0.027375 0.025050 0.030789 \n",
"46_mlp_gelu 0.059492 0.062715 0.062079 0.058199 0.067010 \n",
"46_mlp2 0.091185 0.137860 0.154551 0.148369 0.211934 \n",
"46_add_resid_1 0.001180 0.001484 0.001987 0.001464 0.002399 \n",
"\n",
" test_refit \n",
"46_attention NaN \n",
"46_dense 0.139614 \n",
"46_add_resid_0 0.000561 \n",
"46_layer_norm_1 0.001868 \n",
"46_mlp1 0.030136 \n",
"46_mlp_gelu 0.061426 \n",
"46_mlp2 0.218107 \n",
"46_add_resid_1 0.002399 "
]
},
"execution_count": 83,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dfr # RELATIVE STANDARD ERRORS"
]
},
{
"cell_type": "markdown",
"id": "71371f17-2cec-424e-9601-bbf58dddb3e1",
"metadata": {},
"source": [
"#### activation variances by layer"
]
},
{
"cell_type": "code",
"execution_count": 105,
"id": "173ba439-419c-42fe-9e80-726fe1c7a522",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"xx = []\n",
"for i in X.keys():\n",
" for j in X[i].keys():\n",
" q = X[i][j][0]\n",
" if isinstance(q, dict):\n",
" q = q['hidden_states']\n",
" xx.append([i, j, q.var().item()])"
]
},
{
"cell_type": "code",
"execution_count": 108,
"id": "ec5ba45d-c7d3-4a66-8fe3-24206932ebb7",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th>slice</th>\n",
" <th>test</th>\n",
" <th>test_q</th>\n",
" <th>test_q_cum</th>\n",
" <th>test_refit</th>\n",
" <th>train</th>\n",
" <th>train_q</th>\n",
" <th>train_q_cum</th>\n",
" <th>train_refit</th>\n",
" </tr>\n",
" <tr>\n",
" <th>layer</th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>53.500000</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>54.000000</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>0.636719</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>0.652344</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>0.143555</td>\n",
" <td>0.150391</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>0.166016</td>\n",
" <td>0.172852</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>0.230469</td>\n",
" <td>0.244141</td>\n",
" <td>0.253906</td>\n",
" <td>0.231445</td>\n",
" <td>0.261719</td>\n",
" <td>0.275391</td>\n",
" <td>0.285156</td>\n",
" <td>0.261719</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>54.500000</td>\n",
" <td>54.500000</td>\n",
" <td>54.500000</td>\n",
" <td>54.500000</td>\n",
" <td>55.000000</td>\n",
" <td>55.000000</td>\n",
" <td>55.000000</td>\n",
" <td>55.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>0.695312</td>\n",
" <td>0.695312</td>\n",
" <td>0.695312</td>\n",
" <td>0.695312</td>\n",
" <td>0.734375</td>\n",
" <td>0.738281</td>\n",
" <td>0.738281</td>\n",
" <td>0.734375</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>1.875000</td>\n",
" <td>1.921875</td>\n",
" <td>1.921875</td>\n",
" <td>1.875000</td>\n",
" <td>1.992188</td>\n",
" <td>2.046875</td>\n",
" <td>2.046875</td>\n",
" <td>1.992188</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>0.073730</td>\n",
" <td>0.077637</td>\n",
" <td>0.078613</td>\n",
" <td>0.073730</td>\n",
" <td>0.073730</td>\n",
" <td>0.078125</td>\n",
" <td>0.078125</td>\n",
" <td>0.073242</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>0.494141</td>\n",
" <td>0.558594</td>\n",
" <td>0.593750</td>\n",
" <td>0.539062</td>\n",
" <td>0.494141</td>\n",
" <td>0.542969</td>\n",
" <td>0.574219</td>\n",
" <td>0.523438</td>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <td>55.750000</td>\n",
" <td>55.750000</td>\n",
" <td>55.750000</td>\n",
" <td>55.750000</td>\n",
" <td>56.000000</td>\n",
" <td>56.250000</td>\n",
" <td>56.250000</td>\n",
" <td>56.000000</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
"slice test test_q test_q_cum test_refit train train_q \\\n",
"layer \n",
"0 53.500000 NaN NaN NaN 54.000000 NaN \n",
"1 0.636719 NaN NaN NaN 0.652344 NaN \n",
"2 0.143555 0.150391 NaN NaN 0.166016 0.172852 \n",
"3 0.230469 0.244141 0.253906 0.231445 0.261719 0.275391 \n",
"4 54.500000 54.500000 54.500000 54.500000 55.000000 55.000000 \n",
"5 0.695312 0.695312 0.695312 0.695312 0.734375 0.738281 \n",
"6 1.875000 1.921875 1.921875 1.875000 1.992188 2.046875 \n",
"7 0.073730 0.077637 0.078613 0.073730 0.073730 0.078125 \n",
"8 0.494141 0.558594 0.593750 0.539062 0.494141 0.542969 \n",
"9 55.750000 55.750000 55.750000 55.750000 56.000000 56.250000 \n",
"\n",
"slice train_q_cum train_refit \n",
"layer \n",
"0 NaN NaN \n",
"1 NaN NaN \n",
"2 NaN NaN \n",
"3 0.285156 0.261719 \n",
"4 55.000000 55.000000 \n",
"5 0.738281 0.734375 \n",
"6 2.046875 1.992188 \n",
"7 0.078125 0.073242 \n",
"8 0.574219 0.523438 \n",
"9 56.250000 56.000000 "
]
},
"execution_count": 108,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dfv = pd.DataFrame(xx, columns = ('layer', 'slice', 'var'))\n",
"dfv.set_index(['layer', 'slice'])['var'].unstack()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1dfa04d7-b287-439f-8030-fc7210ebdbea",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 84,
"id": "cfa3d0ce-0ab9-4373-a40b-85f381d7bd17",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"allocated: 12.1 GB, max: 56.5 GB, before flush: 12.1 GB\n"
]
}
],
"source": [
"cuda_mem_use()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment