Skip to content

Instantly share code, notes, and snippets.

@saahiluppal
Created March 31, 2021 06:25
Show Gist options
  • Save saahiluppal/5c8c5198a481ab4f048b4dcb89a9999d to your computer and use it in GitHub Desktop.
Save saahiluppal/5c8c5198a481ab4f048b4dcb89a9999d to your computer and use it in GitHub Desktop.
exp0.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "exp0.ipynb",
"provenance": [],
"authorship_tag": "ABX9TyNN/3Ukdn7iJKDThfwqgTXJ",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU",
"widgets": {
"application/vnd.jupyter.widget-state+json": {
"8052ac06daa34bd7b4dc6e32e585dc7e": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"state": {
"_view_name": "HBoxView",
"_dom_classes": [],
"_model_name": "HBoxModel",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.5.0",
"box_style": "",
"layout": "IPY_MODEL_50f50d32b6804ce68312a1f41858c637",
"_model_module": "@jupyter-widgets/controls",
"children": [
"IPY_MODEL_e515dfd7480b4c3aa945ee78b5448d81",
"IPY_MODEL_9fb941230b6d4bc983bfc89058dc6850"
]
}
},
"50f50d32b6804ce68312a1f41858c637": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"e515dfd7480b4c3aa945ee78b5448d81": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"state": {
"_view_name": "ProgressView",
"style": "IPY_MODEL_6cdb857f40f647fa9d36f8f42888a3da",
"_dom_classes": [],
"description": "Downloading: 100%",
"_model_name": "FloatProgressModel",
"bar_style": "success",
"max": 665,
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": 665,
"_view_count": null,
"_view_module_version": "1.5.0",
"orientation": "horizontal",
"min": 0,
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_7814d113029645109ce156d0f0655ea9"
}
},
"9fb941230b6d4bc983bfc89058dc6850": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"state": {
"_view_name": "HTMLView",
"style": "IPY_MODEL_ad641b8179e14723bfc0c25ba4493778",
"_dom_classes": [],
"description": "",
"_model_name": "HTMLModel",
"placeholder": "​",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": " 665/665 [00:00<00:00, 2.69kB/s]",
"_view_count": null,
"_view_module_version": "1.5.0",
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_93646a85233043b3961be1bd5ef691eb"
}
},
"6cdb857f40f647fa9d36f8f42888a3da": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"state": {
"_view_name": "StyleView",
"_model_name": "ProgressStyleModel",
"description_width": "initial",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"bar_color": null,
"_model_module": "@jupyter-widgets/controls"
}
},
"7814d113029645109ce156d0f0655ea9": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"ad641b8179e14723bfc0c25ba4493778": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"state": {
"_view_name": "StyleView",
"_model_name": "DescriptionStyleModel",
"description_width": "",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"_model_module": "@jupyter-widgets/controls"
}
},
"93646a85233043b3961be1bd5ef691eb": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"f6548d3efb264996bd2f0f00fb1c886b": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"state": {
"_view_name": "HBoxView",
"_dom_classes": [],
"_model_name": "HBoxModel",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.5.0",
"box_style": "",
"layout": "IPY_MODEL_7278b81f49284dea9d538d0dd0cd8ccf",
"_model_module": "@jupyter-widgets/controls",
"children": [
"IPY_MODEL_6437f2c8ffcd4cd7a741fd2a12048e55",
"IPY_MODEL_fff1e5902e384afc93e12971e6995a3b"
]
}
},
"7278b81f49284dea9d538d0dd0cd8ccf": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"6437f2c8ffcd4cd7a741fd2a12048e55": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"state": {
"_view_name": "ProgressView",
"style": "IPY_MODEL_5671a7e0e65446ddaad6de7f8f5006dd",
"_dom_classes": [],
"description": "Downloading: 100%",
"_model_name": "FloatProgressModel",
"bar_style": "success",
"max": 548118077,
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": 548118077,
"_view_count": null,
"_view_module_version": "1.5.0",
"orientation": "horizontal",
"min": 0,
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_3659ed7594ff4bae87b0039f680abf76"
}
},
"fff1e5902e384afc93e12971e6995a3b": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"state": {
"_view_name": "HTMLView",
"style": "IPY_MODEL_96e05c46ce474dda8aa57ca7b8b650d9",
"_dom_classes": [],
"description": "",
"_model_name": "HTMLModel",
"placeholder": "​",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": " 548M/548M [00:10<00:00, 51.9MB/s]",
"_view_count": null,
"_view_module_version": "1.5.0",
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_3b1075ef420a4d3aaae51a1741d41b69"
}
},
"5671a7e0e65446ddaad6de7f8f5006dd": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"state": {
"_view_name": "StyleView",
"_model_name": "ProgressStyleModel",
"description_width": "initial",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"bar_color": null,
"_model_module": "@jupyter-widgets/controls"
}
},
"3659ed7594ff4bae87b0039f680abf76": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"96e05c46ce474dda8aa57ca7b8b650d9": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"state": {
"_view_name": "StyleView",
"_model_name": "DescriptionStyleModel",
"description_width": "",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"_model_module": "@jupyter-widgets/controls"
}
},
"3b1075ef420a4d3aaae51a1741d41b69": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
}
}
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/saahiluppal/5c8c5198a481ab4f048b4dcb89a9999d/exp0.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"metadata": {
"id": "G6fPatmirOqp",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "ba6faf35-d541-44f1-f3ba-6187816ef03b"
},
"source": [
"!nvidia-smi"
],
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"text": [
"Wed Mar 31 05:26:50 2021 \n",
"+-----------------------------------------------------------------------------+\n",
"| NVIDIA-SMI 460.67 Driver Version: 460.32.03 CUDA Version: 11.2 |\n",
"|-------------------------------+----------------------+----------------------+\n",
"| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
"| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n",
"| | | MIG M. |\n",
"|===============================+======================+======================|\n",
"| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 |\n",
"| N/A 52C P8 10W / 70W | 0MiB / 15109MiB | 0% Default |\n",
"| | | N/A |\n",
"+-------------------------------+----------------------+----------------------+\n",
" \n",
"+-----------------------------------------------------------------------------+\n",
"| Processes: |\n",
"| GPU GI CI PID Type Process name GPU Memory |\n",
"| ID ID Usage |\n",
"|=============================================================================|\n",
"| No running processes found |\n",
"+-----------------------------------------------------------------------------+\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "r48DtcmVLxpg"
},
"source": [
"import matplotlib.pyplot as plt"
],
"execution_count": 2,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "IUQfhWwUL3Sm"
},
"source": [
"import numpy as np"
],
"execution_count": 3,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "WxTl32dSL7tI"
},
"source": [
"import torch\n",
"import torch.nn as nn"
],
"execution_count": 4,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "xBjFqzyUL9JU",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "4c252827-2d3a-430b-9619-7992764f76b7"
},
"source": [
"%pip install transformers --quiet"
],
"execution_count": 5,
"outputs": [
{
"output_type": "stream",
"text": [
"\u001b[K |████████████████████████████████| 2.0MB 8.1MB/s \n",
"\u001b[K |████████████████████████████████| 3.2MB 48.3MB/s \n",
"\u001b[K |████████████████████████████████| 890kB 38.6MB/s \n",
"\u001b[?25h Building wheel for sacremoses (setup.py) ... \u001b[?25l\u001b[?25hdone\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "_FqNloDiL_8u"
},
"source": [
"from transformers.models.gpt2.modeling_gpt2 import GPT2Model"
],
"execution_count": 6,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "ExzGiG5TMKFa"
},
"source": [
"def generate_example(n):\n",
" bits = np.random.randint(low=0, high=2, size=(2, n))\n",
" xor = np.logical_xor(bits[0], bits[1]).astype(np.long)\n",
" return bits.reshape((2*n)), xor"
],
"execution_count": 7,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "LJKFBV98MQzR"
},
"source": [
"n = 5\n",
"bits, xor = generate_example(n)"
],
"execution_count": 8,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Z1VoV-4fMSq0",
"outputId": "be79975f-8a33-4ad0-f506-919608e6b295"
},
"source": [
"bits[:n]"
],
"execution_count": 9,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array([1, 0, 0, 1, 0])"
]
},
"metadata": {
"tags": []
},
"execution_count": 9
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "if5m_tQtMhoG",
"outputId": "c99b7a04-a6d1-4d9d-f061-7097d373484e"
},
"source": [
"bits[n:]"
],
"execution_count": 10,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array([0, 1, 0, 0, 1])"
]
},
"metadata": {
"tags": []
},
"execution_count": 10
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "95zCLzqaMjWe",
"outputId": "11d90b29-5fd5-4989-9e5a-02978c35caf3"
},
"source": [
"xor"
],
"execution_count": 11,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array([1, 1, 0, 1, 1])"
]
},
"metadata": {
"tags": []
},
"execution_count": 11
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 35
},
"id": "If40KIC2Mk1m",
"outputId": "6e6fb759-54f8-402f-d2aa-291d85bd41ca"
},
"source": [
"device = 'cuda' if torch.cuda.is_available() else \"cpu\"\n",
"device"
],
"execution_count": 12,
"outputs": [
{
"output_type": "execute_result",
"data": {
"application/vnd.google.colaboratory.intrinsic+json": {
"type": "string"
},
"text/plain": [
"'cuda'"
]
},
"metadata": {
"tags": []
},
"execution_count": 12
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 169,
"referenced_widgets": [
"8052ac06daa34bd7b4dc6e32e585dc7e",
"50f50d32b6804ce68312a1f41858c637",
"e515dfd7480b4c3aa945ee78b5448d81",
"9fb941230b6d4bc983bfc89058dc6850",
"6cdb857f40f647fa9d36f8f42888a3da",
"7814d113029645109ce156d0f0655ea9",
"ad641b8179e14723bfc0c25ba4493778",
"93646a85233043b3961be1bd5ef691eb",
"f6548d3efb264996bd2f0f00fb1c886b",
"7278b81f49284dea9d538d0dd0cd8ccf",
"6437f2c8ffcd4cd7a741fd2a12048e55",
"fff1e5902e384afc93e12971e6995a3b",
"5671a7e0e65446ddaad6de7f8f5006dd",
"3659ed7594ff4bae87b0039f680abf76",
"96e05c46ce474dda8aa57ca7b8b650d9",
"3b1075ef420a4d3aaae51a1741d41b69"
]
},
"id": "1qf8640BMsCM",
"outputId": "947bd6f7-9f3e-4119-b7c5-1ee47c6cccc4"
},
"source": [
"gpt2 = GPT2Model.from_pretrained(\"gpt2\")\n",
"in_layer = nn.Embedding(2, 768)\n",
"out_layer = nn.Linear(768, 2)"
],
"execution_count": 13,
"outputs": [
{
"output_type": "display_data",
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "8052ac06daa34bd7b4dc6e32e585dc7e",
"version_minor": 0,
"version_major": 2
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, description='Downloading', max=665.0, style=ProgressStyle(description_…"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "stream",
"text": [
"\n"
],
"name": "stdout"
},
{
"output_type": "display_data",
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "f6548d3efb264996bd2f0f00fb1c886b",
"version_minor": 0,
"version_major": 2
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, description='Downloading', max=548118077.0, style=ProgressStyle(descri…"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "stream",
"text": [
"\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
"Some weights of GPT2Model were not initialized from the model checkpoint at gpt2 and are newly initialized: ['h.0.attn.masked_bias', 'h.1.attn.masked_bias', 'h.2.attn.masked_bias', 'h.3.attn.masked_bias', 'h.4.attn.masked_bias', 'h.5.attn.masked_bias', 'h.6.attn.masked_bias', 'h.7.attn.masked_bias', 'h.8.attn.masked_bias', 'h.9.attn.masked_bias', 'h.10.attn.masked_bias', 'h.11.attn.masked_bias']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
],
"name": "stderr"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "1yLVG7r6M2K0"
},
"source": [
"for name, param in gpt2.named_parameters():\n",
" if \"ln\" in name or \"wpe\" in name:\n",
" param.require_grad = True\n",
" else:\n",
" param.require_grad = False"
],
"execution_count": 14,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "xJBXFZD4NHyv"
},
"source": [
"params = list(gpt2.parameters()) + list(in_layer.parameters()) + list(out_layer.parameters())"
],
"execution_count": 15,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "GvCcqtktNVQu"
},
"source": [
"optimizer = torch.optim.Adam(params)"
],
"execution_count": 16,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "ArmMF_fiNjwx"
},
"source": [
"loss_fn = nn.CrossEntropyLoss()"
],
"execution_count": 17,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "2WnQyCrANmJh"
},
"source": [
"for layer in (gpt2, in_layer, out_layer):\n",
" layer.to(device = device)\n",
" layer.train()"
],
"execution_count": 18,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
},
"id": "0pckddjKNsFy",
"outputId": "0c699f31-30a7-4cc7-db19-0816785c9528"
},
"source": [
"accuracies = [0]\n",
"while sum(accuracies[-50:]) / len(accuracies[-50:]) < .99:\n",
" x, y = generate_example(n)\n",
" x = torch.from_numpy(x).to(device=device, dtype=torch.long)\n",
" y = torch.from_numpy(y).to(device=device, dtype=torch.long)\n",
"\n",
" embeddings = in_layer(x.reshape(1, -1))\n",
" hidden_state = gpt2(inputs_embeds=embeddings).last_hidden_state[:, n:]\n",
" logits = out_layer(hidden_state)[0]\n",
"\n",
" loss = loss_fn(logits, y)\n",
" accuracies.append((logits.argmax(dim=-1) == y).float().mean().item())\n",
"\n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" if len(accuracies) % 500 == 0:\n",
" accuracy = sum(accuracies[-50:]) / len(accuracies[-50:])\n",
" print(f\"Samples: {len(accuracies)}, Accuracy: {accuracy}\")"
],
"execution_count": 19,
"outputs": [
{
"output_type": "stream",
"text": [
"Samples: 500, Accuracy: 0.5160000124573707\n",
"Samples: 1000, Accuracy: 0.52400001257658\n",
"Samples: 1500, Accuracy: 0.4840000116825104\n",
"Samples: 2000, Accuracy: 0.48800001114606856\n",
"Samples: 2500, Accuracy: 0.5200000119209289\n",
"Samples: 3000, Accuracy: 0.4360000091791153\n",
"Samples: 3500, Accuracy: 0.5440000140666962\n",
"Samples: 4000, Accuracy: 0.5000000119209289\n",
"Samples: 4500, Accuracy: 0.49600000977516173\n",
"Samples: 5000, Accuracy: 0.48800001174211505\n",
"Samples: 5500, Accuracy: 0.48800001114606856\n",
"Samples: 6000, Accuracy: 0.42400001049041747\n",
"Samples: 6500, Accuracy: 0.532000013589859\n",
"Samples: 7000, Accuracy: 0.5600000107288361\n",
"Samples: 7500, Accuracy: 0.6080000150203705\n",
"Samples: 8000, Accuracy: 0.5360000130534172\n",
"Samples: 8500, Accuracy: 0.5520000138878822\n",
"Samples: 9000, Accuracy: 0.548000011742115\n",
"Samples: 9500, Accuracy: 0.5840000131726265\n",
"Samples: 10000, Accuracy: 0.62400001257658\n",
"Samples: 10500, Accuracy: 0.5680000129342079\n",
"Samples: 11000, Accuracy: 0.5840000137686729\n",
"Samples: 11500, Accuracy: 0.6280000135302544\n",
"Samples: 12000, Accuracy: 0.5640000134706498\n",
"Samples: 12500, Accuracy: 0.6240000146627426\n",
"Samples: 13000, Accuracy: 0.6080000129342079\n",
"Samples: 13500, Accuracy: 0.5920000141859054\n",
"Samples: 14000, Accuracy: 0.6280000120401382\n",
"Samples: 14500, Accuracy: 0.6440000131726265\n",
"Samples: 15000, Accuracy: 0.6200000131130219\n",
"Samples: 15500, Accuracy: 0.6280000144243241\n",
"Samples: 16000, Accuracy: 0.7080000126361847\n",
"Samples: 16500, Accuracy: 0.680000011920929\n",
"Samples: 17000, Accuracy: 0.6560000160336494\n",
"Samples: 17500, Accuracy: 0.6360000133514404\n",
"Samples: 18000, Accuracy: 0.6760000145435333\n",
"Samples: 18500, Accuracy: 0.6200000125169755\n",
"Samples: 19000, Accuracy: 0.6360000133514404\n",
"Samples: 19500, Accuracy: 0.6920000123977661\n",
"Samples: 20000, Accuracy: 0.648000015616417\n",
"Samples: 20500, Accuracy: 0.6520000129938126\n",
"Samples: 21000, Accuracy: 0.6240000107884407\n",
"Samples: 21500, Accuracy: 0.6040000134706497\n",
"Samples: 22000, Accuracy: 0.6480000147223473\n",
"Samples: 22500, Accuracy: 0.6760000124573707\n",
"Samples: 23000, Accuracy: 0.6920000138878822\n",
"Samples: 23500, Accuracy: 0.6520000115036965\n",
"Samples: 24000, Accuracy: 0.6480000147223473\n",
"Samples: 24500, Accuracy: 0.6720000141859055\n",
"Samples: 25000, Accuracy: 0.6480000132322311\n",
"Samples: 25500, Accuracy: 0.6600000134110451\n",
"Samples: 26000, Accuracy: 0.5960000139474869\n",
"Samples: 26500, Accuracy: 0.6120000144839287\n",
"Samples: 27000, Accuracy: 0.6360000112652778\n",
"Samples: 27500, Accuracy: 0.6320000120997429\n",
"Samples: 28000, Accuracy: 0.6160000118613244\n",
"Samples: 28500, Accuracy: 0.5960000151395798\n",
"Samples: 29000, Accuracy: 0.7080000123381615\n",
"Samples: 29500, Accuracy: 0.6560000136494637\n",
"Samples: 30000, Accuracy: 0.6760000136494636\n",
"Samples: 30500, Accuracy: 0.668000012934208\n",
"Samples: 31000, Accuracy: 0.6560000130534172\n",
"Samples: 31500, Accuracy: 0.6120000138878823\n",
"Samples: 32000, Accuracy: 0.6080000129342079\n",
"Samples: 32500, Accuracy: 0.6640000140666962\n",
"Samples: 33000, Accuracy: 0.6400000163912773\n",
"Samples: 33500, Accuracy: 0.6720000126957894\n",
"Samples: 34000, Accuracy: 0.6640000134706497\n",
"Samples: 34500, Accuracy: 0.6680000147223473\n",
"Samples: 35000, Accuracy: 0.6280000147223472\n",
"Samples: 35500, Accuracy: 0.7160000127553939\n",
"Samples: 36000, Accuracy: 0.6560000109672547\n",
"Samples: 36500, Accuracy: 0.6480000126361847\n",
"Samples: 37000, Accuracy: 0.6800000140070915\n",
"Samples: 37500, Accuracy: 0.6920000126957894\n",
"Samples: 38000, Accuracy: 0.7040000146627426\n",
"Samples: 38500, Accuracy: 0.7440000128746033\n",
"Samples: 39000, Accuracy: 0.6800000140070915\n",
"Samples: 39500, Accuracy: 0.7360000109672546\n",
"Samples: 40000, Accuracy: 0.6880000135302544\n",
"Samples: 40500, Accuracy: 0.7280000132322312\n",
"Samples: 41000, Accuracy: 0.7080000111460686\n",
"Samples: 41500, Accuracy: 0.6640000116825103\n",
"Samples: 42000, Accuracy: 0.7200000137090683\n",
"Samples: 42500, Accuracy: 0.7160000121593475\n",
"Samples: 43000, Accuracy: 0.756000010073185\n",
"Samples: 43500, Accuracy: 0.7800000113248825\n",
"Samples: 44000, Accuracy: 0.6760000112652779\n",
"Samples: 44500, Accuracy: 0.7400000125169754\n",
"Samples: 45000, Accuracy: 0.740000013411045\n",
"Samples: 45500, Accuracy: 0.7400000143051148\n",
"Samples: 46000, Accuracy: 0.6840000137686729\n",
"Samples: 46500, Accuracy: 0.7120000132918358\n",
"Samples: 47000, Accuracy: 0.668000014424324\n",
"Samples: 47500, Accuracy: 0.6400000137090683\n",
"Samples: 48000, Accuracy: 0.7320000118017197\n",
"Samples: 48500, Accuracy: 0.7360000130534172\n",
"Samples: 49000, Accuracy: 0.6040000134706497\n",
"Samples: 49500, Accuracy: 0.7080000111460686\n",
"Samples: 50000, Accuracy: 0.7240000128746032\n",
"Samples: 50500, Accuracy: 0.7440000113844871\n",
"Samples: 51000, Accuracy: 0.6560000130534172\n",
"Samples: 51500, Accuracy: 0.7320000132918358\n",
"Samples: 52000, Accuracy: 0.6840000146627426\n",
"Samples: 52500, Accuracy: 0.644000013768673\n",
"Samples: 53000, Accuracy: 0.7080000111460686\n",
"Samples: 53500, Accuracy: 0.7440000116825104\n",
"Samples: 54000, Accuracy: 0.6960000124573708\n",
"Samples: 54500, Accuracy: 0.7440000140666961\n",
"Samples: 55000, Accuracy: 0.5600000137090683\n",
"Samples: 55500, Accuracy: 0.5320000123977661\n",
"Samples: 56000, Accuracy: 0.5200000113248825\n",
"Samples: 56500, Accuracy: 0.588000011742115\n",
"Samples: 57000, Accuracy: 0.6200000140070915\n",
"Samples: 57500, Accuracy: 0.6320000126957893\n",
"Samples: 58000, Accuracy: 0.6400000137090683\n",
"Samples: 58500, Accuracy: 0.6480000141263008\n",
"Samples: 59000, Accuracy: 0.624000011086464\n",
"Samples: 59500, Accuracy: 0.6680000132322311\n",
"Samples: 60000, Accuracy: 0.7000000131130218\n",
"Samples: 60500, Accuracy: 0.6800000125169754\n",
"Samples: 61000, Accuracy: 0.6720000138878822\n",
"Samples: 61500, Accuracy: 0.7120000120997428\n"
],
"name": "stdout"
},
{
"output_type": "error",
"ename": "KeyboardInterrupt",
"evalue": "ignored",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-19-620e1760a881>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 14\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzero_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 15\u001b[0;31m \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 16\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/torch/tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m 243\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 244\u001b[0m inputs=inputs)\n\u001b[0;32m--> 245\u001b[0;31m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 246\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 247\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mregister_hook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m 145\u001b[0m Variable._execution_engine.run_backward(\n\u001b[1;32m 146\u001b[0m \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_tensors_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 147\u001b[0;31m allow_unreachable=True, accumulate_grad=True) # allow_unreachable flag\n\u001b[0m\u001b[1;32m 148\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 149\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
]
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "hqfgBYzKPrLN"
},
"source": [
""
],
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment