Created
March 18, 2023 15:36
-
-
Save poedator/0a28b18a0cc275cfbc972a29c7185b55 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"id": "e10492b6-5a32-422c-8ce5-2050347df3aa", | |
"metadata": { | |
"tags": [] | |
}, | |
"source": [ | |
"# WHOLE LAYER TEST. L46" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "bebc0c8e-d531-4a2a-9c31-1892beae7c70", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Sat Mar 18 16:13:42 MSK 2023\n" | |
] | |
} | |
], | |
"source": [ | |
"!date" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "40692f46-77d4-42f7-9388-19e4a53bc499", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"env: CUDA_VISIBLE_DEVICES=4\n", | |
"env: TRANSFORMERS_CACHE=/extra_disk_1/yozh/transformers_cache\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/html": [ | |
"<style>.container { width:100% !important; }</style>" | |
], | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CUDA extension not installed.\n" | |
] | |
} | |
], | |
"source": [ | |
"%env CUDA_VISIBLE_DEVICES=4\n", | |
"%env TRANSFORMERS_CACHE=/extra_disk_1/yozh/transformers_cache\n", | |
"\n", | |
"%load_ext autoreload\n", | |
"%autoreload 2\n", | |
"\n", | |
"import numpy as np\n", | |
"import os\n", | |
"import re\n", | |
"import gc\n", | |
"import copy\n", | |
"from typing import Optional, Tuple, Union, Sequence\n", | |
"import time\n", | |
"from sklearn.model_selection import train_test_split\n", | |
"from operator import attrgetter\n", | |
"from tqdm.auto import trange, tqdm\n", | |
"from collections import defaultdict\n", | |
"\n", | |
"import torch\n", | |
"import torch.nn as nn\n", | |
"import torch.nn.functional as F\n", | |
"import transformers\n", | |
"\n", | |
"torch.backends.cuda.matmul.allow_tf32 = False\n", | |
"torch.backends.cudnn.allow_tf32 = False\n", | |
"\n", | |
"from IPython.display import display, HTML\n", | |
"display(HTML(\"<style>.container { width:100% !important; }</style>\"))\n", | |
"\n", | |
"import pandas as pd\n", | |
"import seaborn as sns\n", | |
"from matplotlib import pyplot as plt\n", | |
"\n", | |
"from petals.bloom.from_pretrained import load_pretrained_block\n", | |
"from transformers.models.bloom.modeling_bloom import BloomGelu\n", | |
"\n", | |
"# local imports\n", | |
"from compress import *\n", | |
"from utils import * \n", | |
"from compressable_layers import *" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "fde12a06-da0a-4262-9506-085df94df3d5", | |
"metadata": {}, | |
"source": [ | |
"### load model\n", | |
"\n", | |
"see block code at \n", | |
"- https://github.com/bigscience-workshop/petals/blob/main/src/petals/bloom/block.py\n", | |
"- https://github.com/huggingface/transformers/blob/main/src/transformers/models/bloom/modeling_bloom.py" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "f8d259ca-7b4c-4bed-a6f6-0907fa2acd94", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Mar 18 16:14:04.070 [\u001b[1m\u001b[34mINFO\u001b[0m] Loaded bigscience/bloom-petals block 46, <All keys matched successfully>\n" | |
] | |
} | |
], | |
"source": [ | |
"# Loading single block (petals.bloom.block.WrappedBloomBlock)\n", | |
"\n", | |
"model_name = 'bigscience/bloom-petals'\n", | |
"block_index = 46\n", | |
"\n", | |
"block = load_pretrained_block(model_name, block_index=block_index, cache_dir=os.environ['TRANSFORMERS_CACHE'])\n", | |
"for param in block.parameters():\n", | |
" param.requires_grad=False\n", | |
" \n", | |
"config = transformers.AutoConfig.from_pretrained(model_name)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "857f9d41-d57d-47a3-97a5-98400819465c", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"BloomConfig {\n", | |
" \"_name_or_path\": \"bigscience/bloom-petals\",\n", | |
" \"apply_residual_connection_post_layernorm\": false,\n", | |
" \"architectures\": [\n", | |
" \"BloomForCausalLM\"\n", | |
" ],\n", | |
" \"attention_dropout\": 0.0,\n", | |
" \"attention_softmax_in_fp32\": true,\n", | |
" \"bos_token_id\": 1,\n", | |
" \"dht_prefix\": \"bigscience/bloom-petals\",\n", | |
" \"eos_token_id\": 2,\n", | |
" \"hidden_dropout\": 0.0,\n", | |
" \"hidden_size\": 14336,\n", | |
" \"initializer_range\": 0.02,\n", | |
" \"layer_norm_epsilon\": 1e-05,\n", | |
" \"masked_softmax_fusion\": true,\n", | |
" \"model_type\": \"bloom\",\n", | |
" \"n_head\": 112,\n", | |
" \"n_layer\": 70,\n", | |
" \"pad_token_id\": 3,\n", | |
" \"pretraining_tp\": 4,\n", | |
" \"slow_but_exact\": false,\n", | |
" \"torch_dtype\": \"bfloat16\",\n", | |
" \"transformers_version\": \"4.25.1\",\n", | |
" \"use_cache\": true,\n", | |
" \"vocab_size\": 250880\n", | |
"}" | |
] | |
}, | |
"execution_count": 4, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"config" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "55d73e8f-8eec-4be0-8b65-b72e3739b1a4", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"WrappedBloomBlock(\n", | |
" (input_layernorm): LayerNorm((14336,), eps=1e-05, elementwise_affine=True)\n", | |
" (self_attention): BloomAttention(\n", | |
" (query_key_value): Linear(in_features=14336, out_features=43008, bias=True)\n", | |
" (dense): Linear(in_features=14336, out_features=14336, bias=True)\n", | |
" (attention_dropout): Dropout(p=0.0, inplace=False)\n", | |
" )\n", | |
" (post_attention_layernorm): LayerNorm((14336,), eps=1e-05, elementwise_affine=True)\n", | |
" (mlp): BloomMLP(\n", | |
" (dense_h_to_4h): Linear(in_features=14336, out_features=57344, bias=True)\n", | |
" (gelu_impl): BloomGelu()\n", | |
" (dense_4h_to_h): Linear(in_features=57344, out_features=14336, bias=True)\n", | |
" )\n", | |
")" | |
] | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"block" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "0a2da100-e4a0-4f82-a3a1-f2e8299bbe4f", | |
"metadata": {}, | |
"source": [ | |
"### load data" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "3cc2695b-135f-4046-876a-163fbb19e879", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [], | |
"source": [ | |
"from pathlib import Path\n", | |
"import re\n", | |
"import torch\n", | |
"from petals.bloom.from_pretrained import load_pretrained_block\n", | |
"from transformers.models.bloom.modeling_bloom import BloomGelu" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"id": "817d231a-ea4c-405f-af9a-660770a4e54e", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "a8a1a1bb973c45908f0c33a26db906d3", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
" 0%| | 0/271 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"loaded 271 samples with shape torch.Size([1, 1024, 14336]) each\n" | |
] | |
} | |
], | |
"source": [ | |
"path46 = '/home/jheuristic/blocks_46_47'\n", | |
"samples_paths = [p for p in Path(path46).iterdir()]\n", | |
"samples_paths.sort(key = lambda x: int(re.search(r'sample(\\d+)_', x.name)[1]))\n", | |
"x_samples = [torch.load(p) for p in tqdm(samples_paths)]\n", | |
"# X = torch.cat(x_slices, dim=1).squeeze(0) # merging tensors into one\n", | |
"# x_samples = [dict(hidden_states=x) for i, x in enumerate(x_samples)]\n", | |
"print(f\"loaded {len(x_samples)} samples with shape {x_samples[0].shape} each\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"id": "999864ee-521f-40f6-8878-93b5b126dcc3", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(243, 28, torch.bfloat16)" | |
] | |
}, | |
"execution_count": 8, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# Train_Test_split using whole samples of 1024 tokens\n", | |
"train_idx, test_idx = train_test_split(range(len(x_samples)), test_size=0.1)\n", | |
"X0_train = [x_samples[i] for i in train_idx]\n", | |
"X0_test = [x_samples[i] for i in test_idx]\n", | |
"# del x_samples\n", | |
"len(X0_train), len(X0_test), X0_train[0].dtype" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "738a4c93-c75f-4f3c-9346-6dc0decadf68", | |
"metadata": {}, | |
"source": [ | |
"# FORWARD PASS" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "dad32079-9514-4bd7-842f-e57fac7b7c68", | |
"metadata": { | |
"tags": [] | |
}, | |
"source": [ | |
"#### 0. setup" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "15bb8bf3-5b00-481d-9b58-c7dae42d3395", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [], | |
"source": [ | |
"def calc_losses(i, layer_name, tasks=('train', 'test'), suffixes=('_q', '_q_cum', '_refit')):\n", | |
" loss1 = {'layer_name': layer_name}\n", | |
" for t in tasks:\n", | |
" for suffix in suffixes:\n", | |
" tag = t + suffix\n", | |
" if tag in X[i].keys():\n", | |
" mse_loss, rse_loss = batch_mse_loss(X[i][t], X[i][tag], desc = tag)\n", | |
" loss1.update({f'{t}{suffix}': (mse_loss, rse_loss)})\n", | |
" return loss1" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"id": "1fc39bad-d0fe-4a80-8797-a6e56a053eb7", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [], | |
"source": [ | |
"wbits = 3" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"id": "322af863-23a5-4cf1-99ab-7ea628aa32f7", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [], | |
"source": [ | |
"losses = dict() # storage for losses by layer" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"id": "9e679b55-abce-42e9-88c5-2864a9b5f235", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [], | |
"source": [ | |
"# main storage for interim data\n", | |
"X = defaultdict(dict)\n", | |
"X[0]={'train': [x_samples[i] for i in train_idx],\n", | |
" 'test': [x_samples[i] for i in test_idx]\n", | |
" }" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"id": "a3284c09-abcb-4499-97e3-2d61f2c87749", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "a9d471710dcf464784f78e85185d8e23", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
" 0%| | 0/243 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "c0d6478772f8424384025302cda6e2d2", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
" 0%| | 0/28 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"# Add extra fields to data (alibi, residual, masks)\n", | |
"for k, v in X[0].items():\n", | |
" X[0][k] = [augment_batch(xb, block) for xb in tqdm(v)]" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "6626bc4d-ef33-4e23-bdfc-36aff66b7a2d", | |
"metadata": { | |
"tags": [] | |
}, | |
"source": [ | |
"#### 1. layer_norm" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"id": "708f7b21-a267-4e9d-95e8-71836b17bfa6", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [], | |
"source": [ | |
"# run this cell before any other one for this layer\n", | |
"i = 1\n", | |
"layer_name = '46_layer_norm_0'\n", | |
"\n", | |
"layer_n0 = LayerNormLayer(module=block.input_layernorm)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"id": "43856496-e37e-4998-8892-5add70afbc8d", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "a9fb9a320d3746fa92a2e04927e50f7b", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"LNorm: 0%| | 0/243 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "010882312cff4b24a12b969644fc3fd3", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"LNorm: 0%| | 0/28 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"# Regular test and train data with original layer \n", | |
"layer = layer_n0\n", | |
"suffix = ''\n", | |
"\n", | |
"for k, v in X[i - 1].items():\n", | |
" X[i][k + suffix] = layer(X[i - 1][k])" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "46f99e48-2431-4f33-aaeb-fa5ed1d03ca4", | |
"metadata": { | |
"tags": [] | |
}, | |
"source": [ | |
"#### 2. attention" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"id": "ed94e7f3-7461-4289-9921-021cdf08c2e0", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [], | |
"source": [ | |
"# run this cell before any other one for this layer\n", | |
"i = 2\n", | |
"layer_name = '46_attention'\n", | |
"layer_a = AttnLayer(module=block.self_attention)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"id": "0f9d0d6b-5bee-41bb-91bd-b22f15e01ad1", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "ab31c2fd407c465d93b278d183172a94", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
" 0%| | 0/243 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "2620cfc5204e47c49b24fb16f23a7351", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
" 0%| | 0/28 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"# Regular test and train data with original layer \n", | |
"layer = layer_a\n", | |
"suffix = ''\n", | |
"\n", | |
"for k in ['train', 'test']:\n", | |
" X[i][k + suffix] = layer(X[i - 1][k])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"id": "41b71794-e004-4dbf-a83c-ca80c5f468ae", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "f8ef1dce6d0e4f2cacd24cc7dcd42574", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"calculating XTX matrix: 0%| | 0/243 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "57608e6dab6a48c3892246f1218aa2f8", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"gptq: 0%| | 0/112 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"layer_a_q = layer_a.compress_gptq(X[1]['train'], wbits=3)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"id": "53403baf-22de-4a50-87d9-5cffa3fbcfd6", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "922bdcdd2d6b46aa8057c056f05edb8f", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
" 0%| | 0/243 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "e7ceb848bf1f41698ddd0821d217515c", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
" 0%| | 0/28 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"# Regular test and train data with GPTQ-compressed layer\n", | |
"layer = layer_a_q\n", | |
"suffix = '_q'\n", | |
"\n", | |
"for k in ['train', 'test']:\n", | |
" X[i][k + suffix] = layer(X[i - 1][k])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"id": "3afa2322-3275-4302-95f1-22ce31098b35", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"train_q: 0%| | 0/243 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"test_q: 0%| | 0/28 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"{'layer_name': '46_attention', 'train_q': (0.0108642578125, 0.07802173495292664), 'test_q': (0.01251220703125, 0.09104174375534058)}\n" | |
] | |
} | |
], | |
"source": [ | |
"losses[layer_name] = calc_losses(i=i, layer_name=layer_name)\n", | |
"print(losses[layer_name]) " | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "41e1d789-5ee2-4585-b047-f408a6692bd1", | |
"metadata": { | |
"tags": [] | |
}, | |
"source": [ | |
"#### 3. dense" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 21, | |
"id": "569e031a-f2ce-45aa-a739-84cdaa3fd1b9", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [], | |
"source": [ | |
"# run this cell before any other one for this layer\n", | |
"i = 3\n", | |
"layer_name = '46_dense'\n", | |
"\n", | |
"layer_d = LinearLayer(block.self_attention.dense)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 22, | |
"id": "968506d0-d324-4ff9-b507-d541bdcf89f8", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "f8b9a59021494d1c99a273626aaacafe", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
" 0%| | 0/243 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "02d4879b827144c986dc17ecd0e9081e", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
" 0%| | 0/28 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"# Regular test and train data with original layer \n", | |
"layer = layer_d\n", | |
"suffix = ''\n", | |
"\n", | |
"for k in 'train test'.split():\n", | |
" X[i][k + suffix] = layer(X[i - 1][k])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 23, | |
"id": "cebfcbe0-d942-487c-b620-a13e76a965c5", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "7ec0b04ac63f48b3a9e3fc891441cb73", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"calculating XTX matrix: 0%| | 0/243 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "effd521588b84de9b2e298b465e700ee", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"gptq: 0%| | 0/112 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"layer_d_q = layer_d.compress_gptq(X[2]['train'], wbits=wbits)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 24, | |
"id": "0e918b87-9d90-4125-ad1b-a51fae71361f", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "840383a6117e464a8bbe46ed47e90daa", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
" 0%| | 0/243 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "00ba87913125465aa22014bda9545383", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
" 0%| | 0/28 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"# Regular test and train data with GPTQ layer\n", | |
"layer = layer_d_q\n", | |
"suffix = '_q'\n", | |
"\n", | |
"for k in 'train test'.split():\n", | |
" X[i][k + suffix] = layer(X[i - 1][k])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 25, | |
"id": "7145a99c-8e5d-4fbf-a6c2-1f3ca7d19ee3", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "59a89b78fc28449886d038c8a6c05771", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
" 0%| | 0/243 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "ba2041c7b81541c792aacefa6febae4a", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
" 0%| | 0/28 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"# CUMULATIVE GPTQ test and train data with GPTQ layer - NO REFIT\n", | |
"layer = layer_d_q\n", | |
"suffix = '_q_cum'\n", | |
"\n", | |
"for k in 'train test'.split():\n", | |
" X[i][k + suffix] = layer(X[i - 1][k + '_q'])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 27, | |
"id": "09070b35-c8ad-42b7-8bed-a26593d820ef", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"without validation set, using train loss for early stopping\n" | |
] | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "18ea0df6bd0e4c0f91849c4e2b5b9515", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Refitter starting: 0%| | 0/3037 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"EARLY STOP\n", | |
"Fit complete with train loss 0.010352, test loss 0.010410\n" | |
] | |
} | |
], | |
"source": [ | |
"layer_d_refit = layer_d.refit(X_train=X[2]['train_q'], Y_train=X[3]['train'])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 28, | |
"id": "3d7c7d4d-f11a-48e5-850f-8c00744b0a50", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "d96b65f7f39a4dc8b7731fcfb78b98b0", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"calculating XTX matrix: 0%| | 0/243 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "5c6e36e025924c81865ea16af4b151ac", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"gptq: 0%| | 0/112 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"layer_d_refit_q = layer_d_refit.compress_gptq(X[2]['train_q'], wbits=wbits)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 29, | |
"id": "a4837c77-7f15-49a8-a200-b7509134afc3", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "f8baabd8bd5049a9a2d3544be6033857", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
" 0%| | 0/243 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "e5926c2da60345649b7c8891188212ad", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
" 0%| | 0/28 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"# CUMULATIVE GPTQ test and train data with REFITTED GPTQ layer\n", | |
"layer = layer_d_refit_q\n", | |
"suffix_in = '_q'\n", | |
"suffix_out = '_refit'\n", | |
"\n", | |
"for k in 'train test'.split():\n", | |
" X[i][k + suffix_out] = layer(X[i - 1][k + suffix_in])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 30, | |
"id": "2442e9c4-c08c-4c73-9cd7-b38d4c6a58a1", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"train_q: 0%| | 0/243 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"train_q_cum: 0%| | 0/243 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"train_refit: 0%| | 0/243 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"test_q: 0%| | 0/28 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"test_q_cum: 0%| | 0/28 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"test_refit: 0%| | 0/28 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"{'layer_name': '46_dense', 'train_q': (0.01177978515625, 0.0545545294880867), 'train_q_cum': (0.03173828125, 0.14698629081249237), 'train_refit': (0.020751953125, 0.09610642492771149), 'test_q': (0.01513671875, 0.07037457823753357), 'test_q_cum': (0.032470703125, 0.15096482634544373), 'test_refit': (0.030029296875, 0.139614075422287)}\n" | |
] | |
} | |
], | |
"source": [ | |
"losses[layer_name] = calc_losses(i=i, layer_name=layer_name)\n", | |
"print(losses[layer_name]) " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 31, | |
"id": "c4b5c209-aaae-473c-ae99-3f6246a2808b", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"dict_keys(['train', 'test', 'train_q', 'test_q', 'train_q_cum', 'test_q_cum', 'train_refit', 'test_refit'])" | |
] | |
}, | |
"execution_count": 31, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"X[3].keys()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 32, | |
"id": "e63831f9-c77b-4d69-b0c1-de29e3634284", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"allocated: 0.0 GB, max: 8.8 GB, before flush: 0.0 GB\n" | |
] | |
} | |
], | |
"source": [ | |
"cuda_mem_use()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "0129840e-049c-4820-ad1e-cfd03ab20e47", | |
"metadata": { | |
"tags": [] | |
}, | |
"source": [ | |
"#### 4. add resid 0" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 33, | |
"id": "f537e13d-606d-4635-bd42-76b3f6e4ef9d", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [], | |
"source": [ | |
"# run this cell before any other one for this layer\n", | |
"i = 4\n", | |
"layer_name = '46_add_resid_0'" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 34, | |
"id": "f65d9f4d-695d-4623-882a-c4e491d757dc", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [], | |
"source": [ | |
"def add_sequences(aa, bb):\n", | |
" \"\"\"adds 2 lists of tensors pairwise, using cuda\"\"\"\n", | |
" cc = []\n", | |
" for a, b in tqdm(zip(aa, bb), leave=False, total=len(aa)):\n", | |
" cc.append((a.cuda() + b.cuda()).cpu())\n", | |
" return cc" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 35, | |
"id": "44ba45e0-1310-4abf-9cbb-ce85232882f9", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"dict_keys(['hidden_states', 'attention_mask', 'alibi', 'residual'])" | |
] | |
}, | |
"execution_count": 35, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"X[0]['test'][0].keys()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 36, | |
"id": "32145b2f-2da6-4ce5-9a11-d690b3ba91de", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "974650bb35704699a32a0a22d1a1c1a5", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
" 0%| | 0/8 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"train train; " | |
] | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
" 0%| | 0/243 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"test test; " | |
] | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
" 0%| | 0/28 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"train_q train; " | |
] | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
" 0%| | 0/243 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"test_q test; " | |
] | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
" 0%| | 0/28 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"train_q_cum train; " | |
] | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
" 0%| | 0/243 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"test_q_cum test; " | |
] | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
" 0%| | 0/28 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"train_refit train; " | |
] | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
" 0%| | 0/243 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"test_refit test; " | |
] | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
" 0%| | 0/28 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"# adding corresponding residual\n", | |
"\n", | |
"for key, val in tqdm(X[i - 1].items()):\n", | |
" resid_key = key.split('_')[0]\n", | |
" print(key, resid_key, end = '; ')\n", | |
" X[i][key] = add_sequences(val, [xb['hidden_states'] for xb in X[0][resid_key]])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 37, | |
"id": "edfe7a0c-ea47-49d6-833a-ebdaad5a8d25", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"dict_keys(['train', 'test', 'train_q', 'test_q', 'train_q_cum', 'test_q_cum', 'train_refit', 'test_refit'])" | |
] | |
}, | |
"execution_count": 37, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"X[4].keys()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 38, | |
"id": "9e7ec44d-d124-44a7-8dc8-5f0c172c694b", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"train_q: 0%| | 0/243 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"train_q_cum: 0%| | 0/243 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"train_refit: 0%| | 0/243 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"test_q: 0%| | 0/28 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"test_q_cum: 0%| | 0/28 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"test_refit: 0%| | 0/28 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"{'layer_name': '46_add_resid_0', 'train_q': (0.01202392578125, 0.00022347424237523228), 'train_q_cum': (0.031982421875, 0.000594418786931783), 'train_refit': (0.020751953125, 0.0003856915864162147), 'test_q': (0.01531982421875, 0.0002849726297426969), 'test_q_cum': (0.032470703125, 0.0006040057633072138), 'test_refit': (0.0301513671875, 0.0005608624778687954)}\n" | |
] | |
} | |
], | |
"source": [ | |
"losses[layer_name] = calc_losses(i=i, layer_name=layer_name)\n", | |
"print(losses[layer_name]) " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 39, | |
"id": "8fe06f37-af23-42ce-85d6-30b7dddf75b8", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"allocated: 0.0 GB, max: 8.8 GB, before flush: 0.0 GB\n" | |
] | |
} | |
], | |
"source": [ | |
"cuda_mem_use()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "910a19ef-8e81-4fe3-b7af-c119c538fd89", | |
"metadata": { | |
"tags": [] | |
}, | |
"source": [ | |
"#### 5. layer_norm_1" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 40, | |
"id": "ed0b38ed-4334-430a-ba18-65233be15eb4", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [], | |
"source": [ | |
"# run this cell before any other one for this layer\n", | |
"i = 5\n", | |
"layer_name = '46_layer_norm_1'\n", | |
"\n", | |
"layer_n1 = LayerNormLayer(module=block.post_attention_layernorm)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 41, | |
"id": "6470829a-3315-4ba5-b849-2b341ceb5e91", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "70432747c2ff4407ba3320b1c9a19873", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
" 0%| | 0/8 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "0a66e81823314cfbb29c74b14ae92282", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"LNorm: 0%| | 0/243 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "d8097cdc2ec64044a1f29ec4fbc8ae2e", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"LNorm: 0%| | 0/28 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "9f012accc78d44f698965a5e6c49e49d", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"LNorm: 0%| | 0/243 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "5264afb399c24504b4c293fd406dd8de", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"LNorm: 0%| | 0/28 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "e93621c538104940ba1418f680c6926d", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"LNorm: 0%| | 0/243 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "a098f4ff1a684e6b963fb97c7aeaffb9", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"LNorm: 0%| | 0/28 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "af731628c3084d3f82dd1f98466f8259", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"LNorm: 0%| | 0/243 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "2baa2c78fc32410395e8e78f2b7410ae", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"LNorm: 0%| | 0/28 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"for k, v in tqdm(X[i - 1].items()):\n", | |
" X[i][k] = layer_n1(v)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 42, | |
"id": "8d8cf8bd-a867-4635-aa24-f008f4e0409c", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"train_q: 0%| | 0/243 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"train_q_cum: 0%| | 0/243 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"train_refit: 0%| | 0/243 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"test_q: 0%| | 0/28 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"test_q_cum: 0%| | 0/28 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"test_refit: 0%| | 0/28 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"{'layer_name': '46_layer_norm_1', 'train_q': (0.000545501708984375, 0.0007736138068139553), 'train_q_cum': (0.00115203857421875, 0.0016337857814505696), 'train_refit': (0.00104522705078125, 0.001482308958657086), 'test_q': (0.00066375732421875, 0.0009506119531579316), 'test_q_cum': (0.00141143798828125, 0.002021416090428829), 'test_refit': (0.00130462646484375, 0.0018684441456571221)}\n" | |
] | |
} | |
], | |
"source": [ | |
"losses[layer_name] = calc_losses(i=i, layer_name=layer_name)\n", | |
"print(losses[layer_name]) " | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "0c68088b-5dd9-498a-9052-af9159f7bea9", | |
"metadata": {}, | |
"source": [ | |
"#### 6. MLP-1" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 43, | |
"id": "9a031013-08b8-4b64-8d72-e8f1667f3499", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [], | |
"source": [ | |
"# run this cell before any other one for this layer\n", | |
"i = 6\n", | |
"layer_name = '46_mlp1'" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 44, | |
"id": "3f8b934c-d794-4e87-98c8-b8179e0966d9", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [], | |
"source": [ | |
"layer_mlp1 = LinearLayer(block.mlp.dense_h_to_4h)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 45, | |
"id": "534ec464-95da-428f-8c6a-e3731c39b21a", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "8268f7a712334665ad8aaeaced347ca1", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
" 0%| | 0/243 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "8337d76791fe43be8616048e4ae88a46", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
" 0%| | 0/28 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"# Regular test and train data with original layer \n", | |
"layer = layer_mlp1\n", | |
"suffix = ''\n", | |
"\n", | |
"for k in ('train', 'test'):\n", | |
" X[i][k + suffix] = layer(X[i - 1][k])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 46, | |
"id": "f8ebdbed-bee3-4202-95c8-e1dbce28d76c", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "20192bb0d2ac46d9ad218d02e1f51ae2", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"calculating XTX matrix: 0%| | 0/243 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "53e039472df545d2b62ba448693b1ab3", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"gptq: 0%| | 0/112 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"layer_mlp1_q = layer_mlp1.compress_gptq(X[5]['train'], wbits=wbits)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 47, | |
"id": "d03a9e50-ab92-49d7-88d1-13e27bfa93e0", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"torch.Size([57344, 14336])" | |
] | |
}, | |
"execution_count": 47, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"layer_mlp1_q.module.weight.data.shape" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 48, | |
"id": "b6006798-48da-4cd0-b6d1-b53d14afe76c", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "3787e343b303496c9cd8723586e5f3f4", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
" 0%| | 0/243 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "6d5ae4b0e68e42199d055b3aa20ee13b", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
" 0%| | 0/28 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"# Regular test and train data with GPTQ layer\n", | |
"layer = layer_mlp1_q\n", | |
"suffix = '_q'\n", | |
"\n", | |
"for k in ('train', 'test'):\n", | |
" X[i][k + suffix] = layer(X[i - 1][k])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 49, | |
"id": "43118fe7-eb98-46fd-9966-009fe7abbf0e", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "f109242060e84cd5baaf807f333ef6c5", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
" 0%| | 0/243 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "31bda58082b24a00a1cc29c189edfa26", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
" 0%| | 0/28 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"# CUMULATIVE GPTQ test and train data with GPTQ layer - NO REFIT\n", | |
"layer = layer_mlp1_q\n", | |
"suffix = '_q_cum'\n", | |
"\n", | |
"for k in ('train', 'test'):\n", | |
" X[i][k + suffix] = layer(X[i - 1][k + suffix])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 52, | |
"id": "697fdd8b-4d50-4c86-a59b-6db17bfebd48", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"without validation set, using train loss for early stopping\n" | |
] | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "ebb3a29cee0f417895418f5800126b2a", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Refitter starting: 0%| | 0/3037 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"EARLY STOP\n", | |
"Fit complete with train loss 0.006116, test loss 0.006138\n" | |
] | |
} | |
], | |
"source": [ | |
"layer_mlp1_refit = layer_mlp1.refit(X_train=X[5]['train_refit'], Y_train=X[6]['train'],\n", | |
" X_val=X[5]['test_refit'], Y_val=X[6]['test'], )" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 53, | |
"id": "002dd949-157c-41bb-b22d-5d269b678173", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "974668df15924537a989541458c15929", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"calculating XTX matrix: 0%| | 0/243 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "4b312552c0774199b453f20145723d47", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"gptq: 0%| | 0/112 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"layer_mlp1_refit_q = layer_mlp1_refit.compress_gptq(X[5]['train_refit'], wbits=wbits)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 54, | |
"id": "7c46503d-5b3a-446f-9ee2-0c4a91ac1a10", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "2a7c52ec399f4236bbbb265fa27cdb97", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
" 0%| | 0/243 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "9cd451339e704afcb84ffa376cce9b4f", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
" 0%| | 0/28 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"# CUMULATIVE GPTQ test and train data with REFITTED GPTQ layer\n", | |
"layer = layer_mlp1_refit_q\n", | |
"suffix_in = '_refit'\n", | |
"suffix_out = '_refit'\n", | |
"\n", | |
"for k in 'train test'.split():\n", | |
" X[i][k + suffix_out] = layer(X[i - 1][k + suffix_in])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 55, | |
"id": "ebacab4d-e9bb-4d38-91fd-f55011d227d2", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"train_q: 0%| | 0/243 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"train_q_cum: 0%| | 0/243 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"train_refit: 0%| | 0/243 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"test_q: 0%| | 0/28 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"test_q_cum: 0%| | 0/28 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"test_refit: 0%| | 0/28 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"{'layer_name': '46_mlp1', 'train_q': (0.047119140625, 0.024921249598264694), 'train_q_cum': (0.0517578125, 0.027374638244509697), 'train_refit': (0.04736328125, 0.02505037561058998), 'test_q': (0.053955078125, 0.02883181720972061), 'test_q_cum': (0.0576171875, 0.030788728967308998), 'test_refit': (0.056396484375, 0.03013642504811287)}\n" | |
] | |
} | |
], | |
"source": [ | |
"losses[layer_name] = calc_losses(i=i, layer_name=layer_name)\n", | |
"print(losses[layer_name]) " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 56, | |
"id": "ca036c50-57b0-4995-9d5b-35d2ea556f2d", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"dict_keys(['train', 'test', 'train_q', 'test_q', 'train_q_cum', 'test_q_cum', 'train_refit', 'test_refit'])" | |
] | |
}, | |
"execution_count": 56, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"X[6].keys()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 57, | |
"id": "4306f711-f93b-448e-969d-7d1149d9618c", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"allocated: 7.5 GB, max: 19.1 GB, before flush: 7.5 GB\n" | |
] | |
} | |
], | |
"source": [ | |
"cuda_mem_use()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "4a0886d4-0cd8-4613-bd40-5c14b26f6948", | |
"metadata": {}, | |
"source": [ | |
"#### 7. Gelu" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 58, | |
"id": "7ab9375b-8874-48f1-a893-55a2526ef321", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# run this cell before any other one for this layer\n", | |
"i = 7\n", | |
"layer_name = '46_mlp_gelu'\n", | |
"layer_mlp_gelu = Layer(module=block.mlp.gelu_impl)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 59, | |
"id": "0f537a25-a8e9-40a2-9b3e-3993c63854fa", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "171fd1c17127480f86cdde743f7ea017", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
" 0%| | 0/8 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "0c927f08bfc240448214e6f6367e3a86", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
" 0%| | 0/243 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "6c5b6b93ea784963a2ba7c687eb8bd79", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
" 0%| | 0/28 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "ab2b6b10421e439bae1d2a688cbea1b6", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
" 0%| | 0/243 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "2f118d7ead39444cbd2fe14b82387ae7", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
" 0%| | 0/28 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "66c4dcb4b1cd4aec8be587d3bee10df4", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
" 0%| | 0/243 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "3c3b0aca107a4bdb8fa4c79bd067ab49", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
" 0%| | 0/28 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "39b37eff929d4fbeb18857313c824f1a", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
" 0%| | 0/243 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "d5d22480048b42e8a118ab6181932d81", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
" 0%| | 0/28 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"for k, v in tqdm(X[i - 1].items()):\n", | |
" X[i][k] = layer_mlp_gelu(v)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 60, | |
"id": "8d5088ea-d2a4-46ea-98f9-b195bd7ce081", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"train_q: 0%| | 0/243 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"train_q_cum: 0%| | 0/243 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"train_refit: 0%| | 0/243 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"test_q: 0%| | 0/28 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"test_q_cum: 0%| | 0/28 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"test_refit: 0%| | 0/28 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"{'layer_name': '46_mlp_gelu', 'train_q': (0.00421142578125, 0.059492044150829315), 'train_q_cum': (0.00439453125, 0.062078654766082764), 'train_refit': (0.004119873046875, 0.05819873884320259), 'test_q': (0.00445556640625, 0.06271477788686752), 'test_q_cum': (0.0047607421875, 0.0670103132724762), 'test_refit': (0.004364013671875, 0.06142611801624298)}\n" | |
] | |
} | |
], | |
"source": [ | |
"losses[layer_name] = calc_losses(i=i, layer_name=layer_name)\n", | |
"print(losses[layer_name]) " | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "2d4c0299-b612-4ccc-ba42-0258d1d6b155", | |
"metadata": {}, | |
"source": [ | |
"#### 8. MLP-2" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 61, | |
"id": "07aa5a66-7bb4-4e79-8385-136b06661447", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [], | |
"source": [ | |
"# run this cell before any other one for this layer\n", | |
"i = 8\n", | |
"layer_name = '46_mlp2'\n", | |
"\n", | |
"layer_mlp2 = LinearLayer(block.mlp.dense_4h_to_h)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 62, | |
"id": "ca3108dc-9b31-4cb5-a075-5718ddcbc770", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "c5c6fd78af2047d88b4a4e8df3a9adea", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
" 0%| | 0/243 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "95a4cef45bc8484f9e843384be835091", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
" 0%| | 0/28 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"# Regular test and train data with original layer \n", | |
"layer = layer_mlp2\n", | |
"suffix = ''\n", | |
"\n", | |
"for k in ('train', 'test'):\n", | |
" X[i][k + suffix] = layer(X[i - 1][k])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 63, | |
"id": "2dbf0a6d-dc3e-454c-b721-3ae6b32354e3", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "84c33a4d457b47a7a8e78a9636ab4d4f", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"calculating XTX matrix: 0%| | 0/243 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "ad49cc1229a748eab84333f886b0f22e", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"gptq: 0%| | 0/448 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"layer_mlp2_q = layer_mlp2.compress_gptq(X[7]['train'], wbits=wbits)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 64, | |
"id": "8d52b692-1179-4b9d-87bd-e130301e0a08", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "9f8103e9b6b64f3eb6c1214e1a8ee009", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
" 0%| | 0/243 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "89a57becd4ad4fde8d665c58ac3120b7", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
" 0%| | 0/28 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"# Regular test and train data with GPTQ layer\n", | |
"layer = layer_mlp2_q\n", | |
"suffix = '_q'\n", | |
"\n", | |
"for k in ('train', 'test'):\n", | |
" X[i][k + suffix] = layer(X[i - 1][k])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 65, | |
"id": "535f21d4-27c6-47fe-9015-9976b6354393", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "1680828d1e7240a78ee434f1972ce28f", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
" 0%| | 0/243 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "e907dee31e6b4d1bb5009c4a8152da1f", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
" 0%| | 0/28 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"# CUMULATIVE GPTQ test and train data with GPTQ layer - NO REFIT\n", | |
"layer = layer_mlp2_q\n", | |
"suffix = '_q_cum'\n", | |
"\n", | |
"for k in ('train', 'test'):\n", | |
" X[i][k + suffix] = layer(X[i - 1][k + suffix])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 68, | |
"id": "281ac375-10a8-41a0-b248-025f1f2fd780", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"early stopping when val loss stops improving.\n" | |
] | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "9d100acf27e6423983e04e0d6006e1da", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Refitter starting: 0%| | 0/3037 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"EARLY STOP\n", | |
"Fit complete with train loss 0.022766, test loss 0.040771\n" | |
] | |
} | |
], | |
"source": [ | |
"# INITIAL REFIT ATTEMPT \n", | |
"layer_mlp2_refit = layer_mlp2.refit(X[7]['train_refit'], X[8]['train'], \n", | |
" X[7]['test_refit'], X[8]['test'])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 69, | |
"id": "a7e83c08-171c-4a2d-9a13-7ee7811da1a9", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "027049bee96a40fd9711293e3121187c", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"calculating XTX matrix: 0%| | 0/243 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "55ee5cef79f5473285f2f9f8c7f2e4a1", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"gptq: 0%| | 0/448 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"layer_mlp2_refit_q = layer_mlp2_refit.compress_gptq(X[7]['train_refit'], wbits=wbits)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 70, | |
"id": "f197e4b4-0081-447d-b90c-db5fdb731daa", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "3de37ecdaa534243b78d102c840857fe", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
" 0%| | 0/243 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "ba3cbeb60c414554872d90bf438595eb", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
" 0%| | 0/28 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"# CUMULATIVE GPTQ test and train data with REFITTED GPTQ layer\n", | |
"layer = layer_mlp2_refit_q\n", | |
"suffix_in = '_refit'\n", | |
"suffix_out = '_refit'\n", | |
"\n", | |
"for k in 'train test'.split():\n", | |
" X[i][k + suffix_out] = layer(X[i - 1][k + suffix_in])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 71, | |
"id": "8c7ce3e4-b764-456c-a621-94765f440ffb", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"train_q: 0%| | 0/243 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"train_q_cum: 0%| | 0/243 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"train_refit: 0%| | 0/243 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"test_q: 0%| | 0/28 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"test_q_cum: 0%| | 0/28 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"test_refit: 0%| | 0/28 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"{'layer_name': '46_mlp2', 'train_q': (0.043212890625, 0.09118527173995972), 'train_q_cum': (0.0732421875, 0.15455131232738495), 'train_refit': (0.0703125, 0.14836925268173218), 'test_q': (0.0654296875, 0.13786007463932037), 'test_q_cum': (0.1005859375, 0.2119341492652893), 'test_refit': (0.103515625, 0.21810698509216309)}\n" | |
] | |
} | |
], | |
"source": [ | |
"losses[layer_name] = calc_losses(i=i, layer_name=layer_name)\n", | |
"print(losses[layer_name]) " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 72, | |
"id": "993d5f35-e022-4611-b9b1-147b0f9b8646", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"dict_keys(['train', 'test', 'train_q', 'test_q', 'train_q_cum', 'test_q_cum', 'train_refit', 'test_refit'])" | |
] | |
}, | |
"execution_count": 72, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"X[8].keys()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 73, | |
"id": "e8d09e8e-2bac-48aa-bbf7-880fc01beef3", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"allocated: 7.5 GB, max: 56.5 GB, before flush: 7.5 GB\n" | |
] | |
} | |
], | |
"source": [ | |
"cuda_mem_use()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "afe7a7df-9073-48a7-9ae5-d468f57ba155", | |
"metadata": {}, | |
"source": [ | |
"#### 9. add resid 1" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 74, | |
"id": "5128996a-9506-4919-8b16-61cc037592ff", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [], | |
"source": [ | |
"# run this cell before any other one for this layer\n", | |
"i = 9\n", | |
"layer_name = '46_add_resid_1'" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 75, | |
"id": "d64d96f8-3d2f-444e-b213-3590bc5a26f6", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"dict_keys(['train', 'test', 'train_q', 'test_q', 'train_q_cum', 'test_q_cum', 'train_refit', 'test_refit'])" | |
] | |
}, | |
"execution_count": 75, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# source for resid\n", | |
"X[4].keys()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 76, | |
"id": "ea809e07-5f30-463c-a3fb-7f5d52a088a4", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "7722f73aae254526b96e8ab8511f687d", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
" 0%| | 0/8 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
" 0%| | 0/243 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
" 0%| | 0/28 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
" 0%| | 0/243 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
" 0%| | 0/28 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
" 0%| | 0/243 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
" 0%| | 0/28 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
" 0%| | 0/243 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
" 0%| | 0/28 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"# adding corresponding residual\n", | |
"\n", | |
"for key, val in tqdm(X[i - 1].items()):\n", | |
" resid_key = key\n", | |
" # print(key, resid_key, end = '; ')\n", | |
" X[i][key] = add_sequences(val, X[4][resid_key])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 77, | |
"id": "1b358f33-0140-4269-a800-5a0aefbc0c49", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"train_q: 0%| | 0/243 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"train_q_cum: 0%| | 0/243 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"train_refit: 0%| | 0/243 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"test_q: 0%| | 0/28 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"test_q_cum: 0%| | 0/28 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"test_refit: 0%| | 0/28 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"{'layer_name': '46_add_resid_1', 'train_q': (0.06494140625, 0.0011800023494288325), 'train_q_cum': (0.109375, 0.0019873722922056913), 'train_refit': (0.08056640625, 0.0014639126602560282), 'test_q': (0.08154296875, 0.0014840448275208473), 'test_q_cum': (0.1318359375, 0.0023993540089577436), 'test_refit': (0.1318359375, 0.0023993540089577436)}\n" | |
] | |
} | |
], | |
"source": [ | |
"losses[layer_name] = calc_losses(i=i, layer_name=layer_name)\n", | |
"print(losses[layer_name]) " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "b34a15c3-7cc0-4af2-bf6c-4872cc5a4acd", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 91, | |
"id": "113c8f3c-8a6b-4a21-9d0e-9da16b43ea00", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"loss_calc: 0%| | 0/28 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"(0.1318359375, 0.0023993540089577436)" | |
] | |
}, | |
"execution_count": 91, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"j = 0\n", | |
"ii = 9\n", | |
"# torch.equal(X[ii]['test_refit'][ii], X[4]['test_q_cum'][j])\n", | |
"t = 'test'\n", | |
"tag = 'test_q_cum'\n", | |
"batch_mse_loss(X[ii][t], X[ii][tag])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 92, | |
"id": "5f092017-81d8-43dc-a006-a51396471f45", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"loss_calc: 0%| | 0/28 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"(0.1318359375, 0.0023993540089577436)" | |
] | |
}, | |
"execution_count": 92, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"j = 0\n", | |
"ii = 9\n", | |
"# torch.equal(X[ii]['test_refit'][ii], X[4]['test_q_cum'][j])\n", | |
"t = 'test'\n", | |
"tag = 'test_refit'\n", | |
"batch_mse_loss(X[ii][t], X[ii][tag])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 90, | |
"id": "c82bf042-0bb9-4fba-b129-371b50f884a1", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"loss_calc: 0%| | 0/28 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"(0.1640625, 0.3178808093070984)" | |
] | |
}, | |
"execution_count": 90, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"j = 0\n", | |
"ii = 8\n", | |
"# torch.equal(X[ii]['test_refit'][ii], X[4]['test_q_cum'][j])\n", | |
"t = 'test_refit'\n", | |
"tag = 'test_q_cum'\n", | |
"batch_mse_loss(X[ii][t], X[ii][tag])" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "10c2511b-bfcf-4286-8ab0-f9c6cb6845fc", | |
"metadata": {}, | |
"source": [ | |
"#### 10. final check" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 78, | |
"id": "be37a466-4677-43b9-8c66-ec40d3d569fe", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "8c48bed48ad64b4a80cd6c3459e4da6a", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
" 0%| | 0/28 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 2.17 s, sys: 0 ns, total: 2.17 s\n", | |
"Wall time: 2.16 s\n" | |
] | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"Y_ref_test = [block.cuda()(xb['hidden_states'].cuda())[0].cpu() for xb in tqdm(X[0]['test'])]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 79, | |
"id": "9934336e-212f-4a19-9a6e-6c87fcff0135", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"loss_calc: 0%| | 0/28 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"(0.0, 0.0)" | |
] | |
}, | |
"execution_count": 79, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"batch_mse_loss(X[9]['test'], Y_ref_test)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "537cce12-e89d-40b8-bcef-b00ab70b0a13", | |
"metadata": {}, | |
"source": [ | |
"### SUMMARY of losses " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 80, | |
"id": "df9c7fad-93f5-434f-aebb-bf9d546840b7", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [], | |
"source": [ | |
"mses = {kk: {k:v[0] for k, v in vv.items() if k != 'layer_name'} for kk, vv in losses.items()}\n", | |
"df = pd.DataFrame(mses).T" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 81, | |
"id": "4b86d3ef-274b-44e4-8734-d653de2a6fa2", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>train_q</th>\n", | |
" <th>test_q</th>\n", | |
" <th>train_q_cum</th>\n", | |
" <th>train_refit</th>\n", | |
" <th>test_q_cum</th>\n", | |
" <th>test_refit</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>46_attention</th>\n", | |
" <td>0.010864</td>\n", | |
" <td>0.012512</td>\n", | |
" <td>NaN</td>\n", | |
" <td>NaN</td>\n", | |
" <td>NaN</td>\n", | |
" <td>NaN</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>46_dense</th>\n", | |
" <td>0.011780</td>\n", | |
" <td>0.015137</td>\n", | |
" <td>0.031738</td>\n", | |
" <td>0.020752</td>\n", | |
" <td>0.032471</td>\n", | |
" <td>0.030029</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>46_add_resid_0</th>\n", | |
" <td>0.012024</td>\n", | |
" <td>0.015320</td>\n", | |
" <td>0.031982</td>\n", | |
" <td>0.020752</td>\n", | |
" <td>0.032471</td>\n", | |
" <td>0.030151</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>46_layer_norm_1</th>\n", | |
" <td>0.000546</td>\n", | |
" <td>0.000664</td>\n", | |
" <td>0.001152</td>\n", | |
" <td>0.001045</td>\n", | |
" <td>0.001411</td>\n", | |
" <td>0.001305</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>46_mlp1</th>\n", | |
" <td>0.047119</td>\n", | |
" <td>0.053955</td>\n", | |
" <td>0.051758</td>\n", | |
" <td>0.047363</td>\n", | |
" <td>0.057617</td>\n", | |
" <td>0.056396</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>46_mlp_gelu</th>\n", | |
" <td>0.004211</td>\n", | |
" <td>0.004456</td>\n", | |
" <td>0.004395</td>\n", | |
" <td>0.004120</td>\n", | |
" <td>0.004761</td>\n", | |
" <td>0.004364</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>46_mlp2</th>\n", | |
" <td>0.043213</td>\n", | |
" <td>0.065430</td>\n", | |
" <td>0.073242</td>\n", | |
" <td>0.070312</td>\n", | |
" <td>0.100586</td>\n", | |
" <td>0.103516</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>46_add_resid_1</th>\n", | |
" <td>0.064941</td>\n", | |
" <td>0.081543</td>\n", | |
" <td>0.109375</td>\n", | |
" <td>0.080566</td>\n", | |
" <td>0.131836</td>\n", | |
" <td>0.131836</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" train_q test_q train_q_cum train_refit test_q_cum \\\n", | |
"46_attention 0.010864 0.012512 NaN NaN NaN \n", | |
"46_dense 0.011780 0.015137 0.031738 0.020752 0.032471 \n", | |
"46_add_resid_0 0.012024 0.015320 0.031982 0.020752 0.032471 \n", | |
"46_layer_norm_1 0.000546 0.000664 0.001152 0.001045 0.001411 \n", | |
"46_mlp1 0.047119 0.053955 0.051758 0.047363 0.057617 \n", | |
"46_mlp_gelu 0.004211 0.004456 0.004395 0.004120 0.004761 \n", | |
"46_mlp2 0.043213 0.065430 0.073242 0.070312 0.100586 \n", | |
"46_add_resid_1 0.064941 0.081543 0.109375 0.080566 0.131836 \n", | |
"\n", | |
" test_refit \n", | |
"46_attention NaN \n", | |
"46_dense 0.030029 \n", | |
"46_add_resid_0 0.030151 \n", | |
"46_layer_norm_1 0.001305 \n", | |
"46_mlp1 0.056396 \n", | |
"46_mlp_gelu 0.004364 \n", | |
"46_mlp2 0.103516 \n", | |
"46_add_resid_1 0.131836 " | |
] | |
}, | |
"execution_count": 81, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"df # MSEs" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "88d8d651-d6bd-4f25-9235-bc54083e2bf3", | |
"metadata": {}, | |
"source": [ | |
"#### # RELATIVE STANDARD ERRORS" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 82, | |
"id": "a4786349-0578-4b2f-b3d9-479f46c8a49f", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [], | |
"source": [ | |
"rses = {kk: {k:v[1] for k, v in vv.items() if k != 'layer_name'} for kk, vv in losses.items()}\n", | |
"dfr = pd.DataFrame(rses).T" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 83, | |
"id": "392d5ead-83d1-4f8e-86e1-f35211783d27", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>train_q</th>\n", | |
" <th>test_q</th>\n", | |
" <th>train_q_cum</th>\n", | |
" <th>train_refit</th>\n", | |
" <th>test_q_cum</th>\n", | |
" <th>test_refit</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>46_attention</th>\n", | |
" <td>0.078022</td>\n", | |
" <td>0.091042</td>\n", | |
" <td>NaN</td>\n", | |
" <td>NaN</td>\n", | |
" <td>NaN</td>\n", | |
" <td>NaN</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>46_dense</th>\n", | |
" <td>0.054555</td>\n", | |
" <td>0.070375</td>\n", | |
" <td>0.146986</td>\n", | |
" <td>0.096106</td>\n", | |
" <td>0.150965</td>\n", | |
" <td>0.139614</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>46_add_resid_0</th>\n", | |
" <td>0.000223</td>\n", | |
" <td>0.000285</td>\n", | |
" <td>0.000594</td>\n", | |
" <td>0.000386</td>\n", | |
" <td>0.000604</td>\n", | |
" <td>0.000561</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>46_layer_norm_1</th>\n", | |
" <td>0.000774</td>\n", | |
" <td>0.000951</td>\n", | |
" <td>0.001634</td>\n", | |
" <td>0.001482</td>\n", | |
" <td>0.002021</td>\n", | |
" <td>0.001868</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>46_mlp1</th>\n", | |
" <td>0.024921</td>\n", | |
" <td>0.028832</td>\n", | |
" <td>0.027375</td>\n", | |
" <td>0.025050</td>\n", | |
" <td>0.030789</td>\n", | |
" <td>0.030136</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>46_mlp_gelu</th>\n", | |
" <td>0.059492</td>\n", | |
" <td>0.062715</td>\n", | |
" <td>0.062079</td>\n", | |
" <td>0.058199</td>\n", | |
" <td>0.067010</td>\n", | |
" <td>0.061426</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>46_mlp2</th>\n", | |
" <td>0.091185</td>\n", | |
" <td>0.137860</td>\n", | |
" <td>0.154551</td>\n", | |
" <td>0.148369</td>\n", | |
" <td>0.211934</td>\n", | |
" <td>0.218107</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>46_add_resid_1</th>\n", | |
" <td>0.001180</td>\n", | |
" <td>0.001484</td>\n", | |
" <td>0.001987</td>\n", | |
" <td>0.001464</td>\n", | |
" <td>0.002399</td>\n", | |
" <td>0.002399</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" train_q test_q train_q_cum train_refit test_q_cum \\\n", | |
"46_attention 0.078022 0.091042 NaN NaN NaN \n", | |
"46_dense 0.054555 0.070375 0.146986 0.096106 0.150965 \n", | |
"46_add_resid_0 0.000223 0.000285 0.000594 0.000386 0.000604 \n", | |
"46_layer_norm_1 0.000774 0.000951 0.001634 0.001482 0.002021 \n", | |
"46_mlp1 0.024921 0.028832 0.027375 0.025050 0.030789 \n", | |
"46_mlp_gelu 0.059492 0.062715 0.062079 0.058199 0.067010 \n", | |
"46_mlp2 0.091185 0.137860 0.154551 0.148369 0.211934 \n", | |
"46_add_resid_1 0.001180 0.001484 0.001987 0.001464 0.002399 \n", | |
"\n", | |
" test_refit \n", | |
"46_attention NaN \n", | |
"46_dense 0.139614 \n", | |
"46_add_resid_0 0.000561 \n", | |
"46_layer_norm_1 0.001868 \n", | |
"46_mlp1 0.030136 \n", | |
"46_mlp_gelu 0.061426 \n", | |
"46_mlp2 0.218107 \n", | |
"46_add_resid_1 0.002399 " | |
] | |
}, | |
"execution_count": 83, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"dfr # RELATIVE STANDARD ERRORS" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "71371f17-2cec-424e-9601-bbf58dddb3e1", | |
"metadata": {}, | |
"source": [ | |
"#### activation variances by layer" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 105, | |
"id": "173ba439-419c-42fe-9e80-726fe1c7a522", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [], | |
"source": [ | |
"xx = []\n", | |
"for i in X.keys():\n", | |
" for j in X[i].keys():\n", | |
" q = X[i][j][0]\n", | |
" if isinstance(q, dict):\n", | |
" q = q['hidden_states']\n", | |
" xx.append([i, j, q.var().item()])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 108, | |
"id": "ec5ba45d-c7d3-4a66-8fe3-24206932ebb7", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th>slice</th>\n", | |
" <th>test</th>\n", | |
" <th>test_q</th>\n", | |
" <th>test_q_cum</th>\n", | |
" <th>test_refit</th>\n", | |
" <th>train</th>\n", | |
" <th>train_q</th>\n", | |
" <th>train_q_cum</th>\n", | |
" <th>train_refit</th>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>layer</th>\n", | |
" <th></th>\n", | |
" <th></th>\n", | |
" <th></th>\n", | |
" <th></th>\n", | |
" <th></th>\n", | |
" <th></th>\n", | |
" <th></th>\n", | |
" <th></th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>53.500000</td>\n", | |
" <td>NaN</td>\n", | |
" <td>NaN</td>\n", | |
" <td>NaN</td>\n", | |
" <td>54.000000</td>\n", | |
" <td>NaN</td>\n", | |
" <td>NaN</td>\n", | |
" <td>NaN</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>0.636719</td>\n", | |
" <td>NaN</td>\n", | |
" <td>NaN</td>\n", | |
" <td>NaN</td>\n", | |
" <td>0.652344</td>\n", | |
" <td>NaN</td>\n", | |
" <td>NaN</td>\n", | |
" <td>NaN</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>0.143555</td>\n", | |
" <td>0.150391</td>\n", | |
" <td>NaN</td>\n", | |
" <td>NaN</td>\n", | |
" <td>0.166016</td>\n", | |
" <td>0.172852</td>\n", | |
" <td>NaN</td>\n", | |
" <td>NaN</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>0.230469</td>\n", | |
" <td>0.244141</td>\n", | |
" <td>0.253906</td>\n", | |
" <td>0.231445</td>\n", | |
" <td>0.261719</td>\n", | |
" <td>0.275391</td>\n", | |
" <td>0.285156</td>\n", | |
" <td>0.261719</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>54.500000</td>\n", | |
" <td>54.500000</td>\n", | |
" <td>54.500000</td>\n", | |
" <td>54.500000</td>\n", | |
" <td>55.000000</td>\n", | |
" <td>55.000000</td>\n", | |
" <td>55.000000</td>\n", | |
" <td>55.000000</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>5</th>\n", | |
" <td>0.695312</td>\n", | |
" <td>0.695312</td>\n", | |
" <td>0.695312</td>\n", | |
" <td>0.695312</td>\n", | |
" <td>0.734375</td>\n", | |
" <td>0.738281</td>\n", | |
" <td>0.738281</td>\n", | |
" <td>0.734375</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>6</th>\n", | |
" <td>1.875000</td>\n", | |
" <td>1.921875</td>\n", | |
" <td>1.921875</td>\n", | |
" <td>1.875000</td>\n", | |
" <td>1.992188</td>\n", | |
" <td>2.046875</td>\n", | |
" <td>2.046875</td>\n", | |
" <td>1.992188</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>7</th>\n", | |
" <td>0.073730</td>\n", | |
" <td>0.077637</td>\n", | |
" <td>0.078613</td>\n", | |
" <td>0.073730</td>\n", | |
" <td>0.073730</td>\n", | |
" <td>0.078125</td>\n", | |
" <td>0.078125</td>\n", | |
" <td>0.073242</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>8</th>\n", | |
" <td>0.494141</td>\n", | |
" <td>0.558594</td>\n", | |
" <td>0.593750</td>\n", | |
" <td>0.539062</td>\n", | |
" <td>0.494141</td>\n", | |
" <td>0.542969</td>\n", | |
" <td>0.574219</td>\n", | |
" <td>0.523438</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>9</th>\n", | |
" <td>55.750000</td>\n", | |
" <td>55.750000</td>\n", | |
" <td>55.750000</td>\n", | |
" <td>55.750000</td>\n", | |
" <td>56.000000</td>\n", | |
" <td>56.250000</td>\n", | |
" <td>56.250000</td>\n", | |
" <td>56.000000</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
"slice test test_q test_q_cum test_refit train train_q \\\n", | |
"layer \n", | |
"0 53.500000 NaN NaN NaN 54.000000 NaN \n", | |
"1 0.636719 NaN NaN NaN 0.652344 NaN \n", | |
"2 0.143555 0.150391 NaN NaN 0.166016 0.172852 \n", | |
"3 0.230469 0.244141 0.253906 0.231445 0.261719 0.275391 \n", | |
"4 54.500000 54.500000 54.500000 54.500000 55.000000 55.000000 \n", | |
"5 0.695312 0.695312 0.695312 0.695312 0.734375 0.738281 \n", | |
"6 1.875000 1.921875 1.921875 1.875000 1.992188 2.046875 \n", | |
"7 0.073730 0.077637 0.078613 0.073730 0.073730 0.078125 \n", | |
"8 0.494141 0.558594 0.593750 0.539062 0.494141 0.542969 \n", | |
"9 55.750000 55.750000 55.750000 55.750000 56.000000 56.250000 \n", | |
"\n", | |
"slice train_q_cum train_refit \n", | |
"layer \n", | |
"0 NaN NaN \n", | |
"1 NaN NaN \n", | |
"2 NaN NaN \n", | |
"3 0.285156 0.261719 \n", | |
"4 55.000000 55.000000 \n", | |
"5 0.738281 0.734375 \n", | |
"6 2.046875 1.992188 \n", | |
"7 0.078125 0.073242 \n", | |
"8 0.574219 0.523438 \n", | |
"9 56.250000 56.000000 " | |
] | |
}, | |
"execution_count": 108, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"dfv = pd.DataFrame(xx, columns = ('layer', 'slice', 'var'))\n", | |
"dfv.set_index(['layer', 'slice'])['var'].unstack()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "1dfa04d7-b287-439f-8030-fc7210ebdbea", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 84, | |
"id": "cfa3d0ce-0ab9-4373-a40b-85f381d7bd17", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"allocated: 12.1 GB, max: 56.5 GB, before flush: 12.1 GB\n" | |
] | |
} | |
], | |
"source": [ | |
"cuda_mem_use()" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.9.15" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment