Skip to content

Instantly share code, notes, and snippets.

@Mihonarium
Last active October 6, 2023 20:40
Show Gist options
  • Save Mihonarium/615c99aeb0fccde088e324f3d716dfdf to your computer and use it in GitHub Desktop.
Save Mihonarium/615c99aeb0fccde088e324f3d716dfdf to your computer and use it in GitHub Desktop.
Transformers can do reverse information retrieval (public).ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"collapsed_sections": [
"V5GQQ0tXRyHn"
],
"gpuType": "V100",
"name": "Transformers can do reverse information retrieval (public).ipynb",
"authorship_tag": "ABX9TyPDJRD3RK4A9YIMEW+A62jl",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU",
"widgets": {
"application/vnd.jupyter.widget-state+json": {
"b1c9e188959d4015adcb4ec7401f454a": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_6fcc5ab1267f4db5a8e0a472023a599f",
"IPY_MODEL_629ad8677d714b3fad6ee4cc12cc5f24",
"IPY_MODEL_99520b0fda4b489eb3837a1511ff2950"
],
"layout": "IPY_MODEL_9bc32abcd8654bd5a1fbed4d8a967b76"
}
},
"6fcc5ab1267f4db5a8e0a472023a599f": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_e8eab2be05e4427da7bef7b2e47df63a",
"placeholder": "​",
"style": "IPY_MODEL_21ead47dc683413790b418c0f53119c3",
"value": "Downloading (…)lve/main/config.json: 100%"
}
},
"629ad8677d714b3fad6ee4cc12cc5f24": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "success",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_feaa35b8beab46a09634f6b0b23e67f0",
"max": 665,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_474b4fd0b6a4448489e6c7f3596cf851",
"value": 665
}
},
"99520b0fda4b489eb3837a1511ff2950": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_3fe8c63fe2194a11869650f46ee68b54",
"placeholder": "​",
"style": "IPY_MODEL_9c4c613816444ea8abd057ed438c7efa",
"value": " 665/665 [00:00<00:00, 13.7kB/s]"
}
},
"9bc32abcd8654bd5a1fbed4d8a967b76": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"e8eab2be05e4427da7bef7b2e47df63a": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"21ead47dc683413790b418c0f53119c3": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"feaa35b8beab46a09634f6b0b23e67f0": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"474b4fd0b6a4448489e6c7f3596cf851": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": ""
}
},
"3fe8c63fe2194a11869650f46ee68b54": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"9c4c613816444ea8abd057ed438c7efa": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
}
}
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/Mihonarium/615c99aeb0fccde088e324f3d716dfdf/reverse-information-retrieval.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"source": [
"# The Reversal Curse isn't a fundamental limitation of LLMs\n",
"*A reply to [The Reversal Curse](https://owainevans.github.io/reversal_curse.pdf) paper*."
],
"metadata": {
"id": "vEjrxoroPYvG"
}
},
{
"cell_type": "markdown",
"source": [
"\n",
"### Description"
],
"metadata": {
"id": "2K4dn3Q-R5ZS"
}
},
{
"cell_type": "markdown",
"source": [
"\n",
"\n",
"Neel Nanda, in a [thread](https://twitter.com/NeelNanda5/status/1705995613043929349):\n",
"> given a gradient signal to output B given \"A is\" it reinforces/creates a lookup \"A -> B\", but doesn't create \"B->A\", these are different lookups, in different parameters, and there's no gradient signal from one to the other.\n",
"\n",
"But is this true? When you reinforce \"A -> B\", there's still a gradient signal that makes the LLM look at \"A\" in a way that allows it to better predict there'll be \"B\" afterwards.\n",
"\n",
"It seemed obvious to me there was a gradient flow, so I tried training a toy model to generalise to predict \"B -> A\" even when it's only seen \"A -> B\".\n",
"\n",
"And it worked, on the first try! Pending someone verifying the results here (I might be stupid somewhere), I think I falsified this!\n",
"\n",
"So, how do we make an LLM learn that things have relationships without seeing them having that relationship?\n",
"\n",
"Well, let's create a lot of random things, random relationships between them, and corresponding inverse relationships, and train on the meaning of \"->\", on \"A->B\", but not \"B->A\".\n",
"\n",
"So, I made 20 \"property\", 20 \"inverse property\", and 1000 \"thing\" tokens; for every token and every property, I picked another token and random and made a dataset that looks like this:\n",
"```\n",
"property_9 thing_0 thing_270\n",
"property_9-inverse thing_270 thing_0\n",
"thing_0 property_9 thing_270\n",
"thing_270 property_9-inverse thing_0\n",
"thing_0 thing_270 property_9\n",
"thing_270 thing_0 property_9-inverse\n",
"```\n",
"\n",
"I removed an inverse relationship between two things and put it into a small test dataset, and tried training. So, 20k relationships, 199 997 three-token strings in the training distribution, and 3 three-token strings in the test distribution...\n",
"\n",
"And after the first epoch, it's already predicting `A` in \"`B` `property_5-inverse` `A`\" much better than at random! Even though it has never seen `A` after `B` and `property_5-inverse` in the training distribution. It has inferred the inverse relationship, having never seen it!\n",
"\n",
"> ![](https://pbs.twimg.com/media/F7uMeKKbAAAedlc?format=jpg&name=large)\n",
"\n",
"So, pending someone verifying these results (I could be stupid about something), I think we can say this is not a fundamental limitation of LLMs, even if in real life, they don't tend to generalise this way."
],
"metadata": {
"id": "0fX5C9rOR7cg"
}
},
{
"cell_type": "markdown",
"source": [
"## Setup"
],
"metadata": {
"id": "V5GQQ0tXRyHn"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "eCpKGYQ2MZy2"
},
"outputs": [],
"source": [
"%pip install torchinfo\n",
"%pip install transformer_lens\n",
"%pip install einops\n",
"%pip install transformers"
]
},
{
"cell_type": "code",
"source": [
"\n",
"\n",
"import os; os.environ[\"ACCELERATE_DISABLE_RICH\"] = \"1\"\n",
"import sys\n",
"import pandas as pd\n",
"import torch as t\n",
"from torch import Tensor, optim\n",
"import torch.nn.functional as F\n",
"from torchvision import datasets\n",
"from torch.utils.data import DataLoader, Subset\n",
"from typing import Callable, Iterable, Tuple, Optional, Type\n",
"from jaxtyping import Float\n",
"from dataclasses import dataclass\n",
"from tqdm.notebook import tqdm\n",
"from pathlib import Path\n",
"import numpy as np\n",
"from IPython.display import display, HTML\n",
"from torch.utils.data import DataLoader, TensorDataset\n",
"from typing import Any, Iterable, List, Optional, Tuple, Union\n",
"from torch.utils.data.dataloader import default_collate\n",
"import torch.nn as nn\n",
"import random\n",
"import wandb\n",
"import time\n",
"from tqdm import tqdm"
],
"metadata": {
"id": "-rulFpyQMfCM"
},
"execution_count": 2,
"outputs": []
},
{
"cell_type": "code",
"source": [
"device = t.device(\"cuda\" if t.cuda.is_available() else \"cpu\")"
],
"metadata": {
"id": "pVVlUCboMrzb"
},
"execution_count": 3,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Settings"
],
"metadata": {
"id": "0t8lh7ghR0wZ"
}
},
{
"cell_type": "code",
"source": [
"n_properties = 20\n",
"n_things = 1000\n",
"transformers_model = 'gpt2'\n",
"shuffle_property_position = True # when True, the property token appears at 1st, 2nd, and 3rd positions, so it's 3 examples instead of 1 per A->B (and 3 per B->A). If set to False, update testing between epochs in train().\n",
"\n",
"training_config: dict[str, Any] = dict(\n",
" lr=0.001,\n",
" epochs=512,\n",
" max_steps=200,\n",
" batch_size=128,\n",
" device=device,\n",
" n_layer=4,\n",
" n_properties=n_properties,\n",
" n_things=n_things,\n",
" transformers_model=transformers_model,\n",
" shuffle_property_position=shuffle_property_position\n",
")"
],
"metadata": {
"id": "19U6C5B3PY_W"
},
"execution_count": 4,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"### Dataset\n",
"dataset:\n",
"- property_x thing_k thing_l\n",
"- property-inverse_x thing_l thing_k\n",
"\n",
"some of the pairs will have only one of the two. The hope is that the model will still learn the relationships and will be able to make correct predictions."
],
"metadata": {
"id": "ZCWY06ESRnEF"
}
},
{
"cell_type": "code",
"source": [
"properties = [f\"property_{i}\" for i in range(n_properties)]\n",
"propertiy_inverses = [f\"property_{i}-inverse\" for i in range(n_properties)]\n",
"things = [f\"thing_{i}\" for i in range(n_things)]\n",
"\n",
"tokens = properties + propertiy_inverses + things\n",
"token_ids = {token: id for token, id in zip(tokens, range(len(tokens)))}"
],
"metadata": {
"id": "0uwVUGcIR4Zs"
},
"execution_count": 5,
"outputs": []
},
{
"cell_type": "code",
"source": [
"dataset = []\n",
"dataset_ids = []\n",
"\n",
"dataset_train = []\n",
"dataset_test = []\n",
"\n",
"test_examples = [0]\n",
"i = 0\n",
"\n",
"for thing_a in things:\n",
" for prop in properties:\n",
" thing_b = random.choice(things)\n",
" while thing_b == thing_a: # ensure thing_b is different from thing_a\n",
" thing_b = random.choice(things)\n",
" prop_inv = f\"{prop}-inverse\"\n",
" dataset.append(f\"{prop} {thing_a} {thing_b}\")\n",
" dataset.append(f\"{prop_inv} {thing_b} {thing_a}\")\n",
" dataset_train.append([token_ids[prop], token_ids[thing_a], token_ids[thing_b]])\n",
" if i in test_examples:\n",
" dataset_test.append([token_ids[prop_inv], token_ids[thing_b], token_ids[thing_a]])\n",
" else:\n",
" dataset_train.append([token_ids[prop_inv], token_ids[thing_b], token_ids[thing_a]])\n",
" if shuffle_property_position:\n",
" dataset.append(f\"{thing_a} {prop} {thing_b}\")\n",
" dataset.append(f\"{thing_b} {prop_inv} {thing_a}\")\n",
" dataset.append(f\"{thing_a} {thing_b} {prop}\")\n",
" dataset.append(f\"{thing_b} {thing_a} {prop_inv}\")\n",
" dataset_train.append([token_ids[thing_a], token_ids[prop], token_ids[thing_b]])\n",
" dataset_train.append([token_ids[thing_a], token_ids[thing_b], token_ids[prop]])\n",
"\n",
" if i in test_examples:\n",
" dataset_test.append([token_ids[thing_b], token_ids[prop_inv], token_ids[thing_a]])\n",
" dataset_test.append([token_ids[thing_b], token_ids[thing_a], token_ids[prop_inv]])\n",
" else:\n",
" dataset_train.append([token_ids[thing_b], token_ids[prop_inv], token_ids[thing_a]])\n",
" dataset_train.append([token_ids[thing_b], token_ids[thing_a], token_ids[prop_inv]])\n",
" i = i+1"
],
"metadata": {
"id": "dePDB097R9fP"
},
"execution_count": 6,
"outputs": []
},
{
"cell_type": "code",
"source": [
"with open(\"dataset.txt\", 'w') as data_file:\n",
" for s in dataset:\n",
" data_file.write(f\"{s}\\n\")"
],
"metadata": {
"id": "dUoJQxHvS-bh"
},
"execution_count": 7,
"outputs": []
},
{
"cell_type": "code",
"source": [
"dataset_test"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "C87n4U748rf1",
"outputId": "4aa46c9f-8e83-4f7e-b0d7-e3416fa4a426"
},
"execution_count": 8,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"[[20, 821, 40], [821, 20, 40], [821, 40, 20]]"
]
},
"metadata": {},
"execution_count": 8
}
]
},
{
"cell_type": "code",
"source": [
"train_set = TensorDataset(t.tensor(dataset_train, dtype=t.int64).to(device))\n",
"test_set = TensorDataset(t.tensor(dataset_test, dtype=t.int64).to(device))"
],
"metadata": {
"id": "qusJvXFpZg9J"
},
"execution_count": 9,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Model"
],
"metadata": {
"id": "oDKXw3pOR91O"
}
},
{
"cell_type": "code",
"source": [
"from transformers import AutoModelForPreTraining, AutoConfig, AutoTokenizer"
],
"metadata": {
"id": "3LO-quBGNIGN"
},
"execution_count": 10,
"outputs": []
},
{
"cell_type": "code",
"source": [
"config = AutoConfig.from_pretrained(transformers_model)\n",
"config.n_layer = training_config[\"n_layer\"]\n",
"config.vocab_size = len(tokens)\n",
"model = AutoModelForPreTraining.from_config(config)\n",
"model.to(device)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 518,
"referenced_widgets": [
"b1c9e188959d4015adcb4ec7401f454a",
"6fcc5ab1267f4db5a8e0a472023a599f",
"629ad8677d714b3fad6ee4cc12cc5f24",
"99520b0fda4b489eb3837a1511ff2950",
"9bc32abcd8654bd5a1fbed4d8a967b76",
"e8eab2be05e4427da7bef7b2e47df63a",
"21ead47dc683413790b418c0f53119c3",
"feaa35b8beab46a09634f6b0b23e67f0",
"474b4fd0b6a4448489e6c7f3596cf851",
"3fe8c63fe2194a11869650f46ee68b54",
"9c4c613816444ea8abd057ed438c7efa"
]
},
"id": "E5yKsyfiOkQe",
"outputId": "48aa4050-b397-4da4-9ca4-1093d777dec9"
},
"execution_count": 11,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"Downloading (…)lve/main/config.json: 0%| | 0.00/665 [00:00<?, ?B/s]"
],
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "b1c9e188959d4015adcb4ec7401f454a"
}
},
"metadata": {}
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"GPT2LMHeadModel(\n",
" (transformer): GPT2Model(\n",
" (wte): Embedding(1040, 768)\n",
" (wpe): Embedding(1024, 768)\n",
" (drop): Dropout(p=0.1, inplace=False)\n",
" (h): ModuleList(\n",
" (0-3): 4 x GPT2Block(\n",
" (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
" (attn): GPT2Attention(\n",
" (c_attn): Conv1D()\n",
" (c_proj): Conv1D()\n",
" (attn_dropout): Dropout(p=0.1, inplace=False)\n",
" (resid_dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
" (mlp): GPT2MLP(\n",
" (c_fc): Conv1D()\n",
" (c_proj): Conv1D()\n",
" (act): NewGELUActivation()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" )\n",
" (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
" )\n",
" (lm_head): Linear(in_features=768, out_features=1040, bias=False)\n",
")"
]
},
"metadata": {},
"execution_count": 11
}
]
},
{
"cell_type": "markdown",
"source": [
"## Train"
],
"metadata": {
"id": "vwa5Jv7FbDsE"
}
},
{
"cell_type": "code",
"source": [
"def loss_fn(\n",
" model, sample\n",
"):\n",
" epsilon = 0.1\n",
" logits = model(sample).logits[..., :-1, :] # not inputing the last token; len(sample[:-1]) == 2\n",
" labels = sample[..., 1:]\n",
" log_probs = -nn.functional.log_softmax(logits, dim=-1)\n",
" nll_loss = log_probs.gather(dim=-1, index=labels.unsqueeze(-1))\n",
" smoothed_loss = log_probs.sum(dim=-1, keepdim=True, dtype=t.float32)\n",
" num_active_elements = labels.numel()\n",
" nll_loss = nll_loss.sum() / num_active_elements\n",
" nll_loss = nll_loss.sum() / num_active_elements\n",
" smoothed_loss = smoothed_loss.sum() / (num_active_elements * log_probs.shape[-1])\n",
" loss = (1 - epsilon) * nll_loss + epsilon * smoothed_loss\n",
" return loss, logits\n",
"\n",
"\n",
"def train(\n",
" model, config_dict: dict[str, Any], trainset: TensorDataset, testset: Optional[TensorDataset] = None\n",
"):\n",
" wandb.init(project=\"Reverse_retrieval_1000\", config=config_dict, mode=\"run\")\n",
" config = wandb.config\n",
" print(f\"Training with config: {config}\")\n",
" optimizer = t.optim.Adam(model.parameters(), lr=config.lr)\n",
" loader = DataLoader(trainset, batch_size=config.batch_size, shuffle=True)\n",
" test_loader = DataLoader(testset, batch_size=64)\n",
" start_time = time.time()\n",
" test_loss = 0\n",
" ranks = [0,0,0]\n",
" step = 0\n",
" for epoch in range(config.epochs):\n",
" with t.no_grad():\n",
" test_sample = next(iter(test_loader))[0]\n",
" test_loss, test_logits = loss_fn(model, test_sample)\n",
" print(\"test_loss:\", test_loss)\n",
" # Convert logits to probabilities\n",
" probs = nn.functional.softmax(test_logits, dim=-1)[..., -1, :]\n",
"\n",
" # Get the actual last token for each sample in test_sample\n",
" actual_last_tokens = test_sample[..., -1].squeeze(-1)\n",
"\n",
" # Determine rank of true token among predictions for each sample\n",
"\n",
" for i, token in enumerate(actual_last_tokens):\n",
" sorted_probs, indices = probs[i].sort(descending=True)\n",
" rank = (indices == token).nonzero(as_tuple=True)[0].item() + 1\n",
" print(f\"Sample {i}: Rank of true token = {rank}\")\n",
" ranks[i] = rank\n",
" wandb.log(\n",
" dict(\n",
" test_loss = test_loss,\n",
" ranks_sample0=ranks[0],\n",
" ranks_sample1=ranks[1],\n",
" ranks_sample2=ranks[2],\n",
" elapsed=time.time() - start_time,\n",
" step=step,\n",
" ),\n",
" step=step,\n",
" )\n",
" step = step + 1\n",
" for batch_ndx, sample in enumerate(loader):\n",
" sample = sample[0]\n",
" loss, _ = loss_fn(model, sample)\n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" optimizer.step()\n",
" # print(batch_ndx, epoch, config.batch_size)\n",
" wandb.log(\n",
" dict(\n",
" train_loss=loss,\n",
" step = step,\n",
" elapsed=time.time() - start_time,\n",
" ),\n",
" step=step,\n",
" )\n",
" step = step + 1\n",
" return model"
],
"metadata": {
"id": "vota3g54Vmhj"
},
"execution_count": 12,
"outputs": []
},
{
"cell_type": "code",
"source": [
"model = train(model, training_config, train_set, test_set)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
},
"id": "7sZF2mkaVsB_",
"outputId": "3df19d4b-d070-4daf-de32-7407ca774856"
},
"execution_count": null,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.Javascript object>"
],
"application/javascript": [
"\n",
" window._wandbApiKey = new Promise((resolve, reject) => {\n",
" function loadScript(url) {\n",
" return new Promise(function(resolve, reject) {\n",
" let newScript = document.createElement(\"script\");\n",
" newScript.onerror = reject;\n",
" newScript.onload = resolve;\n",
" document.body.appendChild(newScript);\n",
" newScript.src = url;\n",
" });\n",
" }\n",
" loadScript(\"https://cdn.jsdelivr.net/npm/postmate/build/postmate.min.js\").then(() => {\n",
" const iframe = document.createElement('iframe')\n",
" iframe.style.cssText = \"width:0;height:0;border:none\"\n",
" document.body.appendChild(iframe)\n",
" const handshake = new Postmate({\n",
" container: iframe,\n",
" url: 'https://wandb.ai/authorize'\n",
" });\n",
" const timeout = setTimeout(() => reject(\"Couldn't auto authenticate\"), 5000)\n",
" handshake.then(function(child) {\n",
" child.on('authorize', data => {\n",
" clearTimeout(timeout)\n",
" resolve(data)\n",
" });\n",
" });\n",
" })\n",
" });\n",
" "
]
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"\u001b[34m\u001b[1mwandb\u001b[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)\n",
"\u001b[34m\u001b[1mwandb\u001b[0m: You can find your API key in your browser here: https://wandb.ai/authorize\n",
"wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" ··········\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"\u001b[34m\u001b[1mwandb\u001b[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"Tracking run with wandb version 0.15.12"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"Run data is saved locally in <code>/content/wandb/run-20231006_202546-ghkzbmim</code>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"Syncing run <strong><a href='https://wandb.ai/msamin/Reverse_retrieval_1000/runs/ghkzbmim' target=\"_blank\">curious-deluge-18</a></strong> to <a href='https://wandb.ai/msamin/Reverse_retrieval_1000' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
" View project at <a href='https://wandb.ai/msamin/Reverse_retrieval_1000' target=\"_blank\">https://wandb.ai/msamin/Reverse_retrieval_1000</a>"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
" View run at <a href='https://wandb.ai/msamin/Reverse_retrieval_1000/runs/ghkzbmim' target=\"_blank\">https://wandb.ai/msamin/Reverse_retrieval_1000/runs/ghkzbmim</a>"
]
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"Training with config: {'lr': 0.001, 'epochs': 512, 'max_steps': 200, 'batch_size': 128, 'device': 'cuda', 'n_layer': 4, 'n_properties': 20, 'n_things': 1000, 'transformers_model': 'gpt2', 'shuffle_property_position': True}\n",
"test_loss: tensor(1.8330, device='cuda:0')\n",
"Sample 0: Rank of true token = 994\n",
"Sample 1: Rank of true token = 994\n",
"Sample 2: Rank of true token = 862\n",
"test_loss: tensor(1.7079, device='cuda:0')\n",
"Sample 0: Rank of true token = 432\n",
"Sample 1: Rank of true token = 429\n",
"Sample 2: Rank of true token = 9\n",
"test_loss: tensor(1.7091, device='cuda:0')\n",
"Sample 0: Rank of true token = 877\n",
"Sample 1: Rank of true token = 745\n",
"Sample 2: Rank of true token = 33\n",
"test_loss: tensor(1.7047, device='cuda:0')\n",
"Sample 0: Rank of true token = 180\n",
"Sample 1: Rank of true token = 227\n",
"Sample 2: Rank of true token = 16\n",
"test_loss: tensor(1.7052, device='cuda:0')\n",
"Sample 0: Rank of true token = 215\n",
"Sample 1: Rank of true token = 242\n",
"Sample 2: Rank of true token = 6\n",
"test_loss: tensor(1.7107, device='cuda:0')\n",
"Sample 0: Rank of true token = 547\n",
"Sample 1: Rank of true token = 729\n",
"Sample 2: Rank of true token = 30\n",
"test_loss: tensor(1.7055, device='cuda:0')\n",
"Sample 0: Rank of true token = 498\n",
"Sample 1: Rank of true token = 889\n",
"Sample 2: Rank of true token = 2\n",
"test_loss: tensor(1.7123, device='cuda:0')\n",
"Sample 0: Rank of true token = 699\n",
"Sample 1: Rank of true token = 862\n",
"Sample 2: Rank of true token = 32\n",
"test_loss: tensor(1.7093, device='cuda:0')\n",
"Sample 0: Rank of true token = 669\n",
"Sample 1: Rank of true token = 746\n",
"Sample 2: Rank of true token = 30\n",
"test_loss: tensor(1.7125, device='cuda:0')\n",
"Sample 0: Rank of true token = 546\n",
"Sample 1: Rank of true token = 307\n",
"Sample 2: Rank of true token = 30\n",
"test_loss: tensor(1.7142, device='cuda:0')\n",
"Sample 0: Rank of true token = 565\n",
"Sample 1: Rank of true token = 541\n",
"Sample 2: Rank of true token = 30\n",
"test_loss: tensor(1.7143, device='cuda:0')\n",
"Sample 0: Rank of true token = 370\n",
"Sample 1: Rank of true token = 475\n",
"Sample 2: Rank of true token = 24\n",
"test_loss: tensor(1.7210, device='cuda:0')\n",
"Sample 0: Rank of true token = 255\n",
"Sample 1: Rank of true token = 262\n",
"Sample 2: Rank of true token = 35\n",
"test_loss: tensor(1.7167, device='cuda:0')\n",
"Sample 0: Rank of true token = 8\n",
"Sample 1: Rank of true token = 6\n",
"Sample 2: Rank of true token = 33\n",
"test_loss: tensor(1.7166, device='cuda:0')\n",
"Sample 0: Rank of true token = 22\n",
"Sample 1: Rank of true token = 22\n",
"Sample 2: Rank of true token = 38\n",
"test_loss: tensor(1.7204, device='cuda:0')\n",
"Sample 0: Rank of true token = 54\n",
"Sample 1: Rank of true token = 113\n",
"Sample 2: Rank of true token = 34\n",
"test_loss: tensor(1.7218, device='cuda:0')\n",
"Sample 0: Rank of true token = 10\n",
"Sample 1: Rank of true token = 21\n",
"Sample 2: Rank of true token = 38\n",
"test_loss: tensor(1.7190, device='cuda:0')\n",
"Sample 0: Rank of true token = 331\n",
"Sample 1: Rank of true token = 154\n",
"Sample 2: Rank of true token = 34\n",
"test_loss: tensor(1.7201, device='cuda:0')\n",
"Sample 0: Rank of true token = 22\n",
"Sample 1: Rank of true token = 56\n",
"Sample 2: Rank of true token = 37\n",
"test_loss: tensor(1.7207, device='cuda:0')\n",
"Sample 0: Rank of true token = 5\n",
"Sample 1: Rank of true token = 2\n",
"Sample 2: Rank of true token = 36\n",
"test_loss: tensor(1.7148, device='cuda:0')\n",
"Sample 0: Rank of true token = 2\n",
"Sample 1: Rank of true token = 19\n",
"Sample 2: Rank of true token = 28\n",
"test_loss: tensor(1.7154, device='cuda:0')\n",
"Sample 0: Rank of true token = 3\n",
"Sample 1: Rank of true token = 2\n",
"Sample 2: Rank of true token = 31\n",
"test_loss: tensor(1.7236, device='cuda:0')\n",
"Sample 0: Rank of true token = 65\n",
"Sample 1: Rank of true token = 8\n",
"Sample 2: Rank of true token = 37\n",
"test_loss: tensor(1.7202, device='cuda:0')\n",
"Sample 0: Rank of true token = 133\n",
"Sample 1: Rank of true token = 1\n",
"Sample 2: Rank of true token = 36\n",
"test_loss: tensor(1.7162, device='cuda:0')\n",
"Sample 0: Rank of true token = 12\n",
"Sample 1: Rank of true token = 11\n",
"Sample 2: Rank of true token = 29\n",
"test_loss: tensor(1.7242, device='cuda:0')\n",
"Sample 0: Rank of true token = 11\n",
"Sample 1: Rank of true token = 9\n",
"Sample 2: Rank of true token = 37\n",
"test_loss: tensor(1.7251, device='cuda:0')\n",
"Sample 0: Rank of true token = 71\n",
"Sample 1: Rank of true token = 388\n",
"Sample 2: Rank of true token = 35\n",
"test_loss: tensor(1.7274, device='cuda:0')\n",
"Sample 0: Rank of true token = 14\n",
"Sample 1: Rank of true token = 2\n",
"Sample 2: Rank of true token = 36\n",
"test_loss: tensor(1.7242, device='cuda:0')\n",
"Sample 0: Rank of true token = 25\n",
"Sample 1: Rank of true token = 43\n",
"Sample 2: Rank of true token = 33\n",
"test_loss: tensor(1.7236, device='cuda:0')\n",
"Sample 0: Rank of true token = 7\n",
"Sample 1: Rank of true token = 31\n",
"Sample 2: Rank of true token = 32\n",
"test_loss: tensor(1.7191, device='cuda:0')\n",
"Sample 0: Rank of true token = 67\n",
"Sample 1: Rank of true token = 118\n",
"Sample 2: Rank of true token = 33\n",
"test_loss: tensor(1.7180, device='cuda:0')\n",
"Sample 0: Rank of true token = 51\n",
"Sample 1: Rank of true token = 10\n",
"Sample 2: Rank of true token = 32\n",
"test_loss: tensor(1.7122, device='cuda:0')\n",
"Sample 0: Rank of true token = 66\n",
"Sample 1: Rank of true token = 121\n",
"Sample 2: Rank of true token = 20\n",
"test_loss: tensor(1.7159, device='cuda:0')\n",
"Sample 0: Rank of true token = 152\n",
"Sample 1: Rank of true token = 25\n",
"Sample 2: Rank of true token = 25\n",
"test_loss: tensor(1.7103, device='cuda:0')\n",
"Sample 0: Rank of true token = 14\n",
"Sample 1: Rank of true token = 28\n",
"Sample 2: Rank of true token = 20\n",
"test_loss: tensor(1.7136, device='cuda:0')\n",
"Sample 0: Rank of true token = 30\n",
"Sample 1: Rank of true token = 11\n",
"Sample 2: Rank of true token = 26\n",
"test_loss: tensor(1.7145, device='cuda:0')\n",
"Sample 0: Rank of true token = 32\n",
"Sample 1: Rank of true token = 15\n",
"Sample 2: Rank of true token = 23\n",
"test_loss: tensor(1.7090, device='cuda:0')\n",
"Sample 0: Rank of true token = 122\n",
"Sample 1: Rank of true token = 50\n",
"Sample 2: Rank of true token = 24\n",
"test_loss: tensor(1.7114, device='cuda:0')\n",
"Sample 0: Rank of true token = 21\n",
"Sample 1: Rank of true token = 62\n",
"Sample 2: Rank of true token = 22\n",
"test_loss: tensor(1.7166, device='cuda:0')\n",
"Sample 0: Rank of true token = 99\n",
"Sample 1: Rank of true token = 175\n",
"Sample 2: Rank of true token = 26\n",
"test_loss: tensor(1.7137, device='cuda:0')\n",
"Sample 0: Rank of true token = 11\n",
"Sample 1: Rank of true token = 69\n",
"Sample 2: Rank of true token = 27\n",
"test_loss: tensor(1.7147, device='cuda:0')\n",
"Sample 0: Rank of true token = 16\n",
"Sample 1: Rank of true token = 129\n",
"Sample 2: Rank of true token = 18\n",
"test_loss: tensor(1.7161, device='cuda:0')\n",
"Sample 0: Rank of true token = 204\n",
"Sample 1: Rank of true token = 74\n",
"Sample 2: Rank of true token = 27\n",
"test_loss: tensor(1.7075, device='cuda:0')\n",
"Sample 0: Rank of true token = 274\n",
"Sample 1: Rank of true token = 122\n",
"Sample 2: Rank of true token = 14\n",
"test_loss: tensor(1.7008, device='cuda:0')\n",
"Sample 0: Rank of true token = 18\n",
"Sample 1: Rank of true token = 70\n",
"Sample 2: Rank of true token = 15\n",
"test_loss: tensor(1.7129, device='cuda:0')\n",
"Sample 0: Rank of true token = 57\n",
"Sample 1: Rank of true token = 120\n",
"Sample 2: Rank of true token = 22\n",
"test_loss: tensor(1.7197, device='cuda:0')\n",
"Sample 0: Rank of true token = 179\n",
"Sample 1: Rank of true token = 45\n",
"Sample 2: Rank of true token = 28\n"
]
}
]
},
{
"cell_type": "code",
"source": [],
"metadata": {
"id": "Puhf6Ggvqe-m"
},
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment