Created
March 31, 2021 06:25
-
-
Save saahiluppal/5c8c5198a481ab4f048b4dcb89a9999d to your computer and use it in GitHub Desktop.
exp0.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"name": "exp0.ipynb", | |
"provenance": [], | |
"authorship_tag": "ABX9TyNN/3Ukdn7iJKDThfwqgTXJ", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
}, | |
"accelerator": "GPU", | |
"widgets": { | |
"application/vnd.jupyter.widget-state+json": { | |
"8052ac06daa34bd7b4dc6e32e585dc7e": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "HBoxModel", | |
"state": { | |
"_view_name": "HBoxView", | |
"_dom_classes": [], | |
"_model_name": "HBoxModel", | |
"_view_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_view_count": null, | |
"_view_module_version": "1.5.0", | |
"box_style": "", | |
"layout": "IPY_MODEL_50f50d32b6804ce68312a1f41858c637", | |
"_model_module": "@jupyter-widgets/controls", | |
"children": [ | |
"IPY_MODEL_e515dfd7480b4c3aa945ee78b5448d81", | |
"IPY_MODEL_9fb941230b6d4bc983bfc89058dc6850" | |
] | |
} | |
}, | |
"50f50d32b6804ce68312a1f41858c637": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"state": { | |
"_view_name": "LayoutView", | |
"grid_template_rows": null, | |
"right": null, | |
"justify_content": null, | |
"_view_module": "@jupyter-widgets/base", | |
"overflow": null, | |
"_model_module_version": "1.2.0", | |
"_view_count": null, | |
"flex_flow": null, | |
"width": null, | |
"min_width": null, | |
"border": null, | |
"align_items": null, | |
"bottom": null, | |
"_model_module": "@jupyter-widgets/base", | |
"top": null, | |
"grid_column": null, | |
"overflow_y": null, | |
"overflow_x": null, | |
"grid_auto_flow": null, | |
"grid_area": null, | |
"grid_template_columns": null, | |
"flex": null, | |
"_model_name": "LayoutModel", | |
"justify_items": null, | |
"grid_row": null, | |
"max_height": null, | |
"align_content": null, | |
"visibility": null, | |
"align_self": null, | |
"height": null, | |
"min_height": null, | |
"padding": null, | |
"grid_auto_rows": null, | |
"grid_gap": null, | |
"max_width": null, | |
"order": null, | |
"_view_module_version": "1.2.0", | |
"grid_template_areas": null, | |
"object_position": null, | |
"object_fit": null, | |
"grid_auto_columns": null, | |
"margin": null, | |
"display": null, | |
"left": null | |
} | |
}, | |
"e515dfd7480b4c3aa945ee78b5448d81": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "FloatProgressModel", | |
"state": { | |
"_view_name": "ProgressView", | |
"style": "IPY_MODEL_6cdb857f40f647fa9d36f8f42888a3da", | |
"_dom_classes": [], | |
"description": "Downloading: 100%", | |
"_model_name": "FloatProgressModel", | |
"bar_style": "success", | |
"max": 665, | |
"_view_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"value": 665, | |
"_view_count": null, | |
"_view_module_version": "1.5.0", | |
"orientation": "horizontal", | |
"min": 0, | |
"description_tooltip": null, | |
"_model_module": "@jupyter-widgets/controls", | |
"layout": "IPY_MODEL_7814d113029645109ce156d0f0655ea9" | |
} | |
}, | |
"9fb941230b6d4bc983bfc89058dc6850": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "HTMLModel", | |
"state": { | |
"_view_name": "HTMLView", | |
"style": "IPY_MODEL_ad641b8179e14723bfc0c25ba4493778", | |
"_dom_classes": [], | |
"description": "", | |
"_model_name": "HTMLModel", | |
"placeholder": "", | |
"_view_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"value": " 665/665 [00:00<00:00, 2.69kB/s]", | |
"_view_count": null, | |
"_view_module_version": "1.5.0", | |
"description_tooltip": null, | |
"_model_module": "@jupyter-widgets/controls", | |
"layout": "IPY_MODEL_93646a85233043b3961be1bd5ef691eb" | |
} | |
}, | |
"6cdb857f40f647fa9d36f8f42888a3da": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "ProgressStyleModel", | |
"state": { | |
"_view_name": "StyleView", | |
"_model_name": "ProgressStyleModel", | |
"description_width": "initial", | |
"_view_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.5.0", | |
"_view_count": null, | |
"_view_module_version": "1.2.0", | |
"bar_color": null, | |
"_model_module": "@jupyter-widgets/controls" | |
} | |
}, | |
"7814d113029645109ce156d0f0655ea9": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"state": { | |
"_view_name": "LayoutView", | |
"grid_template_rows": null, | |
"right": null, | |
"justify_content": null, | |
"_view_module": "@jupyter-widgets/base", | |
"overflow": null, | |
"_model_module_version": "1.2.0", | |
"_view_count": null, | |
"flex_flow": null, | |
"width": null, | |
"min_width": null, | |
"border": null, | |
"align_items": null, | |
"bottom": null, | |
"_model_module": "@jupyter-widgets/base", | |
"top": null, | |
"grid_column": null, | |
"overflow_y": null, | |
"overflow_x": null, | |
"grid_auto_flow": null, | |
"grid_area": null, | |
"grid_template_columns": null, | |
"flex": null, | |
"_model_name": "LayoutModel", | |
"justify_items": null, | |
"grid_row": null, | |
"max_height": null, | |
"align_content": null, | |
"visibility": null, | |
"align_self": null, | |
"height": null, | |
"min_height": null, | |
"padding": null, | |
"grid_auto_rows": null, | |
"grid_gap": null, | |
"max_width": null, | |
"order": null, | |
"_view_module_version": "1.2.0", | |
"grid_template_areas": null, | |
"object_position": null, | |
"object_fit": null, | |
"grid_auto_columns": null, | |
"margin": null, | |
"display": null, | |
"left": null | |
} | |
}, | |
"ad641b8179e14723bfc0c25ba4493778": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "DescriptionStyleModel", | |
"state": { | |
"_view_name": "StyleView", | |
"_model_name": "DescriptionStyleModel", | |
"description_width": "", | |
"_view_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.5.0", | |
"_view_count": null, | |
"_view_module_version": "1.2.0", | |
"_model_module": "@jupyter-widgets/controls" | |
} | |
}, | |
"93646a85233043b3961be1bd5ef691eb": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"state": { | |
"_view_name": "LayoutView", | |
"grid_template_rows": null, | |
"right": null, | |
"justify_content": null, | |
"_view_module": "@jupyter-widgets/base", | |
"overflow": null, | |
"_model_module_version": "1.2.0", | |
"_view_count": null, | |
"flex_flow": null, | |
"width": null, | |
"min_width": null, | |
"border": null, | |
"align_items": null, | |
"bottom": null, | |
"_model_module": "@jupyter-widgets/base", | |
"top": null, | |
"grid_column": null, | |
"overflow_y": null, | |
"overflow_x": null, | |
"grid_auto_flow": null, | |
"grid_area": null, | |
"grid_template_columns": null, | |
"flex": null, | |
"_model_name": "LayoutModel", | |
"justify_items": null, | |
"grid_row": null, | |
"max_height": null, | |
"align_content": null, | |
"visibility": null, | |
"align_self": null, | |
"height": null, | |
"min_height": null, | |
"padding": null, | |
"grid_auto_rows": null, | |
"grid_gap": null, | |
"max_width": null, | |
"order": null, | |
"_view_module_version": "1.2.0", | |
"grid_template_areas": null, | |
"object_position": null, | |
"object_fit": null, | |
"grid_auto_columns": null, | |
"margin": null, | |
"display": null, | |
"left": null | |
} | |
}, | |
"f6548d3efb264996bd2f0f00fb1c886b": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "HBoxModel", | |
"state": { | |
"_view_name": "HBoxView", | |
"_dom_classes": [], | |
"_model_name": "HBoxModel", | |
"_view_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_view_count": null, | |
"_view_module_version": "1.5.0", | |
"box_style": "", | |
"layout": "IPY_MODEL_7278b81f49284dea9d538d0dd0cd8ccf", | |
"_model_module": "@jupyter-widgets/controls", | |
"children": [ | |
"IPY_MODEL_6437f2c8ffcd4cd7a741fd2a12048e55", | |
"IPY_MODEL_fff1e5902e384afc93e12971e6995a3b" | |
] | |
} | |
}, | |
"7278b81f49284dea9d538d0dd0cd8ccf": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"state": { | |
"_view_name": "LayoutView", | |
"grid_template_rows": null, | |
"right": null, | |
"justify_content": null, | |
"_view_module": "@jupyter-widgets/base", | |
"overflow": null, | |
"_model_module_version": "1.2.0", | |
"_view_count": null, | |
"flex_flow": null, | |
"width": null, | |
"min_width": null, | |
"border": null, | |
"align_items": null, | |
"bottom": null, | |
"_model_module": "@jupyter-widgets/base", | |
"top": null, | |
"grid_column": null, | |
"overflow_y": null, | |
"overflow_x": null, | |
"grid_auto_flow": null, | |
"grid_area": null, | |
"grid_template_columns": null, | |
"flex": null, | |
"_model_name": "LayoutModel", | |
"justify_items": null, | |
"grid_row": null, | |
"max_height": null, | |
"align_content": null, | |
"visibility": null, | |
"align_self": null, | |
"height": null, | |
"min_height": null, | |
"padding": null, | |
"grid_auto_rows": null, | |
"grid_gap": null, | |
"max_width": null, | |
"order": null, | |
"_view_module_version": "1.2.0", | |
"grid_template_areas": null, | |
"object_position": null, | |
"object_fit": null, | |
"grid_auto_columns": null, | |
"margin": null, | |
"display": null, | |
"left": null | |
} | |
}, | |
"6437f2c8ffcd4cd7a741fd2a12048e55": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "FloatProgressModel", | |
"state": { | |
"_view_name": "ProgressView", | |
"style": "IPY_MODEL_5671a7e0e65446ddaad6de7f8f5006dd", | |
"_dom_classes": [], | |
"description": "Downloading: 100%", | |
"_model_name": "FloatProgressModel", | |
"bar_style": "success", | |
"max": 548118077, | |
"_view_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"value": 548118077, | |
"_view_count": null, | |
"_view_module_version": "1.5.0", | |
"orientation": "horizontal", | |
"min": 0, | |
"description_tooltip": null, | |
"_model_module": "@jupyter-widgets/controls", | |
"layout": "IPY_MODEL_3659ed7594ff4bae87b0039f680abf76" | |
} | |
}, | |
"fff1e5902e384afc93e12971e6995a3b": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "HTMLModel", | |
"state": { | |
"_view_name": "HTMLView", | |
"style": "IPY_MODEL_96e05c46ce474dda8aa57ca7b8b650d9", | |
"_dom_classes": [], | |
"description": "", | |
"_model_name": "HTMLModel", | |
"placeholder": "", | |
"_view_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"value": " 548M/548M [00:10<00:00, 51.9MB/s]", | |
"_view_count": null, | |
"_view_module_version": "1.5.0", | |
"description_tooltip": null, | |
"_model_module": "@jupyter-widgets/controls", | |
"layout": "IPY_MODEL_3b1075ef420a4d3aaae51a1741d41b69" | |
} | |
}, | |
"5671a7e0e65446ddaad6de7f8f5006dd": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "ProgressStyleModel", | |
"state": { | |
"_view_name": "StyleView", | |
"_model_name": "ProgressStyleModel", | |
"description_width": "initial", | |
"_view_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.5.0", | |
"_view_count": null, | |
"_view_module_version": "1.2.0", | |
"bar_color": null, | |
"_model_module": "@jupyter-widgets/controls" | |
} | |
}, | |
"3659ed7594ff4bae87b0039f680abf76": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"state": { | |
"_view_name": "LayoutView", | |
"grid_template_rows": null, | |
"right": null, | |
"justify_content": null, | |
"_view_module": "@jupyter-widgets/base", | |
"overflow": null, | |
"_model_module_version": "1.2.0", | |
"_view_count": null, | |
"flex_flow": null, | |
"width": null, | |
"min_width": null, | |
"border": null, | |
"align_items": null, | |
"bottom": null, | |
"_model_module": "@jupyter-widgets/base", | |
"top": null, | |
"grid_column": null, | |
"overflow_y": null, | |
"overflow_x": null, | |
"grid_auto_flow": null, | |
"grid_area": null, | |
"grid_template_columns": null, | |
"flex": null, | |
"_model_name": "LayoutModel", | |
"justify_items": null, | |
"grid_row": null, | |
"max_height": null, | |
"align_content": null, | |
"visibility": null, | |
"align_self": null, | |
"height": null, | |
"min_height": null, | |
"padding": null, | |
"grid_auto_rows": null, | |
"grid_gap": null, | |
"max_width": null, | |
"order": null, | |
"_view_module_version": "1.2.0", | |
"grid_template_areas": null, | |
"object_position": null, | |
"object_fit": null, | |
"grid_auto_columns": null, | |
"margin": null, | |
"display": null, | |
"left": null | |
} | |
}, | |
"96e05c46ce474dda8aa57ca7b8b650d9": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "DescriptionStyleModel", | |
"state": { | |
"_view_name": "StyleView", | |
"_model_name": "DescriptionStyleModel", | |
"description_width": "", | |
"_view_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.5.0", | |
"_view_count": null, | |
"_view_module_version": "1.2.0", | |
"_model_module": "@jupyter-widgets/controls" | |
} | |
}, | |
"3b1075ef420a4d3aaae51a1741d41b69": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"state": { | |
"_view_name": "LayoutView", | |
"grid_template_rows": null, | |
"right": null, | |
"justify_content": null, | |
"_view_module": "@jupyter-widgets/base", | |
"overflow": null, | |
"_model_module_version": "1.2.0", | |
"_view_count": null, | |
"flex_flow": null, | |
"width": null, | |
"min_width": null, | |
"border": null, | |
"align_items": null, | |
"bottom": null, | |
"_model_module": "@jupyter-widgets/base", | |
"top": null, | |
"grid_column": null, | |
"overflow_y": null, | |
"overflow_x": null, | |
"grid_auto_flow": null, | |
"grid_area": null, | |
"grid_template_columns": null, | |
"flex": null, | |
"_model_name": "LayoutModel", | |
"justify_items": null, | |
"grid_row": null, | |
"max_height": null, | |
"align_content": null, | |
"visibility": null, | |
"align_self": null, | |
"height": null, | |
"min_height": null, | |
"padding": null, | |
"grid_auto_rows": null, | |
"grid_gap": null, | |
"max_width": null, | |
"order": null, | |
"_view_module_version": "1.2.0", | |
"grid_template_areas": null, | |
"object_position": null, | |
"object_fit": null, | |
"grid_auto_columns": null, | |
"margin": null, | |
"display": null, | |
"left": null | |
} | |
} | |
} | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/saahiluppal/5c8c5198a481ab4f048b4dcb89a9999d/exp0.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "G6fPatmirOqp", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"outputId": "ba6faf35-d541-44f1-f3ba-6187816ef03b" | |
}, | |
"source": [ | |
"!nvidia-smi" | |
], | |
"execution_count": 1, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Wed Mar 31 05:26:50 2021 \n", | |
"+-----------------------------------------------------------------------------+\n", | |
"| NVIDIA-SMI 460.67 Driver Version: 460.32.03 CUDA Version: 11.2 |\n", | |
"|-------------------------------+----------------------+----------------------+\n", | |
"| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n", | |
"| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n", | |
"| | | MIG M. |\n", | |
"|===============================+======================+======================|\n", | |
"| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 |\n", | |
"| N/A 52C P8 10W / 70W | 0MiB / 15109MiB | 0% Default |\n", | |
"| | | N/A |\n", | |
"+-------------------------------+----------------------+----------------------+\n", | |
" \n", | |
"+-----------------------------------------------------------------------------+\n", | |
"| Processes: |\n", | |
"| GPU GI CI PID Type Process name GPU Memory |\n", | |
"| ID ID Usage |\n", | |
"|=============================================================================|\n", | |
"| No running processes found |\n", | |
"+-----------------------------------------------------------------------------+\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "r48DtcmVLxpg" | |
}, | |
"source": [ | |
"import matplotlib.pyplot as plt" | |
], | |
"execution_count": 2, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "IUQfhWwUL3Sm" | |
}, | |
"source": [ | |
"import numpy as np" | |
], | |
"execution_count": 3, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "WxTl32dSL7tI" | |
}, | |
"source": [ | |
"import torch\n", | |
"import torch.nn as nn" | |
], | |
"execution_count": 4, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "xBjFqzyUL9JU", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"outputId": "4c252827-2d3a-430b-9619-7992764f76b7" | |
}, | |
"source": [ | |
"%pip install transformers --quiet" | |
], | |
"execution_count": 5, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"\u001b[K |████████████████████████████████| 2.0MB 8.1MB/s \n", | |
"\u001b[K |████████████████████████████████| 3.2MB 48.3MB/s \n", | |
"\u001b[K |████████████████████████████████| 890kB 38.6MB/s \n", | |
"\u001b[?25h Building wheel for sacremoses (setup.py) ... \u001b[?25l\u001b[?25hdone\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "_FqNloDiL_8u" | |
}, | |
"source": [ | |
"from transformers.models.gpt2.modeling_gpt2 import GPT2Model" | |
], | |
"execution_count": 6, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "ExzGiG5TMKFa" | |
}, | |
"source": [ | |
"def generate_example(n):\n", | |
" bits = np.random.randint(low=0, high=2, size=(2, n))\n", | |
" xor = np.logical_xor(bits[0], bits[1]).astype(np.long)\n", | |
" return bits.reshape((2*n)), xor" | |
], | |
"execution_count": 7, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "LJKFBV98MQzR" | |
}, | |
"source": [ | |
"n = 5\n", | |
"bits, xor = generate_example(n)" | |
], | |
"execution_count": 8, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "Z1VoV-4fMSq0", | |
"outputId": "be79975f-8a33-4ad0-f506-919608e6b295" | |
}, | |
"source": [ | |
"bits[:n]" | |
], | |
"execution_count": 9, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"array([1, 0, 0, 1, 0])" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 9 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "if5m_tQtMhoG", | |
"outputId": "c99b7a04-a6d1-4d9d-f061-7097d373484e" | |
}, | |
"source": [ | |
"bits[n:]" | |
], | |
"execution_count": 10, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"array([0, 1, 0, 0, 1])" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 10 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "95zCLzqaMjWe", | |
"outputId": "11d90b29-5fd5-4989-9e5a-02978c35caf3" | |
}, | |
"source": [ | |
"xor" | |
], | |
"execution_count": 11, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"array([1, 1, 0, 1, 1])" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 11 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 35 | |
}, | |
"id": "If40KIC2Mk1m", | |
"outputId": "6e6fb759-54f8-402f-d2aa-291d85bd41ca" | |
}, | |
"source": [ | |
"device = 'cuda' if torch.cuda.is_available() else \"cpu\"\n", | |
"device" | |
], | |
"execution_count": 12, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"application/vnd.google.colaboratory.intrinsic+json": { | |
"type": "string" | |
}, | |
"text/plain": [ | |
"'cuda'" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 12 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 169, | |
"referenced_widgets": [ | |
"8052ac06daa34bd7b4dc6e32e585dc7e", | |
"50f50d32b6804ce68312a1f41858c637", | |
"e515dfd7480b4c3aa945ee78b5448d81", | |
"9fb941230b6d4bc983bfc89058dc6850", | |
"6cdb857f40f647fa9d36f8f42888a3da", | |
"7814d113029645109ce156d0f0655ea9", | |
"ad641b8179e14723bfc0c25ba4493778", | |
"93646a85233043b3961be1bd5ef691eb", | |
"f6548d3efb264996bd2f0f00fb1c886b", | |
"7278b81f49284dea9d538d0dd0cd8ccf", | |
"6437f2c8ffcd4cd7a741fd2a12048e55", | |
"fff1e5902e384afc93e12971e6995a3b", | |
"5671a7e0e65446ddaad6de7f8f5006dd", | |
"3659ed7594ff4bae87b0039f680abf76", | |
"96e05c46ce474dda8aa57ca7b8b650d9", | |
"3b1075ef420a4d3aaae51a1741d41b69" | |
] | |
}, | |
"id": "1qf8640BMsCM", | |
"outputId": "947bd6f7-9f3e-4119-b7c5-1ee47c6cccc4" | |
}, | |
"source": [ | |
"gpt2 = GPT2Model.from_pretrained(\"gpt2\")\n", | |
"in_layer = nn.Embedding(2, 768)\n", | |
"out_layer = nn.Linear(768, 2)" | |
], | |
"execution_count": 13, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "8052ac06daa34bd7b4dc6e32e585dc7e", | |
"version_minor": 0, | |
"version_major": 2 | |
}, | |
"text/plain": [ | |
"HBox(children=(FloatProgress(value=0.0, description='Downloading', max=665.0, style=ProgressStyle(description_…" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
} | |
}, | |
{ | |
"output_type": "stream", | |
"text": [ | |
"\n" | |
], | |
"name": "stdout" | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "f6548d3efb264996bd2f0f00fb1c886b", | |
"version_minor": 0, | |
"version_major": 2 | |
}, | |
"text/plain": [ | |
"HBox(children=(FloatProgress(value=0.0, description='Downloading', max=548118077.0, style=ProgressStyle(descri…" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
} | |
}, | |
{ | |
"output_type": "stream", | |
"text": [ | |
"\n" | |
], | |
"name": "stdout" | |
}, | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Some weights of GPT2Model were not initialized from the model checkpoint at gpt2 and are newly initialized: ['h.0.attn.masked_bias', 'h.1.attn.masked_bias', 'h.2.attn.masked_bias', 'h.3.attn.masked_bias', 'h.4.attn.masked_bias', 'h.5.attn.masked_bias', 'h.6.attn.masked_bias', 'h.7.attn.masked_bias', 'h.8.attn.masked_bias', 'h.9.attn.masked_bias', 'h.10.attn.masked_bias', 'h.11.attn.masked_bias']\n", | |
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" | |
], | |
"name": "stderr" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "1yLVG7r6M2K0" | |
}, | |
"source": [ | |
"for name, param in gpt2.named_parameters():\n", | |
" if \"ln\" in name or \"wpe\" in name:\n", | |
" param.require_grad = True\n", | |
" else:\n", | |
" param.require_grad = False" | |
], | |
"execution_count": 14, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "xJBXFZD4NHyv" | |
}, | |
"source": [ | |
"params = list(gpt2.parameters()) + list(in_layer.parameters()) + list(out_layer.parameters())" | |
], | |
"execution_count": 15, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "GvCcqtktNVQu" | |
}, | |
"source": [ | |
"optimizer = torch.optim.Adam(params)" | |
], | |
"execution_count": 16, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "ArmMF_fiNjwx" | |
}, | |
"source": [ | |
"loss_fn = nn.CrossEntropyLoss()" | |
], | |
"execution_count": 17, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "2WnQyCrANmJh" | |
}, | |
"source": [ | |
"for layer in (gpt2, in_layer, out_layer):\n", | |
" layer.to(device = device)\n", | |
" layer.train()" | |
], | |
"execution_count": 18, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 1000 | |
}, | |
"id": "0pckddjKNsFy", | |
"outputId": "0c699f31-30a7-4cc7-db19-0816785c9528" | |
}, | |
"source": [ | |
"accuracies = [0]\n", | |
"while sum(accuracies[-50:]) / len(accuracies[-50:]) < .99:\n", | |
" x, y = generate_example(n)\n", | |
" x = torch.from_numpy(x).to(device=device, dtype=torch.long)\n", | |
" y = torch.from_numpy(y).to(device=device, dtype=torch.long)\n", | |
"\n", | |
" embeddings = in_layer(x.reshape(1, -1))\n", | |
" hidden_state = gpt2(inputs_embeds=embeddings).last_hidden_state[:, n:]\n", | |
" logits = out_layer(hidden_state)[0]\n", | |
"\n", | |
" loss = loss_fn(logits, y)\n", | |
" accuracies.append((logits.argmax(dim=-1) == y).float().mean().item())\n", | |
"\n", | |
" optimizer.zero_grad()\n", | |
" loss.backward()\n", | |
" optimizer.step()\n", | |
"\n", | |
" if len(accuracies) % 500 == 0:\n", | |
" accuracy = sum(accuracies[-50:]) / len(accuracies[-50:])\n", | |
" print(f\"Samples: {len(accuracies)}, Accuracy: {accuracy}\")" | |
], | |
"execution_count": 19, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Samples: 500, Accuracy: 0.5160000124573707\n", | |
"Samples: 1000, Accuracy: 0.52400001257658\n", | |
"Samples: 1500, Accuracy: 0.4840000116825104\n", | |
"Samples: 2000, Accuracy: 0.48800001114606856\n", | |
"Samples: 2500, Accuracy: 0.5200000119209289\n", | |
"Samples: 3000, Accuracy: 0.4360000091791153\n", | |
"Samples: 3500, Accuracy: 0.5440000140666962\n", | |
"Samples: 4000, Accuracy: 0.5000000119209289\n", | |
"Samples: 4500, Accuracy: 0.49600000977516173\n", | |
"Samples: 5000, Accuracy: 0.48800001174211505\n", | |
"Samples: 5500, Accuracy: 0.48800001114606856\n", | |
"Samples: 6000, Accuracy: 0.42400001049041747\n", | |
"Samples: 6500, Accuracy: 0.532000013589859\n", | |
"Samples: 7000, Accuracy: 0.5600000107288361\n", | |
"Samples: 7500, Accuracy: 0.6080000150203705\n", | |
"Samples: 8000, Accuracy: 0.5360000130534172\n", | |
"Samples: 8500, Accuracy: 0.5520000138878822\n", | |
"Samples: 9000, Accuracy: 0.548000011742115\n", | |
"Samples: 9500, Accuracy: 0.5840000131726265\n", | |
"Samples: 10000, Accuracy: 0.62400001257658\n", | |
"Samples: 10500, Accuracy: 0.5680000129342079\n", | |
"Samples: 11000, Accuracy: 0.5840000137686729\n", | |
"Samples: 11500, Accuracy: 0.6280000135302544\n", | |
"Samples: 12000, Accuracy: 0.5640000134706498\n", | |
"Samples: 12500, Accuracy: 0.6240000146627426\n", | |
"Samples: 13000, Accuracy: 0.6080000129342079\n", | |
"Samples: 13500, Accuracy: 0.5920000141859054\n", | |
"Samples: 14000, Accuracy: 0.6280000120401382\n", | |
"Samples: 14500, Accuracy: 0.6440000131726265\n", | |
"Samples: 15000, Accuracy: 0.6200000131130219\n", | |
"Samples: 15500, Accuracy: 0.6280000144243241\n", | |
"Samples: 16000, Accuracy: 0.7080000126361847\n", | |
"Samples: 16500, Accuracy: 0.680000011920929\n", | |
"Samples: 17000, Accuracy: 0.6560000160336494\n", | |
"Samples: 17500, Accuracy: 0.6360000133514404\n", | |
"Samples: 18000, Accuracy: 0.6760000145435333\n", | |
"Samples: 18500, Accuracy: 0.6200000125169755\n", | |
"Samples: 19000, Accuracy: 0.6360000133514404\n", | |
"Samples: 19500, Accuracy: 0.6920000123977661\n", | |
"Samples: 20000, Accuracy: 0.648000015616417\n", | |
"Samples: 20500, Accuracy: 0.6520000129938126\n", | |
"Samples: 21000, Accuracy: 0.6240000107884407\n", | |
"Samples: 21500, Accuracy: 0.6040000134706497\n", | |
"Samples: 22000, Accuracy: 0.6480000147223473\n", | |
"Samples: 22500, Accuracy: 0.6760000124573707\n", | |
"Samples: 23000, Accuracy: 0.6920000138878822\n", | |
"Samples: 23500, Accuracy: 0.6520000115036965\n", | |
"Samples: 24000, Accuracy: 0.6480000147223473\n", | |
"Samples: 24500, Accuracy: 0.6720000141859055\n", | |
"Samples: 25000, Accuracy: 0.6480000132322311\n", | |
"Samples: 25500, Accuracy: 0.6600000134110451\n", | |
"Samples: 26000, Accuracy: 0.5960000139474869\n", | |
"Samples: 26500, Accuracy: 0.6120000144839287\n", | |
"Samples: 27000, Accuracy: 0.6360000112652778\n", | |
"Samples: 27500, Accuracy: 0.6320000120997429\n", | |
"Samples: 28000, Accuracy: 0.6160000118613244\n", | |
"Samples: 28500, Accuracy: 0.5960000151395798\n", | |
"Samples: 29000, Accuracy: 0.7080000123381615\n", | |
"Samples: 29500, Accuracy: 0.6560000136494637\n", | |
"Samples: 30000, Accuracy: 0.6760000136494636\n", | |
"Samples: 30500, Accuracy: 0.668000012934208\n", | |
"Samples: 31000, Accuracy: 0.6560000130534172\n", | |
"Samples: 31500, Accuracy: 0.6120000138878823\n", | |
"Samples: 32000, Accuracy: 0.6080000129342079\n", | |
"Samples: 32500, Accuracy: 0.6640000140666962\n", | |
"Samples: 33000, Accuracy: 0.6400000163912773\n", | |
"Samples: 33500, Accuracy: 0.6720000126957894\n", | |
"Samples: 34000, Accuracy: 0.6640000134706497\n", | |
"Samples: 34500, Accuracy: 0.6680000147223473\n", | |
"Samples: 35000, Accuracy: 0.6280000147223472\n", | |
"Samples: 35500, Accuracy: 0.7160000127553939\n", | |
"Samples: 36000, Accuracy: 0.6560000109672547\n", | |
"Samples: 36500, Accuracy: 0.6480000126361847\n", | |
"Samples: 37000, Accuracy: 0.6800000140070915\n", | |
"Samples: 37500, Accuracy: 0.6920000126957894\n", | |
"Samples: 38000, Accuracy: 0.7040000146627426\n", | |
"Samples: 38500, Accuracy: 0.7440000128746033\n", | |
"Samples: 39000, Accuracy: 0.6800000140070915\n", | |
"Samples: 39500, Accuracy: 0.7360000109672546\n", | |
"Samples: 40000, Accuracy: 0.6880000135302544\n", | |
"Samples: 40500, Accuracy: 0.7280000132322312\n", | |
"Samples: 41000, Accuracy: 0.7080000111460686\n", | |
"Samples: 41500, Accuracy: 0.6640000116825103\n", | |
"Samples: 42000, Accuracy: 0.7200000137090683\n", | |
"Samples: 42500, Accuracy: 0.7160000121593475\n", | |
"Samples: 43000, Accuracy: 0.756000010073185\n", | |
"Samples: 43500, Accuracy: 0.7800000113248825\n", | |
"Samples: 44000, Accuracy: 0.6760000112652779\n", | |
"Samples: 44500, Accuracy: 0.7400000125169754\n", | |
"Samples: 45000, Accuracy: 0.740000013411045\n", | |
"Samples: 45500, Accuracy: 0.7400000143051148\n", | |
"Samples: 46000, Accuracy: 0.6840000137686729\n", | |
"Samples: 46500, Accuracy: 0.7120000132918358\n", | |
"Samples: 47000, Accuracy: 0.668000014424324\n", | |
"Samples: 47500, Accuracy: 0.6400000137090683\n", | |
"Samples: 48000, Accuracy: 0.7320000118017197\n", | |
"Samples: 48500, Accuracy: 0.7360000130534172\n", | |
"Samples: 49000, Accuracy: 0.6040000134706497\n", | |
"Samples: 49500, Accuracy: 0.7080000111460686\n", | |
"Samples: 50000, Accuracy: 0.7240000128746032\n", | |
"Samples: 50500, Accuracy: 0.7440000113844871\n", | |
"Samples: 51000, Accuracy: 0.6560000130534172\n", | |
"Samples: 51500, Accuracy: 0.7320000132918358\n", | |
"Samples: 52000, Accuracy: 0.6840000146627426\n", | |
"Samples: 52500, Accuracy: 0.644000013768673\n", | |
"Samples: 53000, Accuracy: 0.7080000111460686\n", | |
"Samples: 53500, Accuracy: 0.7440000116825104\n", | |
"Samples: 54000, Accuracy: 0.6960000124573708\n", | |
"Samples: 54500, Accuracy: 0.7440000140666961\n", | |
"Samples: 55000, Accuracy: 0.5600000137090683\n", | |
"Samples: 55500, Accuracy: 0.5320000123977661\n", | |
"Samples: 56000, Accuracy: 0.5200000113248825\n", | |
"Samples: 56500, Accuracy: 0.588000011742115\n", | |
"Samples: 57000, Accuracy: 0.6200000140070915\n", | |
"Samples: 57500, Accuracy: 0.6320000126957893\n", | |
"Samples: 58000, Accuracy: 0.6400000137090683\n", | |
"Samples: 58500, Accuracy: 0.6480000141263008\n", | |
"Samples: 59000, Accuracy: 0.624000011086464\n", | |
"Samples: 59500, Accuracy: 0.6680000132322311\n", | |
"Samples: 60000, Accuracy: 0.7000000131130218\n", | |
"Samples: 60500, Accuracy: 0.6800000125169754\n", | |
"Samples: 61000, Accuracy: 0.6720000138878822\n", | |
"Samples: 61500, Accuracy: 0.7120000120997428\n" | |
], | |
"name": "stdout" | |
}, | |
{ | |
"output_type": "error", | |
"ename": "KeyboardInterrupt", | |
"evalue": "ignored", | |
"traceback": [ | |
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | |
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", | |
"\u001b[0;32m<ipython-input-19-620e1760a881>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 14\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzero_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 15\u001b[0;31m \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 16\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/torch/tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m 243\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 244\u001b[0m inputs=inputs)\n\u001b[0;32m--> 245\u001b[0;31m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 246\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 247\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mregister_hook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m 145\u001b[0m Variable._execution_engine.run_backward(\n\u001b[1;32m 146\u001b[0m \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_tensors_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 147\u001b[0;31m allow_unreachable=True, accumulate_grad=True) # allow_unreachable flag\n\u001b[0m\u001b[1;32m 148\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 149\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;31mKeyboardInterrupt\u001b[0m: " | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "hqfgBYzKPrLN" | |
}, | |
"source": [ | |
"" | |
], | |
"execution_count": null, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment