Skip to content

Instantly share code, notes, and snippets.

@jkminder
Last active June 21, 2024 12:47
Show Gist options
  • Save jkminder/d05d708f3f93c66037ac7f0c352eefa4 to your computer and use it in GitHub Desktop.
Save jkminder/d05d708f3f93c66037ac7f0c352eefa4 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "0d99bcc4264645cebfdce8e48699baa0",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Loading checkpoint shards: 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "b70c6a933cc7432dba1cfad042f1021f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Loading checkpoint shards: 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loaded pretrained model meta-llama/Meta-Llama-3-8B-Instruct into HookedTransformer\n"
]
}
],
"source": [
"from nnsight import NNsight\n",
"from matplotlib import pyplot as plt\n",
"import seaborn as sns\n",
"import torch\n",
"from transformer_lens import HookedTransformer\n",
"from transformers import AutoTokenizer, LlamaForCausalLM\n",
"\n",
"MODEL_PATH = 'meta-llama/Meta-Llama-3-8B-Instruct'\n",
"\n",
"# Load huggingface model and tokenizer using LlamaForCausalLM and shard it on 8 gpus\n",
"\n",
"hf_model = LlamaForCausalLM.from_pretrained(\n",
" MODEL_PATH\n",
").to('cuda:0')\n",
"hf_tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)\n",
"\n",
"# Load HookedTransformer model \n",
"hooked_model = HookedTransformer.from_pretrained_no_processing(\n",
" MODEL_PATH,\n",
" device=\"cuda\",\n",
" dtype=torch.float32,\n",
" default_padding_side='left',\n",
" n_devices=1,\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"TF Greedy: tensor([4815], device='cuda:0') Logit: tensor(13.3366, device='cuda:0')\n",
"TL Greedy: tensor([4815], device='cuda:0') Logit: tensor(13.6355, device='cuda:0')\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAs0AAAGxCAYAAACZR4umAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8fJSN1AAAACXBIWXMAAA9hAAAPYQGoP6dpAABV4ElEQVR4nO3deVxVdf7H8fdV4YKIV5bgCiIiKmpiGc0otKi5ZmZlyzROpGVquaXm1DS2YFNaVmppLm1qqVkz41ZTpGY5Ne40FJrjb9zCDTVlUUNE+P7+KM54Wbyg6AV8PR+P+6j7vZ97zucczvLx3O/5HpsxxggAAABAmWp5OgEAAACgqqNoBgAAANygaAYAAADcoGgGAAAA3KBoBgAAANygaAYAAADcoGgGAAAA3KBoBgAAANygaAYAAADcqFDRPHfuXNlsNuvl4+Mjp9Opzp07a+LEiTp8+HCJ7yQlJclms7m0nT59Wg8//LAaNmyo2rVr6+qrr5YkHTt2TPfee69CQkJks9l0++23n/eC1QQ2m03Dhw+vtOn9/PPPSkpK0ldffVVp06xqStveKtuePXtks9k0d+5cq23t2rVKSkpSVlZWpc4rPz9fLVu21Isvvmi1Fe2He/bsqdR5FVm4cKGmTp1aoj0zM1MNGjTQ0qVLyz2tL774Qtdee638/Pxks9kq9N1L7YcfflBSUlKp67VTp05q06bNRc/hUs2nNJV9vCnNgAED1KRJk4s6j8pwIceRqriMNptNSUlJ54w5cOCAkpKSlJqaWuKzAQMGqF69ehcnOQ/M50JdivOMVHathP/p1KmTOnXqdF7fbdKkiQYMGFCh79Q5nxnNmTNHLVu2VH5+vg4fPqxvvvlGL730kl555RV9+OGH6tq1qxX70EMPqWfPni7fnzlzpmbPnq1p06YpLi7O2kn+8pe/aMmSJXr33XcVHR2twMDA80kPZfj55581fvx4STrvjQxSw4YNtW7dOkVHR1tta9eu1fjx4zVgwAA1aNCg0uY1Y8YMZWZmasSIEZU2TXcWLlyoLVu2aNSoUS7tAQEBGj16tP74xz+qV69e8vb2Pud0jDG655571KJFCy1fvlx+fn6KiYm5iJlfmB9++EHjx49Xp06dqlzRA1xsBw4c0Pjx49WkSROKMzdKq2suhrJqJXjOeRXNbdq00bXXXmu9v/POOzV69Ghdf/316tu3r/773/8qNDRUktSoUSM1atTI5ftbtmyRr69viasaW7ZsUXR0tP7whz+cT1qlys3Nla+vb6VND7Db7erQocNFn8+ZM2f08ssv68EHH5Sfn99Fn195PPzww3r++ef1t7/9Tf369Ttn7IEDB3Ts2DHdcccd6tKlyzljf/75Z9WtW7cyUwWAi6K0uuZiKKtWgudUWp/mxo0b69VXX9Xx48c1e/Zsq734zxg2m01vv/22cnNzrW4eRT83r1q1Stu2bbPai7oRnD59Ws8//7xatmwpu92uK664Qg888ICOHDnikkOTJk3Uu3dvLV68WO3atZOPj491ZTUjI0NDhgxRo0aN5O3traioKI0fP15nzpyxvl/0s/srr7yiyZMnKyoqSvXq1VN8fLzWr19fYpk3bNigW2+9VUFBQfLx8VF0dHSJq3P//e9/1a9fP4WEhMhut6tVq1Z64403KrRuZ8+erRYtWshut6t169ZatGhRiRh3y7dnzx5dccUVkqTx48db63jAgAHaunWrbDab/vrXv1rTS0lJkc1m05VXXukynz59+iguLs6l7cMPP1R8fLz8/PxUr1499ejRQ//+979L5Lh582b16dNHgYGB8vHxUbt27fTRRx+5xBRtC19++aUeeeQRBQcHKygoSH379tWBAwcqtN6KFBYWatKkSdb2ExISovvvv1/79u1ziTPGaMKECYqMjJSPj4+uvfZarVy5ssTPP8W7ZyQlJemPf/yjJCkqKqrE9rt69Wp16tRJQUFB8vX1VePGjXXnnXfq559/Pmfey5cv1/79+5WYmOh2GVeuXKnbbrtNjRo1ko+Pj5o1a6YhQ4bop59+cok7cuSIBg8erIiICGtfuu6667Rq1SpJv/wC8Y9//EM//vijS1esIqGhoerWrZtmzZp1znySkpKsk8oTTzwhm81mXb0tOiZ8++23uuuuuxQQEGBdtT916pSefPJJRUVFydvbW+Hh4Ro2bFiJbi9F+/onn3yidu3aydfXV61atdInn3wi6ZftqFWrVvLz89Nvf/tbbd68+Zz5zp07V3fffbckqXPnzi7HprNt2rRJN9xwg+rWraumTZvqxRdfVGFhoUtMTk6Oxo4d67IMo0aN0smTJ8+ZQ0Xnk56ervvuu8/l2PLqq6+WiDt27JiGDh2q8PBweXt7q2nTpho3bpzy8vLOmYMxRn/+85/l5eWlt956y2ov7/4+d+5cxcTEWLm999575V7+yvj7Ll++XPHx8apbt678/f3VrVs3rVu3rkTcP/7xD1199dWy2+2KiorSK6+8Uub6mDFjhq6++mr5+voqICBAd911l3bt2lXu5TpbeffZov1l69at+v3vfy+Hw6HQ0FA9+OCDys7OdonNycnRoEGDFBQUpHr16qlnz576v//7P7e5fPXVV/rNb34jSXrggQes7b94l44dO3aoV69eqlevniIiIvTYY4+V2I7Ke74+F3fz+eqrr1yOsUVK6zonSW+99ZbLOXThwoWldqPZt2+f7rrrLvn7+6tBgwb6wx/+oE2bNpWYZmndM4q22eTkZF1zzTXy9fVVy5Yt9e6775ZYvm+++Ubx8fHy8fFReHi4nn76ab399tsu3e7KqpUk6Y033tCNN96okJAQ+fn5KTY2VpMmTVJ+fn6JeSUnJ6tLly5yOByqW7euWrVqpYkTJ7rElOfcXJqi9f3yyy/rpZdeUpMmTeTr66tOnTrp//7v/5Sfn68//elPCgsLk8Ph0B133FGiG29Fzs+TJk2yzs/XXHONPvvss1LzqoxjcJlMBcyZM8dIMps2bSr18xMnTpjatWubLl26WG3PPvusOXs269atM7169TK+vr5m3bp1Zt26dSYjI8OsW7fOtGvXzjRt2tRqz87ONgUFBaZnz57Gz8/PjB8/3qxcudK8/fbbJjw83LRu3dr8/PPP1rQjIyNNw4YNTdOmTc27775rvvzyS7Nx40Zz8OBBExERYSIjI83s2bPNqlWrzF/+8hdjt9vNgAEDrO/v3r3bSDJNmjQxPXv2NEuXLjVLly41sbGxJiAgwGRlZVmxycnJxsvLy7Rt29bMnTvXrF692rz77rvm3nvvtWK2bt1qHA6HiY2NNe+9955ZsWKFeeyxx0ytWrVMUlKS2/UtyURERJjWrVubDz74wCxfvtz07NnTSDJ//etfrbjyLN+pU6dMcnKykWQGDhxoreMdO3YYY4xp2LChGTx4sDXNF1980fj6+hpJZv/+/cYYY/Lz8039+vXN448/bsW98MILxmazmQcffNB88sknZvHixSY+Pt74+fmZrVu3WnGrV6823t7e5oYbbjAffvihSU5ONgMGDDCSzJw5c6y4om2sadOmZsSIEebzzz83b7/9tgkICDCdO3d2u86Kb2/GGDN48GAjyQwfPtwkJyebWbNmmSuuuMJERESYI0eOWHFPPvmkkWQGDx5skpOTzVtvvWUaN25sGjZsaDp27GjFFW0nRXnv3bvXjBgxwkgyixcvdtl+d+/ebXx8fEy3bt3M0qVLzVdffWUWLFhgEhMTTWZm5jmX5cEHHzQhISEl2ovW0e7du622mTNnmokTJ5rly5ebNWvWmHnz5pmrrrrKxMTEmNOnT1txPXr0MFdccYV58803zVdffWWWLl1qnnnmGbNo0SJjzC/b7HXXXWecTqe1HOvWrXOZ/0svvWRq1ap1zvz37t1rFi9ebCSZESNGmHXr1plvv/3WGPO/v1FkZKR54oknzMqVK83SpUtNYWGh6dGjh6lTp455+umnzYoVK8wrr7xi/Pz8TLt27cypU6es6UdGRppGjRqZNm3amA8++MB8+umnpn379sbLy8s888wz5rrrrjOLFy82S5YsMS1atDChoaEux4riDh8+bCZMmGAkmTfeeMNa7sOHDxtjjOnYsaMJCgoyzZs3N7NmzTIrV640Q4cONZLMvHnzrOmcPHnSXH311SY4ONhMnjzZrFq1yrz22mvG4XCYm266yRQWFpaZQ0Xmc/jwYRMeHm6uuOIKM2vWLJOcnGyGDx9uJJlHHnnEisvNzTVt27Y1fn5+5pVXXjErVqwwTz/9tKlTp47p1auXy7wlmWHDhhljfjle3Hvvvcbf39989tlnVkx59/eibfS2224zH3/8sZk/f75p1qyZdZxy50L/vgsWLDCSTPfu3c3SpUvNhx9+aOLi4oy3t7f5+uuvrbhVq1aZ2rVrm+uvv94sXrzY/PWvfzW/+c1vTOPGjUscRwYNGmS8vLzMY489ZpKTk83ChQtNy5YtTWhoqMnIyLDi+vfvX65lLO8+W7S/xMTEmGeeecasXLnSTJ482djtdvPAAw9YcYWFhaZz587GbrebF154waxYscI8++yzpmnTpkaSefbZZ8vMJTs72/qbPfXUU9b2v3fvXmuZvL29TatWrcwrr7xiVq1aZZ555hljs9nM+PHjrelU5HxdmvLO58svvzSSzJdffuny/eLHZmOMmT17tpFk7rzzTvPJJ5+YBQsWmBYtWpjIyEiXv9OJEydMs2bNTGBgoHnjjTfM559/bkaPHm2ioqJKTLO080zRNtu6dWvz3nvvmc8//9zcfffdRpJZs2aNFffdd98ZHx8f07ZtW7No0SKzfPly06tXL9OkSROX43pptVLR8Wj06NFm5syZJjk52axevdpMmTLFBAcHu2wPxhjz9ttvG5vNZjp16mQWLlxoVq1aZWbMmGGGDh1qxZT33FyaovUdGRlpbr31VvPJJ5+Y+fPnm9DQUNOiRQuTmJhoHnzwQfPZZ5+ZWbNmmXr16plbb73VZRrlPT8XrfOBAweazz77zLz55psmPDzcOJ1Ol/NzRY7BkZGRpn///udcxuIqtWg2xpjQ0FDTqlUr631pG1f//v2Nn59fie927NjRXHnllS5tH3zwgZFk/v73v7u0b9q0yUgyM2bMsNoiIyNN7dq1zfbt211ihwwZYurVq2d+/PFHl/ZXXnnFSLIO9kUbQGxsrDlz5owVt3HjRiPJfPDBB1ZbdHS0iY6ONrm5uWWuix49ephGjRqZ7Oxsl/bhw4cbHx8fc+zYsTK/a8wvJzFfX1+XA/KZM2dMy5YtTbNmzSq8fEeOHCnz4HnfffeZpk2bWu+7du1qBg0aZAICAqyT9b/+9S8jyaxYscIYY0x6erqpU6eOGTFihMu0jh8/bpxOp7nnnnustpYtW5p27dqZ/Px8l9jevXubhg0bmoKCAmPM/7axs3dqY4yZNGmSkWQOHjx4znVWfHvbtm1bqdPbsGGDkWT+/Oc/G2OMOXbsmLHb7eZ3v/udS9y6deuMpHMWzcYY8/LLL5coZI0x5m9/+5uRZFJTU8+Zd2latWplevbsWaK9tKL5bIWFhSY/P9/8+OOPRpJZtmyZ9Vm9evXMqFGjzjnfW2655Zwn/ZUrVxpJLsVUaYrW08svv+zSXvQ3euaZZ1zai/5RN2nSJJf2Dz/80Egyb775ptUWGRlpfH19zb59+6y21NRUI8k0bNjQnDx50mpfunSpkWSWL19+znz/+te/lnoiNuaXY5Mks2HDBpf21q1bmx49eljvJ06caGrVqlXiGFm0HXz66afnzKG88/nTn/5UatwjjzxibDabdQycNWuWkWQ++ugjl7iXXnrJZV825n9F89GjR831119vwsPDXbbb8u7vBQUFJiwszFxzzTUuJ6g9e/YYLy+vchfN5/v3LZp/bGysdVwpyjMkJMQkJCRYbe3btzdhYWEux/GcnBwTGBhY4mKPJPPqq6+65Ll3717j6+vrciGhvEXz2c61zxbtL8X3i6FDhxofHx9rHX/22WdGknnttddc4l544QW3RbMx/zunllYo9e/fv9TtqFevXiYmJsZ6X5HzdWnKO5/yFs0FBQXG6XSa9u3bu8T9+OOPJbbFN954o9Tj2pAhQ8pdNPv4+Lich3Nzc01gYKAZMmSI1Xb33XcbPz8/l4KwoKDAtG7dusRxvaxa6WwFBQUmPz/fvPfee6Z27dpWXXH8+HFTv359c/3115/zH+vlPTeXpmh9X3XVVS5xU6dONZJMnz59XOJHjRplJFk1UXnPz5mZmcbHx8fccccdLnFFNcnZ5+eKHIPPp2iu9CHnjDGVOr1PPvlEDRo00K233qozZ85Yr6uvvlpOp7PEzzNt27ZVixYtSkyjc+fOCgsLc5nGzTffLElas2aNS/wtt9yi2rVru0xTkn788UdJ0v/93/9p586dGjhwoHx8fErN+9SpU/riiy90xx13qG7dui7z7dWrl06dOlVql4/iunTpYvUPl6TatWvrd7/7nXbs2GH9fFHR5StrPrt27dLu3bt16tQpffPNN+rZs6c6d+6slStXSpJWrVolu92u66+/XpL0+eef68yZM7r//vtd5uvj46OOHTtaf5sdO3boP//5j9VXvfi6OHjwoLZv3+6ST58+fVzeF/8blNeXX34pSSXukP3tb3+rVq1a6YsvvpAkrV+/Xnl5ebrnnntc4jp06HBBN4VdffXV8vb21uDBgzVv3rwK/ZR74MABhYSElCv28OHDevjhhxUREaE6derIy8tLkZGRkqRt27ZZcb/97W81d+5cPf/881q/fn2pP+e5U5TT/v37K/zds915550u71evXi2p5N/q7rvvlp+fn/W3KnL11VcrPDzcet+qVStJv3QxObt/dFF7Rbed4pxOp37729+6tLVt29Zlup988onatGmjq6++2mU779GjR6k/J5/vfFavXq3WrVuXiBswYICMMda6XL16tfz8/HTXXXeViJNUYp3u3r1b8fHxysnJ0fr163XVVVdZn5V3f9++fbsOHDigfv36ufyEHRkZqYSEBLfLX+R8/75F809MTFStWv87xdWrV0933nmn1q9fr59//lknT57Upk2b1LdvX5fjuL+/v2699VaXXD755BPZbDbdd999LsvudDp11VVXndeIROXdZ4uUdkw8deqU9XN30bGu+D1B7u49KC+bzVZivZS2/VfkfH2+8ymv7du3KyMjo8RxvXHjxrruuutc2tasWSN/f/8SN/j9/ve/L/f8rr76ajVu3Nh67+PjoxYtWrjkvmbNGt10000KDg622mrVqlUix3P597//rT59+igoKEi1a9eWl5eX7r//fhUUFFjdcdauXaucnBwNHTq0zJE+zufcXJpevXq57GtF++Qtt9ziElfUnp6eLqn85+d169bp1KlTJbbthIQEa58pUhnH4HM5rxsBy3Ly5EkdPXpUsbGxlTbNQ4cOKSsrq8w79Yv3/2rYsGGp0/j444/l5eVVrmkEBQW5vLfb7ZJ+ualQktU361w3Ahw9elRnzpzRtGnTNG3atHLNtzROp7PMtqNHj6pRo0YVXr7SFI14smrVKkVFRSk/P1833XSTDh06pL/85S/WZ9ddd511Y+WhQ4ckyeoLV1zRTlQUN3bsWI0dO7ZcObr7G5TX0aNHJZW+XYSFhVkHs6K4s/+BUqS0tvKKjo7WqlWrNGnSJA0bNkwnT55U06ZNNXLkSD366KPn/G5ubm6Z/yg7W2Fhobp3764DBw7o6aefVmxsrPz8/FRYWKgOHTq4rLMPP/xQzz//vN5++209/fTTqlevnu644w5NmjSp1G2tNEU5VfRvUVzxv8nRo0dVp04dq+99EZvNJqfTaf2NihQfXafoGFFW+6lTpy4o3+LbpPTLdnn2ejh06JB27NhxQftieeZz9OjRUv8xFxYWZn1e9F+n01nipBkSEqI6deqUWKcbN27UTz/9pBdeeKHE8a28+3vRNMs6dpV3qMTz/fu62+cLCwuVmZkpY4wKCwvPeYwtcujQIRljyjwWNG3atDyLZKnIPlvE3TGxaP8pHlfe/dqdunXrljge2e12l/2qoufr851Pebk7ru/evdsl9kKP/+Xddy9kPunp6brhhhsUExOj1157TU2aNJGPj482btyoYcOGVahOOZ9zc2ku1r5a/Pxc3n31Qo/B51KpRfM//vEPFRQUVOpwZkU3giUnJ5f6ub+/v8v70v5FFRwcrLZt2+qFF14odRpFJ5ryKjqpF++ofraAgADVrl1biYmJGjZsWKkxUVFRbueVkZFRZlvRDloZy9eoUSO1aNFCq1atUpMmTXTttdeqQYMG6tKli4YOHaoNGzZo/fr11o2VRfOVpL/97W8l/rV3tqK4J598Un379i015mINRVa0jg4ePFji4HHgwAErt6K4ooPI2TIyMi7oavMNN9ygG264QQUFBdq8ebOmTZumUaNGKTQ0VPfee2+Z3wsODtaxY8fcTn/Lli367rvvNHfuXPXv399q37FjR6nTnDp1qqZOnar09HQtX75cf/rTn3T48OEy97HiinI6+0rJ+Si+rwYFBenMmTM6cuSIS+FsjFFGRkaZxVpVEhwcLF9f31Jv/in6vDIEBQXp4MGDJdqLbpY9e7vesGGDjDEu6/vw4cM6c+ZMiXx+97vfyel0aty4cSosLNRTTz1VInd3+3vRvnSuY9fFdPY+X9yBAwdUq1YtBQQEWOukPHkGBwfLZrPp66+/torVs5XWdi4V2WfLq2j/OXr0qEvxdinWeZGKnq/PV1FRXfwmxLIuvpR1XC8eu3HjRrdxFyooKKhc+ZRl6dKlOnnypBYvXuyyHxYfX7s8dYonz81Sxc/PZe2rZ5+fL/YxuNKK5vT0dI0dO1YOh0NDhgyprMmqd+/eWrRokQoKCtS+ffvznsann36q6OhoBQQEXHBOLVq0UHR0tN59912NGTOm1ANm3bp11blzZ/373/9W27Zt3Y5pW5YvvvhChw4dsv4VWlBQoA8//FDR0dHWRlbe5XN3tbZr16766KOPFBERYf2s0qJFCzVu3FjPPPOM8vPzXcbg7tGjh+rUqaOdO3eW+Kn9bDExMWrevLm+++47TZgwoWIr4ALddNNNkqT58+e7FF2bNm3Stm3bNG7cOElS+/btZbfb9eGHH7ocPNavX68ff/zRbdFcnivhtWvXVvv27dWyZUstWLBA33777TmL5pYtW2rnzp1ul7GoGCq+HZ49ik1pGjdurOHDh+uLL77Qv/71L5dlOddyFHUxad26tdvcKqJLly6aNGmS5s+fr9GjR1vtf//733Xy5Em3w9ZdqPP9NeNsvXv31oQJExQUFFSufxSfry5dumjixIn69ttvdc0111jt7733nmw2mzp37mzFffTRR1q6dKnuuOMOl7iiz4t76qmn5O/vr9GjR+vkyZPWnfYV2d8bNmyoDz74QGPGjLG2zx9//FFr166t8EWKioqJiVF4eLgWLlyosWPHWvM/efKk/v73v1sjaki//Ay8ePFivfzyy1Yhdvz4cX388ccu0+zdu7defPFF7d+/v0I/o5flfPfZc+ncubMmTZqkBQsWaOTIkVb7woULy/X9ytr+L/R8XR5Fx+Pvv/9ePXr0sNqXL1/uEhcTEyOn06mPPvpIY8aMsdrT09NLbIsdO3bURx99pM8++8zq2iip1NGqLkTHjh316aef6qeffrIKuMLCQpfRq86ltG3HGOMywo30S9cFh8OhWbNm6d577y31gqInz81S+c/PHTp0kI+PjxYsWOBy7Fm7dm2J8/PFPgafV9G8ZcsWq5/I4cOH9fXXX2vOnDmqXbu2lixZUuLn1Qtx7733asGCBerVq5ceffRR/fa3v5WXl5f27dunL7/8UrfddpvLyaA0zz33nFauXKmEhASNHDlSMTExOnXqlPbs2aNPP/1Us2bNqvCYi2+88YZuvfVWdejQQaNHj1bjxo2Vnp6uzz//XAsWLJAkvfbaa7r++ut1ww036JFHHlGTJk10/Phx7dixQx9//LHV7/BcgoODddNNN+npp5+Wn5+fZsyYof/85z8uO3J5l8/f31+RkZFatmyZunTposDAQAUHB1sbXJcuXTRjxgz99NNPLk+E69Kli+bMmaOAgACX4eaaNGmi5557TuPGjdOuXbvUs2dPBQQE6NChQ9q4caP8/PysK9OzZ8/WzTffrB49emjAgAEKDw/XsWPHtG3bNn377bflPmBUVExMjAYPHqxp06apVq1auvnmm7Vnzx49/fTTioiIsIqzwMBAjRkzRhMnTlRAQIDuuOMO7du3T+PHj1fDhg1d+muVpqhL0muvvab+/fvLy8tLMTExWrBggVavXq1bbrlFjRs31qlTp6x/AZ/9D5DSdOrUSc8995zbMYxbtmyp6Oho/elPf5IxRoGBgfr444+tvuhFsrOz1blzZ/Xr108tW7aUv7+/Nm3apOTkZJd/KMTGxmrx4sWaOXOm4uLiVKtWLZdx2devX6+goKBK7YYlSd26dVOPHj30xBNPKCcnR9ddd52+//57Pfvss2rXrl25ht67EEVP4nvzzTfl7+8vHx8fRUVFlfqTa1lGjRqlv//977rxxhs1evRotW3bVoWFhUpPT9eKFSv02GOPVUoxMXr0aL333nu65ZZb9NxzzykyMlL/+Mc/NGPGDD3yyCPWfR3333+/3njjDfXv31979uxRbGysvvnmG02YMEG9evUqcxt89NFHVa9ePQ0ePFgnTpzQ66+/Xu79vVatWvrLX/6ihx56SHfccYcGDRqkrKwsJSUlVVpXgXOpVauWJk2apD/84Q/q3bu3hgwZory8PL388svKyspyebrmX/7yF/Xs2VPdunXTY489poKCAr300kvy8/Nz+ZXnuuuu0+DBg/XAAw9o8+bNuvHGG+Xn56eDBw/qm2++UWxsrB555JFy51jefbYiunfvrhtvvFGPP/64Tp48qWuvvVb/+te/9P7775fr+9HR0fL19dWCBQvUqlUr1atXT2FhYRX6R05lnK/Lw+l0qmvXrtbxOjIyUl988YUWL17sElerVi2NHz9eQ4YM0V133aUHH3xQWVlZpR7X+/fvrylTpui+++7T888/r2bNmumzzz7T559/bk2rMowbN04ff/yxunTponHjxsnX11ezZs2yhkNzN59u3brJ29tbv//97/X444/r1KlTmjlzpjIzM13i6tWrp1dffVUPPfSQunbtqkGDBik0NFQ7duzQd999p+nTp0vy3LlZKv/5OSAgQGPHjtXzzz+vhx56SHfffbf27t1b6jHloh+DK3LXYNFd+0Uvb29vExISYjp27GgmTJhgDYdytgsdPcOYX4Y6e+WVV8xVV11lfHx8TL169UzLli3NkCFDzH//+18rLjIy0txyyy2l5n7kyBEzcuRIExUVZby8vExgYKCJi4sz48aNMydOnDDGlH23vzGm1LuP161bZ26++WbjcDiM3W430dHRZvTo0S4xu3fvNg8++KAJDw83Xl5e5oorrjAJCQnm+eefLzXP4vMcNmyYmTFjhomOjjZeXl6mZcuWZsGCBee1fMb8MsRSu3btjN1uN5Jc7hzNzMw0tWrVMn5+fi5DHhUN39S3b99S81y6dKnp3LmzqV+/vrHb7SYyMtLcddddZtWqVS5x3333nbnnnntMSEiI8fLyMk6n09x0001m1qxZVkxZI7SUdbd0caVtbwUFBeall14yLVq0MF5eXiY4ONjcd9991nBKRQoLC83zzz9vGjVqZLy9vU3btm3NJ598Yq666iqXu3ZLGz3DmF+GrAsLCzO1atWycl23bp254447TGRkpLHb7SYoKMh07NjR7UgOxhizY8cOY7PZStxJXtroGT/88IPp1q2b8ff3NwEBAebuu+826enpLtvtqVOnzMMPP2zatm1r6tevb3x9fU1MTIx59tlnXUYjOHbsmLnrrrtMgwYNjM1mc1mfhYWFJjIyssQICqVxN3rG2XePF8nNzTVPPPGEiYyMNF5eXqZhw4bmkUceKTG8XVn7etE+U548SjN16lQTFRVlateu7fI3LuvYVNpICSdOnDBPPfWUiYmJMd7e3tawk6NHj3YZCac0FZnPjz/+aPr162eCgoKMl5eXiYmJMS+//HKJu92PHj1qHn74YdOwYUNTp04dExkZaZ588kmXIfyMKX3dffDBB6ZOnTrmgQcesKZb3v397bffNs2bNzfe3t6mRYsW5t133y33yBKV8fddunSpad++vfHx8TF+fn6mS5cu5l//+leJaS5fvty0bdvWeHt7m8aNG5sXX3yx1OOIMca8++67pn379sbPz8/4+vqa6Ohoc//995vNmzdbMeVdxvLss8aUvb+UdhzIysoyDz74oGnQoIGpW7eu6datm/nPf/5TrtEzjPnl792yZUvj5eXl8p2yztmlrafynq9LU5H5HDx40Nx1110mMDDQOBwOc99995nNmzeXemx+8803TbNmzVy2xdtuu820a9fOJS49Pd307dvX1KtXz/j7+5s777zTfPrpp2WOaHK2srbZjh07uozuYIwxX3/9tWnfvr2x2+3G6XSaP/7xj9aINmcPbVvW+vj444+t9RseHm7++Mc/WqOnFD9Hfvrpp6Zjx47Gz8/P1K1b17Ru3dq89NJLLjHlOTeXpqx9r+h8ffbQuMaUfn6vyPl54sSJJiIiwjo/f/zxx6Wu3/Ieg89n9AybMZU83AVQg+zevVstW7bUs88+qz//+c+XfP5Fd6GXNYj7pfbFF1+oe/fu2rp1q1q2bOnpdACgwrKystSiRQvdfvvtevPNN88ZO2HCBD311FNKT0+/qE8B7N69u/bs2VOuh9HAcyr1RkCgOvvuu+/0wQcfKCEhQfXr19f27ds1adIk1a9fXwMHDvRIThMnTlS7du20adOmKnEj3PPPP68HH3yQghlAtZCRkaEXXnhBnTt3VlBQkH788UdNmTJFx48fLzGCUVGXhZYtWyo/P1+rV6/W66+/rvvuu69SC+YxY8aoXbt2ioiI0LFjx7RgwQKtXLlS77zzTqXNAxcHRTPwKz8/P23evFnvvPOOsrKy5HA41KlTJ73wwgsXNOzchWjTpo3mzJlzSe+AL0tmZqY6duyooUOHejoVACgXu92uPXv2aOjQoTp27Jjq1q2rDh06aNasWbryyitdYuvWraspU6Zoz549ysvLU+PGjfXEE0+4jCJTGQoKCvTMM88oIyNDNptNrVu31vvvv6/77ruvUueDykf3DAAAAMCNSn8iIAAAAFDTUDQDAAAAblA0AwAAAG5wI2A1VVhYqAMHDsjf37/UJ/0AAICqxxij48ePKywsrNIemoJLg6K5mjpw4IAiIiI8nQYAADgPe/fuvahjP6PyUTRXU/7+/pJ+2enq16/v4WwAAEB55OTkKCIiwjqPo/qgaK6mirpk1K9fn6IZAIBqhq6V1Q+daQAAAAA3KJoBAAAANyiaAQAAADcomgEAAAA3KJoBAAAANyiaAQAAADcomgEAAAA3KJoBAAAANyiaAQAAADcomgEAAAA3KJoBAAAANyiaAQAAADcomn+VlJQkm83m8nI6ndbnxhglJSUpLCxMvr6+6tSpk7Zu3eoyjby8PI0YMULBwcHy8/NTnz59tG/fPpeYzMxMJSYmyuFwyOFwKDExUVlZWZdiEQEAAHCeKJrPcuWVV+rgwYPWKy0tzfps0qRJmjx5sqZPn65NmzbJ6XSqW7duOn78uBUzatQoLVmyRIsWLdI333yjEydOqHfv3iooKLBi+vXrp9TUVCUnJys5OVmpqalKTEy8pMsJoGrLy8vT2rVrXV55eXmeTgsALmt1PJ1AVVKnTh2Xq8tFjDGaOnWqxo0bp759+0qS5s2bp9DQUC1cuFBDhgxRdna23nnnHb3//vvq2rWrJGn+/PmKiIjQqlWr1KNHD23btk3Jyclav3692rdvL0l66623FB8fr+3btysmJqbM3PLy8lxOmjk5OZW56ACqkJSUFI2csUwNwqMlSVn7d+r1oVJCQoKHMwOAyxdXms/y3//+V2FhYYqKitK9996rXbt2SZJ2796tjIwMde/e3Yq12+3q2LGj1q5dK+mXk1x+fr5LTFhYmNq0aWPFrFu3Tg6HwyqYJalDhw5yOBxWTFkmTpxodelwOByKiIiotOUGUPU0CI9WcHSsgqNjreIZAOA5FM2/at++vd577z19/vnneuutt5SRkaGEhAQdPXpUGRkZkqTQ0FCX74SGhlqfZWRkyNvbWwEBAeeMCQkJKTHvkJAQK6YsTz75pLKzs63X3r17z3tZAQAAUDF0z/jVzTffbP1/bGys4uPjFR0drXnz5qlDhw6SJJvN5vIdY0yJtuKKx5QWX57p2O122e12t8sBAACAyseV5jL4+fkpNjZW//3vf61+zsWvBh8+fNi6+ux0OnX69GllZmaeM+bQoUMl5nXkyJESV7EBAABQdVA0lyEvL0/btm1Tw4YNFRUVJafTqZUrV1qfnz59WmvWrLFuzImLi5OXl5dLzMGDB7VlyxYrJj4+XtnZ2dq4caMVs2HDBmVnZ3ODDwAAQBVG94xfjR07VrfeeqsaN26sw4cP6/nnn1dOTo769+8vm82mUaNGacKECWrevLmaN2+uCRMmqG7duurXr58kyeFwaODAgXrssccUFBSkwMBAjR07VrGxsdZoGq1atVLPnj01aNAgzZ49W5I0ePBg9e7d+5wjZwAAAMCzKJp/tW/fPv3+97/XTz/9pCuuuEIdOnTQ+vXrFRkZKUl6/PHHlZubq6FDhyozM1Pt27fXihUr5O/vb01jypQpqlOnju655x7l5uaqS5cumjt3rmrXrm3FLFiwQCNHjrRG2ejTp4+mT59+aRcWAAAAFWIzxhhPJ4GKy8nJkcPhUHZ2turXr+/pdABUorVr1+qZZVsUHB0rSfppZ5qeu60N3biAGoDzd/XFlWYAuITy8vKUkpLi0hYXF8foOABQxVE0A8AldD5P+ys8k6+0tDSXttOnT0uSvL29S31fhIIcACoHRTMAXGJFT/srr5xD6Zq2J1fOXf8bz31f6j9Vp16gnM3alPpe4vHbAFCZKJoBoBrwd0a5FNpZ+3fKy+G02oq/BwBULopmAKihSuvWQXcNADg/FM0AUEMV79ZBdw0AOH8UzQDgQaVdDU5LS1NhYeVMv3i3DgDA+aFoBgAPKv0mv68V0CzOg1kBAIqjaAYADyvtJj8AQNVSy9MJAAAAAFUdRTMAAADgBkUzAAAA4AZFMwAAAOAGNwICwEWSl5enlJQUl7bKHE4OAHDpUDQDwEWSkpKikTOWqUF4tNXGcHIAUD1RNAPARdQgPLrKDCdX2oNUJB6tDQDlQdEMAJeJ0h6kwqO1AaB8KJoB4DLCY7UB4PwwegYAAADgBkUzAAAA4AZFMwAAAOAGfZoBoJIUH5eZMZkBoOagaAaASlJ8XGbGZAaAmoOiGQAq0dnjMntyTGYAQOWiTzMAAADgBkUzAAAA4AZFMwAAAOAGRTMAAADgBkUzAAAA4AZFMwAAAOAGQ84BwGWs8Ey+0tLSXNri4uJkt9s9lBEAVE0UzQBwGcs5lK5pe3Ll3GWT9MvY0q8PlRISEjycGQBULRTNAHCZ83dGWQ9kAQCUjj7NAAAAgBsUzQAAAIAbFM0AAACAGxTNAAAAgBsUzQAAAIAbFM0AAACAGxTNAAAAgBsUzQAAAIAbFM0AAACAGzwREADOQ15enlJSUlza0tLSVFjooYQAABcVRTMAnIeUlBSNnLFMDcKjrbZ9qV8roFmcB7MCAFwsFM0AcJ4ahEcrODrWep+1f6cHswEAXEz0aQYAAADcoGgGAAAA3KBoBgAAANygaAYAAADcoGgGAAAA3KBoBgAAANxgyDkAgKXwTL7S0tJKtMfFxclut3sgIwCoGiiaAQCWnEPpmrYnV85dNqsta/9OvT5USkhI8GBmAOBZFM0AABf+ziiXh7YAAOjTDAAAALhF0QwAAAC4QdEMAAAAuEHRDAAAALhB0QwAAAC4QdFchokTJ8pms2nUqFFWmzFGSUlJCgsLk6+vrzp16qStW7e6fC8vL08jRoxQcHCw/Pz81KdPH+3bt88lJjMzU4mJiXI4HHI4HEpMTFRWVtYlWCoAqLiisZvXrl1rvfLy8jydFgBcUhTNpdi0aZPefPNNtW3b1qV90qRJmjx5sqZPn65NmzbJ6XSqW7duOn78uBUzatQoLVmyRIsWLdI333yjEydOqHfv3iooKLBi+vXrp9TUVCUnJys5OVmpqalKTEy8ZMsHABWRcyhd01Zs1TPLtuiZZVs0csYypaSkeDotALikKJqLOXHihP7whz/orbfeUkBAgNVujNHUqVM1btw49e3bV23atNG8efP0888/a+HChZKk7OxsvfPOO3r11VfVtWtXtWvXTvPnz1daWppWrVolSdq2bZuSk5P19ttvKz4+XvHx8Xrrrbf0ySefaPv27R5ZZgBwp2js5uDoWDUIj/Z0OgBwyVE0FzNs2DDdcsst6tq1q0v77t27lZGRoe7du1ttdrtdHTt21Nq1ayVJKSkpys/Pd4kJCwtTmzZtrJh169bJ4XCoffv2VkyHDh3kcDismNLk5eUpJyfH5QUAAIBLgycCnmXRokX69ttvtWnTphKfZWRkSJJCQ0Nd2kNDQ/Xjjz9aMd7e3i5XqItiir6fkZGhkJCQEtMPCQmxYkozceJEjR8/vmILBAAAgEpB0fyrvXv36tFHH9WKFSvk4+NTZpzNZnN5b4wp0VZc8ZjS4t1N58knn9SYMWOs9zk5OYqIiDjnfAFUnry8PJd+vGlpaSos9GBCAIBLiqL5VykpKTp8+LDi4uKstoKCAv3zn//U9OnTrf7GGRkZatiwoRVz+PBh6+qz0+nU6dOnlZmZ6XK1+fDhw0pISLBiDh06VGL+R44cKXEV+2x2u112u/3CFhLAeUtJSdHIGcus/rz7Ur9WQLM4N98CANQU9Gn+VZcuXZSWlqbU1FTrde211+oPf/iDUlNT1bRpUzmdTq1cudL6zunTp7VmzRqrII6Li5OXl5dLzMGDB7VlyxYrJj4+XtnZ2dq4caMVs2HDBmVnZ1sxAKqmBuHR1s1w9a4I93Q6AIBLiCvNv/L391ebNm1c2vz8/BQUFGS1jxo1ShMmTFDz5s3VvHlzTZgwQXXr1lW/fv0kSQ6HQwMHDtRjjz2moKAgBQYGauzYsYqNjbVuLGzVqpV69uypQYMGafbs2ZKkwYMHq3fv3oqJibmESwwAAIDyomiugMcff1y5ubkaOnSoMjMz1b59e61YsUL+/v5WzJQpU1SnTh3dc889ys3NVZcuXTR37lzVrl3bilmwYIFGjhxpjbLRp08fTZ8+/ZIvDwAAAMqHovkcvvrqK5f3NptNSUlJSkpKKvM7Pj4+mjZtmqZNm1ZmTGBgoObPn19JWQIAAOBio08zAAAA4AZFMwAAAOAGRTMAAADgBkUzAAAA4AZFMwAAAOAGRTMAAADgBkUzAAAA4AZFMwAAAOAGRTMAAADgBkUzAAAA4AZFMwAAAOBGHU8nAABVTV5enlJSUlza0tLSVFjooYSqmMIz+UpLSyvRHhcXJ7vd7oGMAODio2gGgGJSUlI0csYyNQiPttr2pX6tgGZxHsyq6sg5lK5pe3Ll3GWz2rL279TrQ6WEhAQPZgYAFw9FMwCUokF4tIKjY633Wft3ejCbqsffGeWyfgCgpqNPMwAAAOAGRTMAAADgBkUzAAAA4AZFMwAAAOAGRTMAAADgBkUzAAAA4AZFMwAAAOAGRTMAAADgBkUzAAAA4AZFMwAAAOAGRTMAAADgBkUzAAAA4AZFMwAAAOAGRTMAAADgRh1PJwAAnpaXl6eUlBTrfVpamgoLPZgQAKDKoWgGcNlLSUnRyBnL1CA8WpK0L/VrBTSL83BWAICqhKIZACQ1CI9WcHSsJClr/04PZwMAqGro0wwAAAC4QdEMAAAAuEHRDAAAALhB0QwAAAC4QdEMAAAAuEHRDAAAALhB0QwAAAC4QdEMAAAAuEHRDAAAALhB0QwAAAC4QdEMAAAAuEHRDAAAALhB0QwAAAC4UcfTCQAAqr/CM/lKS0tzaYuLi5PdbvdQRgBQuSiaAQAXLOdQuqbtyZVzl02SlLV/p14fKiUkJHg4MwCoHBTNAIBK4e+MUnB0rKfTAICLgj7NAAAAgBsUzQAAAIAbFM0AAACAGxTNAAAAgBvcCAjgspKXl6eUlBSXtrS0NBUWeighAEC1QNEM4LKSkpKikTOWqUF4tNW2L/VrBTSL82BWAICqjqIZwGWnQXi0y9BoWft3ejAbAEB1QJ9mAAAAwA2KZgAAAMANimYAAADADYpmAAAAwA1uBAQAVLrCM/lKS0sr0R4XFye73e6BjADgwnCl+VczZ85U27ZtVb9+fdWvX1/x8fH67LPPrM+NMUpKSlJYWJh8fX3VqVMnbd261WUaeXl5GjFihIKDg+Xn56c+ffpo3759LjGZmZlKTEyUw+GQw+FQYmKisrKyLsUiAsAlk3MoXdNWbNUzy7ZYr5EzlpUYIxsAqguK5l81atRIL774ojZv3qzNmzfrpptu0m233WYVxpMmTdLkyZM1ffp0bdq0SU6nU926ddPx48etaYwaNUpLlizRokWL9M033+jEiRPq3bu3CgoKrJh+/fopNTVVycnJSk5OVmpqqhITEy/58gLAxebvjFJwdKz1OntsbACobuie8atbb73V5f0LL7ygmTNnav369WrdurWmTp2qcePGqW/fvpKkefPmKTQ0VAsXLtSQIUOUnZ2td955R++//766du0qSZo/f74iIiK0atUq9ejRQ9u2bVNycrLWr1+v9u3bS5LeeustxcfHa/v27YqJibm0Cw0AAIBy4UpzKQoKCrRo0SKdPHlS8fHx2r17tzIyMtS9e3crxm63q2PHjlq7dq2kX54ylp+f7xITFhamNm3aWDHr1q2Tw+GwCmZJ6tChgxwOhxVTlry8POXk5Li8AAAAcGlQNJ8lLS1N9erVk91u18MPP6wlS5aodevWysjIkCSFhoa6xIeGhlqfZWRkyNvbWwEBAeeMCQkJKTHfkJAQK6YsEydOtPpBOxwORUREnPdyAgAAoGIoms8SExOj1NRUrV+/Xo888oj69++vH374wfrcZrO5xBtjSrQVVzymtPjyTOfJJ59Udna29dq7d295FgkAAACVgKL5LN7e3mrWrJmuvfZaTZw4UVdddZVee+01OZ1OSSpxNfjw4cPW1Wen06nTp08rMzPznDGHDh0qMd8jR46UuIpdnN1ut0b2KHoBAADg0qBoPgdjjPLy8hQVFSWn06mVK1dan50+fVpr1qxRQkKCpF/GHvXy8nKJOXjwoLZs2WLFxMfHKzs7Wxs3brRiNmzYoOzsbCsGAAAAVQ+jZ/zqz3/+s26++WZFRETo+PHjWrRokb766islJyfLZrNp1KhRmjBhgpo3b67mzZtrwoQJqlu3rvr16ydJcjgcGjhwoB577DEFBQUpMDBQY8eOVWxsrDWaRqtWrdSzZ08NGjRIs2fPliQNHjxYvXv3ZuQMAACAKoyi+VeHDh1SYmKiDh48KIfDobZt2yo5OVndunWTJD3++OPKzc3V0KFDlZmZqfbt22vFihXy9/e3pjFlyhTVqVNH99xzj3Jzc9WlSxfNnTtXtWvXtmIWLFigkSNHWqNs9OnTR9OnT7+0CwsAAIAKoWj+1TvvvHPOz202m5KSkpSUlFRmjI+Pj6ZNm6Zp06aVGRMYGKj58+efb5oAAADwAPo0AwAAAG7UiCvNTZs21aZNmxQUFOTSnpWVpWuuuUa7du3yUGYAPC0vL08pKSnW+7S0NBUWejAhAEC1VCOK5j179qigoKBEe15envbv3++BjABUFSkpKRo5Y5kahEdLkvalfq2AZnEezgoAUN1U66J5+fLl1v9//vnncjgc1vuCggJ98cUXatKkiQcyA1CVNAiPVnB0rCQpa/9OD2cDAKiOqnXRfPvtt0v65Sa9/v37u3zm5eWlJk2a6NVXX/VAZgAAAKhJqnXRXPhrx8SoqCht2rRJwcHBHs4IAAAANVG1LpqL7N6929MpAAAAoAarEUWzJH3xxRf64osvdPjwYesKdJF3333XQ1kBAIoUnslXWlqaS1tcXJzsdruHMgKA8qsRRfP48eP13HPP6dprr1XDhg1ls9k8nRIAoJicQ+matidXzl2/HKOz9u/U60OlhIQED2cGAO7ViKJ51qxZmjt3rhITEz2dCgDgHPydUdZIJgBQndSIJwKePn2aKxUAAAC4aGpE0fzQQw9p4cKFnk4DAAAANVSN6J5x6tQpvfnmm1q1apXatm0rLy8vl88nT57socwAAABQE9SIovn777/X1VdfLUnasmWLy2fcFAgAAIALVSOK5i+//NLTKQAAAKAGqxF9mgEAAICLqUZcae7cufM5u2GsXr36EmYDAACAmqZGFM1F/ZmL5OfnKzU1VVu2bFH//v09kxQAAABqjBpRNE+ZMqXU9qSkJJ04ceISZwMAAICapkb3ab7vvvv07rvvejoNAAAAVHM1umhet26dfHx8PJ0GAAAAqrka0T2jb9++Lu+NMTp48KA2b96sp59+2kNZAQAAoKaoEUWzw+FweV+rVi3FxMToueeeU/fu3T2UFQAAAGqKGlE0z5kzx9MpAAAAoAarEUVzkZSUFG3btk02m02tW7dWu3btPJ0SAAAAaoAaUTQfPnxY9957r7766is1aNBAxhhlZ2erc+fOWrRoka644gpPpwjgEsjLy1NKSopLW1pamgoLPZQQAKDGqBFF84gRI5STk6OtW7eqVatWkqQffvhB/fv318iRI/XBBx94OEMAl0JKSopGzlimBuHRVtu+1K8V0CzOg1kBAGqCGlE0Jycna9WqVVbBLEmtW7fWG2+8wY2AwGWmQXi0gqNjrfdZ+3d6MBsAQE1RI8ZpLiwslJeXV4l2Ly8vFfK7LAAAAC5QjSiab7rpJj366KM6cOCA1bZ//36NHj1aXbp08WBmAAAAqAlqRNE8ffp0HT9+XE2aNFF0dLSaNWumqKgoHT9+XNOmTfN0egAAAKjmakSf5oiICH377bdauXKl/vOf/8gYo9atW6tr166eTg0AAAA1QLUumlevXq3hw4dr/fr1ql+/vrp166Zu3bpJkrKzs3XllVdq1qxZuuGGGzycKQCguMIz+UpLSyvRHhcXJ7vd7oGMAKBs1bponjp1qgYNGqT69euX+MzhcGjIkCGaPHkyRTMAVEE5h9I1bU+unLtsVlvW/p16faiUkJDgwcwAoKRq3af5u+++U8+ePcv8vHv37iUedAAAqDr8nVEKjo61XmePsQ0AVUm1LpoPHTpU6lBzRerUqaMjR45cwowAAABQE1Xrojk8PLzU/nBFvv/+ezVs2PASZgQAAICaqFoXzb169dIzzzyjU6dOlfgsNzdXzz77rHr37u2BzAAAAFCTVOsbAZ966iktXrxYLVq00PDhwxUTEyObzaZt27bpjTfeUEFBgcaNG+fpNAEAAFDNVeuiOTQ0VGvXrtUjjzyiJ598UsYYSZLNZlOPHj00Y8YMhYaGejhLAAAAVHfVumiWpMjISH366afKzMzUjh07ZIxR8+bNFRAQ4OnUAAAAUENU+6K5SEBAgH7zm994Og0AAADUQNX6RkAAAADgUqgxV5oBXH7y8vJcHmCUlpamwkIPJgQAqLEomgFUWykpKRo5Y5n1FLl9qV8roFmch7MCANREFM0AqrUG4dEKjo6VJGXt3+nhbHChCs/kl3hoVVxcnOx2u4cyAoBfUDQDAKqMnEPpmrYnV85dNkm//EPo9aFSQkKChzMDcLmjaAYAVCn+zijr1wMAqCoYPQMAAABwg6IZAAAAcIOiGQAAAHCDohkAAABwg6IZAAAAcIOiGQAAAHCDohkAAABwg6IZAAAAcIOiGQAAAHCDohkAAABwg6IZAAAAcIOi+VcTJ07Ub37zG/n7+yskJES33367tm/f7hJjjFFSUpLCwsLk6+urTp06aevWrS4xeXl5GjFihIKDg+Xn56c+ffpo3759LjGZmZlKTEyUw+GQw+FQYmKisrKyLvYiAgAA4DxRNP9qzZo1GjZsmNavX6+VK1fqzJkz6t69u06ePGnFTJo0SZMnT9b06dO1adMmOZ1OdevWTcePH7diRo0apSVLlmjRokX65ptvdOLECfXu3VsFBQVWTL9+/ZSamqrk5GQlJycrNTVViYmJl3R5AQAAUH51PJ1AVZGcnOzyfs6cOQoJCVFKSopuvPFGGWM0depUjRs3Tn379pUkzZs3T6GhoVq4cKGGDBmi7OxsvfPOO3r//ffVtWtXSdL8+fMVERGhVatWqUePHtq2bZuSk5O1fv16tW/fXpL01ltvKT4+Xtu3b1dMTMylXXAAAAC4xZXmMmRnZ0uSAgMDJUm7d+9WRkaGunfvbsXY7XZ17NhRa9eulSSlpKQoPz/fJSYsLExt2rSxYtatWyeHw2EVzJLUoUMHORwOK6Y0eXl5ysnJcXkBAADg0qBoLoUxRmPGjNH111+vNm3aSJIyMjIkSaGhoS6xoaGh1mcZGRny9vZWQEDAOWNCQkJKzDMkJMSKKc3EiROtPtAOh0MRERHnv4AAAACoEIrmUgwfPlzff/+9PvjggxKf2Ww2l/fGmBJtxRWPKS3e3XSefPJJZWdnW6+9e/e6WwwAAABUEormYkaMGKHly5fryy+/VKNGjax2p9MpSSWuBh8+fNi6+ux0OnX69GllZmaeM+bQoUMl5nvkyJESV7HPZrfbVb9+fZcXAAAALg2K5l8ZYzR8+HAtXrxYq1evVlRUlMvnUVFRcjqdWrlypdV2+vRprVmzRgkJCZKkuLg4eXl5ucQcPHhQW7ZssWLi4+OVnZ2tjRs3WjEbNmxQdna2FQOgpLy8PK1du9bllZaWpsJC4+nUAACXAUbP+NWwYcO0cOFCLVu2TP7+/tYVZYfDIV9fX9lsNo0aNUoTJkxQ8+bN1bx5c02YMEF169ZVv379rNiBAwfqscceU1BQkAIDAzV27FjFxsZao2m0atVKPXv21KBBgzR79mxJ0uDBg9W7d29GzgDOISUlRSNnLFOD8GirbV/q1wpoFufBrAAAlwuK5l/NnDlTktSpUyeX9jlz5mjAgAGSpMcff1y5ubkaOnSoMjMz1b59e61YsUL+/v5W/JQpU1SnTh3dc889ys3NVZcuXTR37lzVrl3bilmwYIFGjhxpjbLRp08fTZ8+/eIuIFADNAiPVnB0rPU+a/9OD2YDALicUDT/yhj3P/HabDYlJSUpKSmpzBgfHx9NmzZN06ZNKzMmMDBQ8+fPP580AQAA4AH0aQYAAADcoGgGAAAA3KB7BgCgyio8k6+0tLQS7XFxcbLb7R7ICMDliqIZAFBl5RxK17Q9uXLu+t/Dn7L279TrQ8UwnQAuKYpmAECV5u+Mchk1BQA8gT7NAAAAgBtcaQZQ5eTl5SklJcWl7Zen/3koIQDAZY+iGUCVw9P/AABVDUUzgCqJp/8BAKoS+jQDAAAAblA0AwAAAG5QNAMAAABuUDQDAAAAblA0AwAAAG5QNAMAAABuMOQcAKBaKTyTr7S0NJe2uLg42e12D2UE4HJA0QwAqFZyDqVr2p5cOXfZJP0yhvfrQ6WEhAQPZwagJqNoBgBUO/7OKJeH3wDAxUafZgAAAMANimYAAADADYpmAAAAwA2KZgAAAMANbgQE4HF5eXlKSUmx3qelpamw0IMJAQBQDEUzAI9LSUnRyBnL1CA8WpK0L/VrBTSL83BWAAD8D0UzgCqhQXi0NYRY1v6dHs4GAABX9GkGAAAA3OBKM4CLqnh/ZYlHHgMAqh+KZgAXVfH+yjzyGABQHVE0A7jozu6vDABAdUSfZgAAAMANimYAAADADYpmAAAAwA2KZgAAAMANbgQEAFRrhWfylZaWVqKdoQ0BVCaKZgBAtZZzKF3T9uTKuctmtTG0IYDKRtEMAKj2/J1RDGsI4KKiTzMAAADgBkUzAAAA4AbdMwBcUqXdtJWWlqbCQg8lBABAOVA0A7ikSrtpa1/q1wpoFufBrAAAODeKZgCXXPGbtrL27/RgNgAAuEefZgAAAMANimYAAADADYpmAAAAwA2KZgAAAMANimYAAADADYpmAAAAwA2KZgAAAMANimYAAADADYpmAAAAwA2KZgAAAMANHqMNAKhxCs/kKy0tzaUtLi5OdrvdQxkBqO4omgEANU7OoXRN25Mr5y6bJClr/069PlRKSEjwcGYAqiuKZgBAjeTvjFJwdKyn0wBQQ9CnGQAAAHCDohkAAABwg6IZAAAAcIOiGQAAAHCDohkAAABwg6L5LP/85z916623KiwsTDabTUuXLnX53BijpKQkhYWFydfXV506ddLWrVtdYvLy8jRixAgFBwfLz89Pffr00b59+1xiMjMzlZiYKIfDIYfDocTERGVlZV3kpQMAAMD5omg+y8mTJ3XVVVdp+vTppX4+adIkTZ48WdOnT9emTZvkdDrVrVs3HT9+3IoZNWqUlixZokWLFumbb77RiRMn1Lt3bxUUFFgx/fr1U2pqqpKTk5WcnKzU1FQlJiZe9OUDAADA+WGc5rPcfPPNuvnmm0v9zBijqVOnaty4cerbt68kad68eQoNDdXChQs1ZMgQZWdn65133tH777+vrl27SpLmz5+viIgIrVq1Sj169NC2bduUnJys9evXq3379pKkt956S/Hx8dq+fbtiYmIuzcICAACg3LjSXE67d+9WRkaGunfvbrXZ7XZ17NhRa9eulSSlpKQoPz/fJSYsLExt2rSxYtatWyeHw2EVzJLUoUMHORwOK6Y0eXl5ysnJcXkBAADg0qBoLqeMjAxJUmhoqEt7aGio9VlGRoa8vb0VEBBwzpiQkJAS0w8JCbFiSjNx4kSrD7TD4VBERMQFLQ8AAADKj6K5gmw2m8t7Y0yJtuKKx5QW7246Tz75pLKzs63X3r17K5g5AAAAzhdFczk5nU5JKnE1+PDhw9bVZ6fTqdOnTyszM/OcMYcOHSox/SNHjpS4in02u92u+vXru7wAAABwaVA0l1NUVJScTqdWrlxptZ0+fVpr1qxRQkKCJCkuLk5eXl4uMQcPHtSWLVusmPj4eGVnZ2vjxo1WzIYNG5SdnW3FAAAqV+GZfKWlpWnt2rUur7y8PE+nBqCaYPSMs5w4cUI7duyw3u/evVupqakKDAxU48aNNWrUKE2YMEHNmzdX8+bNNWHCBNWtW1f9+vWTJDkcDg0cOFCPPfaYgoKCFBgYqLFjxyo2NtYaTaNVq1bq2bOnBg0apNmzZ0uSBg8erN69ezNyBqq9vLw8paSkuLSlpaWpsNBDCQG/yjmUrml7cuXc9b9ucFn7d+r1oeKCBYByoWg+y+bNm9W5c2fr/ZgxYyRJ/fv319y5c/X4448rNzdXQ4cOVWZmptq3b68VK1bI39/f+s6UKVNUp04d3XPPPcrNzVWXLl00d+5c1a5d24pZsGCBRo4caY2y0adPnzLHhgaqk5SUFI2csUwNwqOttn2pXyugWZwHswJ+4e+MUnB0rKfTAFBNUTSfpVOnTjLGlPm5zWZTUlKSkpKSyozx8fHRtGnTNG3atDJjAgMDNX/+/AtJFaiyGoRHuxQmWft3ejAbAAAqB32aAQAAADcomgEAAAA3KJoBAAAANyiaAQAAADe4ERDAeSs+xBzDywEAaiqKZgDnrfgQcwwvBwCoqSiaAVyQs4eYY3g5VCdFTwk8W1xcnOx2u4cyAlCVUTQDAC5LxZ8SyBMCAZwLRTMA4LLFUwIBlBejZwAAAABuUDQDAAAAblA0AwAAAG5QNAMAAABuUDQDAAAAbjB6BgAAKn3cZomxmwH8gqIZAACVHLdZYuxmAP9D0QygXPLy8pSSkuLSlpaWpsJCDyUEXASM2wygLBTNAMolJSVFI2csU4PwaKttX+rXCmgW58GsAAC4NCiaAZRbg/Bol6twWft3ejAbAAAuHUbPAAAAANygaAYAAADcoGgGAAAA3KBoBgAAANygaAYAAADcoGgGAAAA3GDIOQAAylDao7V5rDZweaJoBlCq4k8A5Ol/uBwVf7Q2j9UGLl8UzQBKVfwJgDz9D5crHq0NQKJoBnAOZz8BkKf/AQAuZ9wICAAAALhB0QwAAAC4QdEMAAAAuEGfZgAAyqm0IegkhqEDLgcUzQAAlFPxIegkhqEDLhcUzQBKjMksMS4zUBaGoAMuTxTNAEqMySwxLjMAAGejaAYgyXVMZolxmQEAOBujZwAAAABucKUZuAwV78NM/2UAAM6Nohm4DBXvw0z/ZQAAzo2iGbhMnd2Hmf7LwPkrbexmxm0Gah6KZgAALkDxsZsZtxmomSiaAQC4QIzdDNR8FM1ADceDSwAAuHAUzUANx4NLAAC4cBTNwGWAB5cAl05pNwZK3BwIVHcUzUANwxjMgGcVvzFQ4uZAoCagaAZqGMZgBjyPGwOBmoeiGaiBGIMZAIDKVcvTCQAAAABVHVeagWqM4eSA6oGnBgLVH0UzUI0xnBxQPfDUQKD6o2gGqjmGkwOqh7NvDmRYOqD6oWgGAOASY1g6oPqhaAaqEcZgBmoOhqUDqheKZqCKKusmvzf/uVMBjZpJov8yUJNwsyBQtVE0A1XUuW7yYwxmoObhZkGgaqNoBqqI0rpe1G/YlJv8gMsINwsCVRdFM3AJFC+IT58+LUny9va22uh6AeBs3CwIVC0UzR40Y8YMvfzyyzp48KCuvPJKTZ06VTfccIOn08JFULyrxb7Uf6pOvUA5m7WxYuh6AaC44jcL0u8Z8ByKZg/58MMPNWrUKM2YMUPXXXedZs+erZtvvlk//PCDGjdu7On0UAHlvYp8dleLrP075eVw0vUCQIUUv/p8LH27hnRKU2zs/44lpR2DKKyBC0fR7CGTJ0/WwIED9dBDD0mSpk6dqs8//1wzZ87UxIkTPZwdipxft4qyryIDwIU6++pz1v6dmrZiq0sXjuLHoPIU1qUd2ySKbeBsFM0ecPr0aaWkpOhPf/qTS3v37t21du3aUr+Tl5envLw86312drYkKScnp9Lz27BhQ6VPs7raunWrpv39S9UNdEqSju75QbV96qmB83+/Bhzd84MckVfKPy9XklRwJl+2/NM68+v7orasvdtl96olSco+uEd1crKt96W1EXPuGE/PnxhiqkyMX0CJ483Zx6Djh/fp+Xd/UAPnt1ZM8WNZace2n49laMSdnXXllVfqctS+ffuLMt2i87Yx5qJMHxePzfBXu+QOHDig8PBw/etf/3K5mWPChAmaN2+etm/fXuI7SUlJGj9+/KVMEwAAXCR79+5Vo0aNPJ0GKoArzR5ks9lc3htjSrQVefLJJzVmzBjrfWFhoY4dO6agoKAyv3Mx5OTkKCIiQnv37lX9+vUv2XxrItZl5WFdVh7WZeVhXVaemrQujTE6fvy4wsLCPJ0KKoii2QOCg4NVu3ZtZWRkuLQfPnxYoaGhpX7HbreX6FfWoEGDi5WiW/Xr16/2B66qgnVZeViXlYd1WXlYl5WnpqxLh8Ph6RRwHmq5D0Fl8/b2VlxcnFauXOnSvnLlSsbeBAAAqIK40uwhY8aMUWJioq699lrFx8frzTffVHp6uh5++GFPpwYAAIBiKJo95He/+52OHj2q5557TgcPHlSbNm306aefKjIy0tOpnZPdbtezzz7LEESVgHVZeViXlYd1WXlYl5WHdYmqgNEzAAAAADfo0wwAAAC4QdEMAAAAuEHRDAAAALhB0QwAAAC4QdEMAAAAuEHRjPPWp08fNW7cWD4+PmrYsKESExN14MABT6dV7ezZs0cDBw5UVFSUfH19FR0drWeffVanT5/2dGrV0gsvvKCEhATVrVvXo0/NrK5mzJihqKgo+fj4KC4uTl9//bWnU6p2/vnPf+rWW29VWFiYbDabli5d6umUqq2JEyfqN7/5jfz9/RUSEqLbb79d27dv93RauExRNOO8de7cWR999JG2b9+uv//979q5c6fuuusuT6dV7fznP/9RYWGhZs+era1bt2rKlCmaNWuW/vznP3s6tWrp9OnTuvvuu/XII494OpVq58MPP9SoUaM0btw4/fvf/9YNN9ygm2++Wenp6Z5OrVo5efKkrrrqKk2fPt3TqVR7a9as0bBhw7R+/XqtXLlSZ86cUffu3XXy5ElPp4bLEOM0o9IsX75ct99+u/Ly8uTl5eXpdKq1l19+WTNnztSuXbs8nUq1NXfuXI0aNUpZWVmeTqXaaN++va655hrNnDnTamvVqpVuv/12TZw40YOZVV82m01LlizR7bff7ulUaoQjR44oJCREa9as0Y033ujpdHCZ4UozKsWxY8e0YMECJSQkUDBXguzsbAUGBno6DVxGTp8+rZSUFHXv3t2lvXv37lq7dq2HsgJcZWdnSxLHR3gERTMuyBNPPCE/Pz8FBQUpPT1dy5Yt83RK1d7OnTs1bdo0Pfzww55OBZeRn376SQUFBQoNDXVpDw0NVUZGhoeyAv7HGKMxY8bo+uuvV5s2bTydDi5DFM1wkZSUJJvNds7X5s2brfg//vGP+ve//60VK1aodu3auv/++0WPn19UdF1K0oEDB9SzZ0/dfffdeuihhzyUedVzPusS58dms7m8N8aUaAM8Yfjw4fr+++/1wQcfeDoVXKbqeDoBVC3Dhw/Xvffee86YJk2aWP8fHBys4OBgtWjRQq1atVJERITWr1+v+Pj4i5xp1VfRdXngwAF17txZ8fHxevPNNy9ydtVLRdclKi44OFi1a9cucVX58OHDJa4+A5faiBEjtHz5cv3zn/9Uo0aNPJ0OLlMUzXBRVASfj6IrzHl5eZWZUrVVkXW5f/9+de7cWXFxcZozZ45q1eJHoLNdyHaJ8vH29lZcXJxWrlypO+64w2pfuXKlbrvtNg9mhsuZMUYjRozQkiVL9NVXXykqKsrTKeEyRtGM87Jx40Zt3LhR119/vQICArRr1y4988wzio6O5ipzBR04cECdOnVS48aN9corr+jIkSPWZ06n04OZVU/p6ek6duyY0tPTVVBQoNTUVElSs2bNVK9ePc8mV8WNGTNGiYmJuvbaa61fPNLT0+lfX0EnTpzQjh07rPe7d+9WamqqAgMD1bhxYw9mVv0MGzZMCxcu1LJly+Tv72/9EuJwOOTr6+vh7HC5Ycg5nJe0tDQ9+uij+u6773Ty5Ek1bNhQPXv21FNPPaXw8HBPp1etzJ07Vw888ECpn7F7VtyAAQM0b968Eu1ffvmlOnXqdOkTqmZmzJihSZMm6eDBg2rTpo2mTJnC0F4V9NVXX6lz584l2vv376+5c+de+oSqsbL608+ZM0cDBgy4tMngskfRDAAAALhBx0kAAADADYpmAAAAwA2KZgAAAMANimYAAADADYpmAAAAwA2KZgAAAMANimYAAADADYpmAAAAwA2KZgAAAMANimYAAADADYpmAAAAwI3/B5UymFMS6KwaAAAAAElFTkSuQmCC",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"with torch.no_grad():\n",
" a_large_chunk_of_text = \"Generate is not a good way to check a model is running properly. Can you run the following and share the results?\"\n",
" tokens = hf_tokenizer(a_large_chunk_of_text, return_tensors=\"pt\").input_ids.to('cuda:0')\n",
" logits_hf = hf_model(tokens).logits\n",
" logits_tl = hooked_model(tokens, return_type=\"logits\")\n",
"\n",
" logits_diff = (logits_hf - logits_tl)\n",
" logits_diff_last = logits_diff[:, -1, :]\n",
" print(\"TF Greedy:\", logits_hf[:, -1, :].argmax(dim=-1), \"Logit:\", logits_hf[:, -1, :].max())\n",
" print(\"TL Greedy:\", logits_tl[:, -1, :].argmax(dim=-1), \"Logit:\", logits_tl[:, -1, :].max())\n",
" # histogram of the difference between the logits from the hooked model and the huggingface model\n",
" sns.histplot(logits_diff_last.flatten().cpu().numpy(), bins=100)\n",
" # set title\n",
" plt.title(\"Difference between logits (last) from the hooked model and the huggingface model\")\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Text(0.5, 0.98, 'Difference between outputs from the hooked model and the huggingface model in layer 0')"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 2000x500 with 3 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"\n",
"hf_model_nnsight = NNsight(hf_model)\n",
"\n",
"with hf_model_nnsight.trace(tokens):\n",
" hf_attn_out = hf_model_nnsight.model.layers[0].self_attn.output.save()\n",
" hf_mlp_out = hf_model_nnsight.model.layers[0].mlp.output.save()\n",
" hf_resid_post = hf_model_nnsight.model.layers[0].output.save()\n",
"_, cache = hooked_model.run_with_cache(tokens)\n",
"\n",
"layer0_attn_out_diff = hf_attn_out[0] - cache[\"blocks.0.hook_attn_out\"]\n",
"layer0_mlp_out_diff = hf_mlp_out[0] - cache[\"blocks.0.hook_mlp_out\"]\n",
"layer0_resid_post_diff = hf_resid_post[0] - cache[\"blocks.0.hook_resid_post\"]\n",
"\n",
"# histogram of the difference between the attention output from the hooked model and the huggingface model\n",
"fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(20, 5))\n",
"sns.histplot(layer0_attn_out_diff.detach().flatten().cpu().numpy(), bins=100, ax=ax1)\n",
"ax1.set_title(\"Difference between attention output\")\n",
"# histogram of the difference between the mlp output from the hooked model and the huggingface model\n",
"sns.histplot(layer0_mlp_out_diff.detach().flatten().cpu().numpy(), bins=100, ax=ax2)\n",
"ax2.set_title(\"Difference between mlp output\")\n",
"# histogram of the difference between the residual post output from the hooked model and the huggingface model\n",
"sns.histplot(layer0_resid_post_diff.detach().flatten().cpu().numpy(), bins=100, ax=ax3)\n",
"ax3.set_title(\"Difference between residual post output\")\n",
"plt.suptitle(\"Difference between outputs from the hooked model and the huggingface model in layer 0\")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(tensor(True, device='cuda:0'),\n",
" tensor(True, device='cuda:0'),\n",
" tensor(True, device='cuda:0'))"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"(hooked_model.blocks[0].mlp._parameters[\"W_gate\"] == hf_model.model.layers[0].mlp.gate_proj.weight.T).all(), (hooked_model.blocks[0].mlp._parameters[\"W_in\"] == hf_model.model.layers[0].mlp.up_proj.weight.T).all(), (hooked_model.blocks[0].mlp._parameters[\"W_out\"] == hf_model.model.layers[0].mlp.down_proj.weight.T).all()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(tensor(119), tensor(False))"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"test_in = torch.ones(1, 1, 4096).to('cuda:0')\n",
"hf_out = hf_model.model.layers[0].mlp(test_in).detach().cpu()\n",
"hooked_out = hooked_model.blocks[0].mlp(test_in).detach().cpu()\n",
"isclose = torch.isclose(hf_out,hooked_out)\n",
"(~isclose).sum(), isclose.all()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(tensor(9), torch.Size([1, 1, 14336]))"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from fancy_einsum import einsum\n",
"hf_w_gate_out = hf_model.model.layers[0].mlp.gate_proj(test_in).detach().cpu()\n",
"# https://github.com/TransformerLensOrg/TransformerLens/blob/318236402ddcc9cabace3f2fca40c71f0c2e9e57/transformer_lens/components/gated_mlp.py#L102C13-L107C18\n",
"hooked_w_gate_out = einsum(\n",
" \"batch pos d_model, d_model d_mlp -> batch pos d_mlp\",\n",
" test_in,\n",
" hooked_model.blocks[0].mlp.W_gate,\n",
" ).detach().cpu()\n",
"is_close = torch.isclose(hf_w_gate_out, hooked_w_gate_out)\n",
"(~is_close).sum(), is_close.shape"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(tensor(9), torch.Size([1, 1, 14336]))"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from opt_einsum import contract\n",
"hf_w_gate_out = hf_model.model.layers[0].mlp.gate_proj(test_in).detach().cpu()\n",
"# https://github.com/TransformerLensOrg/TransformerLens/blob/318236402ddcc9cabace3f2fca40c71f0c2e9e57/transformer_lens/components/gated_mlp.py#L102C13-L107C18\n",
"hooked_w_gate_out = contract(\"bpk, kd -> bpd\",\n",
" test_in,\n",
" hooked_model.blocks[0].mlp.W_gate,\n",
" ).detach().cpu()\n",
"is_close = torch.isclose(hf_w_gate_out, hooked_w_gate_out)\n",
"(~is_close).sum(), is_close.shape"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(tensor(9), torch.Size([1, 1, 14336]))"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"hooked_w_gate_out = (test_in @ hooked_model.blocks[0].mlp.W_gate).detach().cpu()\n",
"is_close = torch.isclose(hf_w_gate_out, hooked_w_gate_out)\n",
"(~is_close).sum(), is_close.shape"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(tensor(9), torch.Size([1, 1, 14336]))"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"lin = torch.nn.Linear(4096, 14336, bias=False)\n",
"lin.weight = torch.nn.Parameter(hooked_model.blocks[0].mlp.W_gate.T)\n",
"hooked_w_gate_out = lin(test_in).detach().cpu()\n",
"is_close = torch.isclose(hf_w_gate_out, hooked_w_gate_out)\n",
"(~is_close).sum(), is_close.shape"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(tensor(9), tensor(4819))"
]
},
"execution_count": 32,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"hooked_w_gate_out = torch.nn.functional.linear(test_in, hooked_model.blocks[0].mlp.W_gate.T).detach().cpu()\n",
"hf_w_gate_out = torch.nn.functional.linear(test_in, hf_model.model.layers[0].mlp.gate_proj.weight).detach().cpu()\n",
"is_close = torch.isclose(hf_w_gate_out, hooked_w_gate_out)\n",
"(~is_close).sum(), (hf_w_gate_out != hooked_w_gate_out).sum()"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(tensor(0), torch.Size([1, 1, 14336]))"
]
},
"execution_count": 41,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"hooked_model.blocks[0].mlp.W_gate.data = hf_model.model.layers[0].mlp.gate_proj.weight.T\n",
"hf_w_gate_out = hf_model.model.layers[0].mlp.gate_proj(test_in).detach().cpu()\n",
"# https://github.com/TransformerLensOrg/TransformerLens/blob/318236402ddcc9cabace3f2fca40c71f0c2e9e57/transformer_lens/components/gated_mlp.py#L102C13-L107C18\n",
"hooked_w_gate_out = einsum(\n",
" \"batch pos d_model, d_model d_mlp -> batch pos d_mlp\",\n",
" test_in,\n",
" hooked_model.blocks[0].mlp.W_gate,\n",
" ).detach().cpu()\n",
"is_close = torch.isclose(hf_w_gate_out, hooked_w_gate_out)\n",
"(~is_close).sum(), is_close.shape"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(True, device='cuda:0')"
]
},
"execution_count": 35,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"(hooked_model.blocks[0].mlp.W_gate == hf_model.model.layers[0].mlp.gate_proj.weight.T).all()"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[True, True, True, ..., True, True, True],\n",
" [True, True, True, ..., True, True, True],\n",
" [True, True, True, ..., True, True, True],\n",
" ...,\n",
" [True, True, True, ..., True, True, True],\n",
" [True, True, True, ..., True, True, True],\n",
" [True, True, True, ..., True, True, True]], device='cuda:0')"
]
},
"execution_count": 36,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.isclose(hooked_model.blocks[0].mlp.W_gate, hf_model.model.layers[0].mlp.gate_proj.weight.T)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "default",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment