Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save marii-moe/ec2492aac5187e7d47fa2b73affbc3f5 to your computer and use it in GitHub Desktop.
Save marii-moe/ec2492aac5187e7d47fa2b73affbc3f5 to your computer and use it in GitHub Desktop.
Copia de PickableOpt+Basic_lenet_exploration_MultiTPU.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"accelerator": "TPU",
"colab": {
"name": "Copia de PickableOpt+Basic_lenet_exploration_MultiTPU.ipynb",
"provenance": [],
"collapsed_sections": [],
"include_colab_link": true
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.7"
},
"widgets": {
"application/vnd.jupyter.widget-state+json": {
"5eb5b3fdae2348258df1d841c2310685": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"state": {
"_view_name": "HBoxView",
"_dom_classes": [],
"_model_name": "HBoxModel",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.5.0",
"box_style": "",
"layout": "IPY_MODEL_96bb07f5e97844358ad5d2e6bfa86fe1",
"_model_module": "@jupyter-widgets/controls",
"children": [
"IPY_MODEL_eb53ce9afd4b4825ade9ac4fcd8761d1",
"IPY_MODEL_080099fc4e7e485b9513b9ccd43919aa"
]
}
},
"96bb07f5e97844358ad5d2e6bfa86fe1": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"eb53ce9afd4b4825ade9ac4fcd8761d1": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"state": {
"_view_name": "ProgressView",
"style": "IPY_MODEL_d2674c49b31642bfb19fa859ce111c91",
"_dom_classes": [],
"description": "100%",
"_model_name": "FloatProgressModel",
"bar_style": "success",
"max": 87306240,
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": 87306240,
"_view_count": null,
"_view_module_version": "1.5.0",
"orientation": "horizontal",
"min": 0,
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_57413751c8134ee0aad378133ea958be"
}
},
"080099fc4e7e485b9513b9ccd43919aa": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"state": {
"_view_name": "HTMLView",
"style": "IPY_MODEL_9bce95278aa447c2813a3d0c2211c763",
"_dom_classes": [],
"description": "",
"_model_name": "HTMLModel",
"placeholder": "​",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": " 83.3M/83.3M [00:02<00:00, 43.5MB/s]",
"_view_count": null,
"_view_module_version": "1.5.0",
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_31651a7da3d643c89ba465dd856f49dd"
}
},
"d2674c49b31642bfb19fa859ce111c91": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"state": {
"_view_name": "StyleView",
"_model_name": "ProgressStyleModel",
"description_width": "initial",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"bar_color": null,
"_model_module": "@jupyter-widgets/controls"
}
},
"57413751c8134ee0aad378133ea958be": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"9bce95278aa447c2813a3d0c2211c763": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"state": {
"_view_name": "StyleView",
"_model_name": "DescriptionStyleModel",
"description_width": "",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"_model_module": "@jupyter-widgets/controls"
}
},
"31651a7da3d643c89ba465dd856f49dd": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
}
}
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/marii-moe/ec2492aac5187e7d47fa2b73affbc3f5/copia-de-pickableopt-basic_lenet_exploration_multitpu.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "BmnUX_l8lQ6B"
},
"source": [
"# Install fastai2 from github"
]
},
{
"cell_type": "code",
"metadata": {
"id": "Q5DZXcBNJoy1",
"outputId": "c366c166-db06-4650-b350-b7093f6c6ae9",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
}
},
"source": [
"!pip install -U pandas --upgrade\n",
"!pip install -U fastcore --upgrade\n",
"!pip install -U fastai --upgrade \n",
"!pip install -Uqq git+https://github.com/tyoc213/fastai_xla_extensions@fix_prev_lenet"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"Collecting pandas\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/a2/21/e10d65222d19a2537e3eb0df306686a9eabd08b3c98dd120e43720bf802d/pandas-1.1.3-cp36-cp36m-manylinux1_x86_64.whl (9.5MB)\n",
"\u001b[K |████████████████████████████████| 9.5MB 3.1MB/s \n",
"\u001b[?25hRequirement already satisfied, skipping upgrade: pytz>=2017.2 in /usr/local/lib/python3.6/dist-packages (from pandas) (2018.9)\n",
"Requirement already satisfied, skipping upgrade: numpy>=1.15.4 in /usr/local/lib/python3.6/dist-packages (from pandas) (1.18.5)\n",
"Requirement already satisfied, skipping upgrade: python-dateutil>=2.7.3 in /usr/local/lib/python3.6/dist-packages (from pandas) (2.8.1)\n",
"Requirement already satisfied, skipping upgrade: six>=1.5 in /usr/local/lib/python3.6/dist-packages (from python-dateutil>=2.7.3->pandas) (1.15.0)\n",
"Installing collected packages: pandas\n",
" Found existing installation: pandas 1.1.2\n",
" Uninstalling pandas-1.1.2:\n",
" Successfully uninstalled pandas-1.1.2\n",
"Successfully installed pandas-1.1.3\n"
],
"name": "stdout"
},
{
"output_type": "display_data",
"data": {
"application/vnd.colab-display-data+json": {
"pip_warning": {
"packages": [
"pandas"
]
}
}
},
"metadata": {
"tags": []
}
},
{
"output_type": "stream",
"text": [
"Collecting fastcore\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/d6/52/8ba4fa23b95f06dd1a4915ccd381c3c8263ad7186d5466083657b2a488dc/fastcore-1.1.0-py3-none-any.whl (42kB)\n",
"\r\u001b[K |███████▊ | 10kB 11.5MB/s eta 0:00:01\r\u001b[K |███████████████▌ | 20kB 2.0MB/s eta 0:00:01\r\u001b[K |███████████████████████▎ | 30kB 2.5MB/s eta 0:00:01\r\u001b[K |███████████████████████████████ | 40kB 3.0MB/s eta 0:00:01\r\u001b[K |████████████████████████████████| 51kB 1.8MB/s \n",
"\u001b[?25hRequirement already satisfied, skipping upgrade: packaging in /usr/local/lib/python3.6/dist-packages (from fastcore) (20.4)\n",
"Requirement already satisfied, skipping upgrade: pip in /usr/local/lib/python3.6/dist-packages (from fastcore) (19.3.1)\n",
"Requirement already satisfied, skipping upgrade: six in /usr/local/lib/python3.6/dist-packages (from packaging->fastcore) (1.15.0)\n",
"Requirement already satisfied, skipping upgrade: pyparsing>=2.0.2 in /usr/local/lib/python3.6/dist-packages (from packaging->fastcore) (2.4.7)\n",
"Installing collected packages: fastcore\n",
"Successfully installed fastcore-1.1.0\n",
"Collecting fastai\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/28/d9/23222f694d28a6bd798f1c0f3600efd31c623ba63115c11d8fd83c83216e/fastai-2.0.16-py3-none-any.whl (187kB)\n",
"\u001b[K |████████████████████████████████| 194kB 3.3MB/s \n",
"\u001b[?25hRequirement already satisfied, skipping upgrade: pyyaml in /usr/local/lib/python3.6/dist-packages (from fastai) (3.13)\n",
"Requirement already satisfied, skipping upgrade: pandas in /usr/local/lib/python3.6/dist-packages (from fastai) (1.1.3)\n",
"Requirement already satisfied, skipping upgrade: torch>=1.6.0 in /usr/local/lib/python3.6/dist-packages (from fastai) (1.6.0+cu101)\n",
"Requirement already satisfied, skipping upgrade: fastprogress>=0.2.4 in /usr/local/lib/python3.6/dist-packages (from fastai) (1.0.0)\n",
"Requirement already satisfied, skipping upgrade: matplotlib in /usr/local/lib/python3.6/dist-packages (from fastai) (3.2.2)\n",
"Requirement already satisfied, skipping upgrade: requests in /usr/local/lib/python3.6/dist-packages (from fastai) (2.23.0)\n",
"Requirement already satisfied, skipping upgrade: packaging in /usr/local/lib/python3.6/dist-packages (from fastai) (20.4)\n",
"Requirement already satisfied, skipping upgrade: pillow in /usr/local/lib/python3.6/dist-packages (from fastai) (7.0.0)\n",
"Requirement already satisfied, skipping upgrade: pip in /usr/local/lib/python3.6/dist-packages (from fastai) (19.3.1)\n",
"Requirement already satisfied, skipping upgrade: spacy in /usr/local/lib/python3.6/dist-packages (from fastai) (2.2.4)\n",
"Requirement already satisfied, skipping upgrade: torchvision>=0.7 in /usr/local/lib/python3.6/dist-packages (from fastai) (0.7.0+cu101)\n",
"Requirement already satisfied, skipping upgrade: scipy in /usr/local/lib/python3.6/dist-packages (from fastai) (1.4.1)\n",
"Requirement already satisfied, skipping upgrade: fastcore>=1.1.0 in /usr/local/lib/python3.6/dist-packages (from fastai) (1.1.0)\n",
"Requirement already satisfied, skipping upgrade: scikit-learn in /usr/local/lib/python3.6/dist-packages (from fastai) (0.22.2.post1)\n",
"Requirement already satisfied, skipping upgrade: numpy>=1.15.4 in /usr/local/lib/python3.6/dist-packages (from pandas->fastai) (1.18.5)\n",
"Requirement already satisfied, skipping upgrade: python-dateutil>=2.7.3 in /usr/local/lib/python3.6/dist-packages (from pandas->fastai) (2.8.1)\n",
"Requirement already satisfied, skipping upgrade: pytz>=2017.2 in /usr/local/lib/python3.6/dist-packages (from pandas->fastai) (2018.9)\n",
"Requirement already satisfied, skipping upgrade: future in /usr/local/lib/python3.6/dist-packages (from torch>=1.6.0->fastai) (0.16.0)\n",
"Requirement already satisfied, skipping upgrade: cycler>=0.10 in /usr/local/lib/python3.6/dist-packages (from matplotlib->fastai) (0.10.0)\n",
"Requirement already satisfied, skipping upgrade: kiwisolver>=1.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->fastai) (1.2.0)\n",
"Requirement already satisfied, skipping upgrade: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->fastai) (2.4.7)\n",
"Requirement already satisfied, skipping upgrade: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests->fastai) (1.24.3)\n",
"Requirement already satisfied, skipping upgrade: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests->fastai) (2020.6.20)\n",
"Requirement already satisfied, skipping upgrade: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->fastai) (3.0.4)\n",
"Requirement already satisfied, skipping upgrade: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->fastai) (2.10)\n",
"Requirement already satisfied, skipping upgrade: six in /usr/local/lib/python3.6/dist-packages (from packaging->fastai) (1.15.0)\n",
"Requirement already satisfied, skipping upgrade: setuptools in /usr/local/lib/python3.6/dist-packages (from spacy->fastai) (50.3.0)\n",
"Requirement already satisfied, skipping upgrade: murmurhash<1.1.0,>=0.28.0 in /usr/local/lib/python3.6/dist-packages (from spacy->fastai) (1.0.2)\n",
"Requirement already satisfied, skipping upgrade: srsly<1.1.0,>=1.0.2 in /usr/local/lib/python3.6/dist-packages (from spacy->fastai) (1.0.2)\n",
"Requirement already satisfied, skipping upgrade: catalogue<1.1.0,>=0.0.7 in /usr/local/lib/python3.6/dist-packages (from spacy->fastai) (1.0.0)\n",
"Requirement already satisfied, skipping upgrade: blis<0.5.0,>=0.4.0 in /usr/local/lib/python3.6/dist-packages (from spacy->fastai) (0.4.1)\n",
"Requirement already satisfied, skipping upgrade: wasabi<1.1.0,>=0.4.0 in /usr/local/lib/python3.6/dist-packages (from spacy->fastai) (0.8.0)\n",
"Requirement already satisfied, skipping upgrade: cymem<2.1.0,>=2.0.2 in /usr/local/lib/python3.6/dist-packages (from spacy->fastai) (2.0.3)\n",
"Requirement already satisfied, skipping upgrade: plac<1.2.0,>=0.9.6 in /usr/local/lib/python3.6/dist-packages (from spacy->fastai) (1.1.3)\n",
"Requirement already satisfied, skipping upgrade: thinc==7.4.0 in /usr/local/lib/python3.6/dist-packages (from spacy->fastai) (7.4.0)\n",
"Requirement already satisfied, skipping upgrade: preshed<3.1.0,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from spacy->fastai) (3.0.2)\n",
"Requirement already satisfied, skipping upgrade: tqdm<5.0.0,>=4.38.0 in /usr/local/lib/python3.6/dist-packages (from spacy->fastai) (4.41.1)\n",
"Requirement already satisfied, skipping upgrade: joblib>=0.11 in /usr/local/lib/python3.6/dist-packages (from scikit-learn->fastai) (0.16.0)\n",
"Requirement already satisfied, skipping upgrade: importlib-metadata>=0.20; python_version < \"3.8\" in /usr/local/lib/python3.6/dist-packages (from catalogue<1.1.0,>=0.0.7->spacy->fastai) (2.0.0)\n",
"Requirement already satisfied, skipping upgrade: zipp>=0.5 in /usr/local/lib/python3.6/dist-packages (from importlib-metadata>=0.20; python_version < \"3.8\"->catalogue<1.1.0,>=0.0.7->spacy->fastai) (3.2.0)\n",
"Installing collected packages: fastai\n",
" Found existing installation: fastai 1.0.61\n",
" Uninstalling fastai-1.0.61:\n",
" Successfully uninstalled fastai-1.0.61\n",
"Successfully installed fastai-2.0.16\n",
" Building wheel for fastai-xla-extensions (setup.py) ... \u001b[?25l\u001b[?25hdone\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "GucdOzF7r6ch",
"outputId": "a6dd9b1e-cce4-421d-931f-c743d0f50314",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 72
}
},
"source": [
"VERSION = \"20200707\" #\"20200515\" @param [\"1.5\" , \"20200325\", \"nightly\"]\n",
"!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
" % Total % Received % Xferd Average Speed Time Time Time Current\n",
" Dload Upload Total Spent Left Speed\n",
"100 5116 100 5116 0 0 25838 0 --:--:-- --:--:-- --:--:-- 25838\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "BxoA3fJusV17",
"outputId": "95abbef5-baaf-4377-9027-ae04f968867b",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
}
},
"source": [
"#!TORCH_SHOW_CPP_STACKTRACES=1 python pytorch-xla-env-setup.py --apt-packages libomp5 libopenblas-dev\n",
"!python pytorch-xla-env-setup.py --version $VERSION --apt-packages libomp5 libopenblas-dev"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"Updating... This may take around 2 minutes.\n",
"Updating TPU runtime to pytorch-dev20200707 ...\n",
"Collecting cloud-tpu-client\n",
" Downloading https://files.pythonhosted.org/packages/56/9f/7b1958c2886db06feb5de5b2c191096f9e619914b6c31fdf93999fdbbd8b/cloud_tpu_client-0.10-py3-none-any.whl\n",
"Requirement already satisfied: oauth2client in /usr/local/lib/python3.6/dist-packages (from cloud-tpu-client) (4.1.3)\n",
"Collecting google-api-python-client==1.8.0\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/9a/b4/a955f393b838bc47cbb6ae4643b9d0f90333d3b4db4dc1e819f36aad18cc/google_api_python_client-1.8.0-py3-none-any.whl (57kB)\n",
"\u001b[K |████████████████████████████████| 61kB 2.7MB/s \n",
"\u001b[?25hRequirement already satisfied: pyasn1-modules>=0.0.5 in /usr/local/lib/python3.6/dist-packages (from oauth2client->cloud-tpu-client) (0.2.8)\n",
"Requirement already satisfied: pyasn1>=0.1.7 in /usr/local/lib/python3.6/dist-packages (from oauth2client->cloud-tpu-client) (0.4.8)\n",
"Requirement already satisfied: httplib2>=0.9.1 in /usr/local/lib/python3.6/dist-packages (from oauth2client->cloud-tpu-client) (0.17.4)\n",
"Requirement already satisfied: rsa>=3.1.4 in /usr/local/lib/python3.6/dist-packages (from oauth2client->cloud-tpu-client) (4.6)\n",
"Requirement already satisfied: six>=1.6.1 in /usr/local/lib/python3.6/dist-packages (from oauth2client->cloud-tpu-client) (1.15.0)\n",
"Requirement already satisfied: uritemplate<4dev,>=3.0.0 in /usr/local/lib/python3.6/dist-packages (from google-api-python-client==1.8.0->cloud-tpu-client) (3.0.1)\n",
"Requirement already satisfied: google-auth>=1.4.1 in /usr/local/lib/python3.6/dist-packages (from google-api-python-client==1.8.0->cloud-tpu-client) (1.17.2)\n",
"Requirement already satisfied: google-api-core<2dev,>=1.13.0 in /usr/local/lib/python3.6/dist-packages (from google-api-python-client==1.8.0->cloud-tpu-client) (1.16.0)\n",
"Uninstalling torch-1.6.0+cu101:\n",
"Requirement already satisfied: google-auth-httplib2>=0.0.3 in /usr/local/lib/python3.6/dist-packages (from google-api-python-client==1.8.0->cloud-tpu-client) (0.0.4)\n",
"Requirement already satisfied: cachetools<5.0,>=2.0.0 in /usr/local/lib/python3.6/dist-packages (from google-auth>=1.4.1->google-api-python-client==1.8.0->cloud-tpu-client) (4.1.1)\n",
"Requirement already satisfied: setuptools>=40.3.0 in /usr/local/lib/python3.6/dist-packages (from google-auth>=1.4.1->google-api-python-client==1.8.0->cloud-tpu-client) (50.3.0)\n",
"Requirement already satisfied: pytz in /usr/local/lib/python3.6/dist-packages (from google-api-core<2dev,>=1.13.0->google-api-python-client==1.8.0->cloud-tpu-client) (2018.9)\n",
"Requirement already satisfied: googleapis-common-protos<2.0dev,>=1.6.0 in /usr/local/lib/python3.6/dist-packages (from google-api-core<2dev,>=1.13.0->google-api-python-client==1.8.0->cloud-tpu-client) (1.52.0)\n",
"Requirement already satisfied: requests<3.0.0dev,>=2.18.0 in /usr/local/lib/python3.6/dist-packages (from google-api-core<2dev,>=1.13.0->google-api-python-client==1.8.0->cloud-tpu-client) (2.23.0)\n",
"Requirement already satisfied: protobuf>=3.4.0 in /usr/local/lib/python3.6/dist-packages (from google-api-core<2dev,>=1.13.0->google-api-python-client==1.8.0->cloud-tpu-client) (3.12.4)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests<3.0.0dev,>=2.18.0->google-api-core<2dev,>=1.13.0->google-api-python-client==1.8.0->cloud-tpu-client) (2020.6.20)\n",
"Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests<3.0.0dev,>=2.18.0->google-api-core<2dev,>=1.13.0->google-api-python-client==1.8.0->cloud-tpu-client) (2.10)\n",
"Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests<3.0.0dev,>=2.18.0->google-api-core<2dev,>=1.13.0->google-api-python-client==1.8.0->cloud-tpu-client) (1.24.3)\n",
"Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests<3.0.0dev,>=2.18.0->google-api-core<2dev,>=1.13.0->google-api-python-client==1.8.0->cloud-tpu-client) (3.0.4)\n",
"Installing collected packages: google-api-python-client, cloud-tpu-client\n",
" Found existing installation: google-api-python-client 1.7.12\n",
" Uninstalling google-api-python-client-1.7.12:\n",
" Successfully uninstalled google-api-python-client-1.7.12\n",
"Successfully installed cloud-tpu-client-0.10 google-api-python-client-1.8.0\n",
"Done updating TPU runtime\n",
" Successfully uninstalled torch-1.6.0+cu101\n",
"Uninstalling torchvision-0.7.0+cu101:\n",
" Successfully uninstalled torchvision-0.7.0+cu101\n",
"Copying gs://tpu-pytorch/wheels/torch-nightly+20200707-cp36-cp36m-linux_x86_64.whl...\n",
"\\ [1 files][107.5 MiB/107.5 MiB] \n",
"Operation completed over 1 objects/107.5 MiB. \n",
"Copying gs://tpu-pytorch/wheels/torch_xla-nightly+20200707-cp36-cp36m-linux_x86_64.whl...\n",
"\\ [1 files][123.8 MiB/123.8 MiB] \n",
"Operation completed over 1 objects/123.8 MiB. \n",
"Copying gs://tpu-pytorch/wheels/torchvision-nightly+20200707-cp36-cp36m-linux_x86_64.whl...\n",
"/ [1 files][ 2.2 MiB/ 2.2 MiB] \n",
"Operation completed over 1 objects/2.2 MiB. \n",
"Processing ./torch-nightly+20200707-cp36-cp36m-linux_x86_64.whl\n",
"Requirement already satisfied: future in /usr/local/lib/python3.6/dist-packages (from torch==nightly+20200707) (0.16.0)\n",
"Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from torch==nightly+20200707) (1.18.5)\n",
"\u001b[31mERROR: fastai 2.0.16 requires torchvision>=0.7, which is not installed.\u001b[0m\n",
"Installing collected packages: torch\n",
"Successfully installed torch-1.7.0a0+12b5bdc\n",
"Processing ./torch_xla-nightly+20200707-cp36-cp36m-linux_x86_64.whl\n",
"Installing collected packages: torch-xla\n",
"Successfully installed torch-xla-1.6+5430aca\n",
"Processing ./torchvision-nightly+20200707-cp36-cp36m-linux_x86_64.whl\n",
"Requirement already satisfied: pillow>=4.1.1 in /usr/local/lib/python3.6/dist-packages (from torchvision==nightly+20200707) (7.0.0)\n",
"Requirement already satisfied: torch in /usr/local/lib/python3.6/dist-packages (from torchvision==nightly+20200707) (1.7.0a0+12b5bdc)\n",
"Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from torchvision==nightly+20200707) (1.18.5)\n",
"Requirement already satisfied: future in /usr/local/lib/python3.6/dist-packages (from torch->torchvision==nightly+20200707) (0.16.0)\n",
"Installing collected packages: torchvision\n",
"Successfully installed torchvision-0.8.0a0+86b6c3e\n",
"Reading package lists... Done\n",
"Building dependency tree \n",
"Reading state information... Done\n",
"libopenblas-dev is already the newest version (0.2.20+ds-4).\n",
"The following NEW packages will be installed:\n",
" libomp5\n",
"0 upgraded, 1 newly installed, 0 to remove and 6 not upgraded.\n",
"Need to get 234 kB of archives.\n",
"After this operation, 774 kB of additional disk space will be used.\n",
"Get:1 http://archive.ubuntu.com/ubuntu bionic/universe amd64 libomp5 amd64 5.0.1-1 [234 kB]\n",
"Fetched 234 kB in 1s (325 kB/s)\n",
"Selecting previously unselected package libomp5:amd64.\n",
"(Reading database ... 144617 files and directories currently installed.)\n",
"Preparing to unpack .../libomp5_5.0.1-1_amd64.deb ...\n",
"Unpacking libomp5:amd64 (5.0.1-1) ...\n",
"Setting up libomp5:amd64 (5.0.1-1) ...\n",
"Processing triggers for libc-bin (2.27-3ubuntu1.2) ...\n",
"/sbin/ldconfig.real: /usr/local/lib/python3.6/dist-packages/ideep4py/lib/libmkldnn.so.0 is not a symbolic link\n",
"\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "aJMhjxPPaPo8",
"outputId": "1dbe5589-0ffe-43e1-db81-c575b8d52863",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 220
}
},
"source": [
"!pip freeze | grep torch \n",
"!pip freeze | grep fast"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"torch==1.7.0a0+12b5bdc\n",
"torch-xla==1.6+5430aca\n",
"torchsummary==1.5.1\n",
"torchtext==0.3.1\n",
"torchvision==0.8.0a0+86b6c3e\n",
"fastai==2.0.16\n",
"fastai-xla-extensions==0.0.1\n",
"fastcore==1.1.0\n",
"fastdtw==0.3.4\n",
"fastprogress==1.0.0\n",
"fastrlock==0.5\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "OvGIdagFuGxK"
},
"source": [
"from fastai.optimizer import Optimizer"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "O5vVQw3JM7yA"
},
"source": [
"import fastai_xla_extensions.core"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "mlAKa0RsbOei"
},
"source": [
"from fastai.vision.all import *"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "rHE_YF1_wHeO",
"outputId": "72c0e8ab-37b2-4c7c-8a15-42c395ed1d8b",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 35
}
},
"source": [
"default_device()"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"device(type='cpu')"
]
},
"metadata": {
"tags": []
},
"execution_count": 8
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "OD7QTq_ulNZK",
"outputId": "f0a32725-5a0d-4516-d8b7-e94f5814d8bd",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 35
}
},
"source": [
"path = untar_data(URLs.MNIST_SAMPLE)\n",
"Path.BASE_PATH = path; path.ls()"
],
"execution_count": null,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/html": [
""
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(#3) [Path('train'),Path('labels.csv'),Path('valid')]"
]
},
"metadata": {
"tags": []
},
"execution_count": 9
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "jZHPPO8I8RbC",
"outputId": "1c1873ea-75d4-4d6c-aec2-6799a59cbf4f",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 35
}
},
"source": [
"(path/'train').ls()"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(#2) [Path('train/3'),Path('train/7')]"
]
},
"metadata": {
"tags": []
},
"execution_count": 10
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Obb_HBYU4wM0"
},
"source": [
"# multi TPU"
]
},
{
"cell_type": "code",
"metadata": {
"id": "FwqEIE9aA9ZT",
"outputId": "14621a2c-e022-4526-d68c-2c9d8ed2f305",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
}
},
"source": [
"# Configures training (and evaluation) parameters\n",
"import torchvision\n",
"from torchvision import datasets\n",
"import torchvision.transforms as transforms\n",
"import torch_xla.distributed.parallel_loader as pl\n",
"import torch_xla.core.xla_model as xm\n",
"import torch_xla.distributed.xla_multiprocessing as xmp\n",
"from fastai.vision.all import *\n",
"import time\n",
"from fastai.test_utils import *\n",
"print(f'torch version {torch.__version__}')\n",
"\n",
"import pdb\n",
"\n",
"path = untar_data(URLs.MNIST_SAMPLE)\n",
"Path.BASE_PATH = path; path.ls()\n",
"\n",
"def debug_on(*exceptions):\n",
" if not exceptions:\n",
" exceptions = (AssertionError, )\n",
" def decorator(f):\n",
" @functools.wraps(f)\n",
" def wrapper(*args, **kwargs):\n",
" try:\n",
" return f(*args, **kwargs)\n",
" except exceptions:\n",
" pdb.post_mortem(sys.exc_info()[2])\n",
" return wrapper\n",
" return decorator\n",
"\n",
"\n",
"class Lenet2(nn.Module):\n",
" def __init__(self):\n",
" super(Lenet2, self).__init__()\n",
" self.conv1 = nn.Conv2d(3, 6, 3)\n",
" self.conv2 = nn.Conv2d(6, 16, 3)\n",
" self.fc1 = nn.Linear(400, 120)\n",
" self.fc2 = nn.Linear(120, 84)\n",
" self.fc3 = nn.Linear(84, 2) # Only 2 outputs instead of 10\n",
" @debug_on(KeyError)\n",
" def forward(self, x):\n",
" # Max pooling over a (2, 2) window\n",
" x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))\n",
" # If the size is a square you can only specify a single number\n",
" x = F.max_pool2d(F.relu(self.conv2(x)), 2)\n",
" x = x.view(-1, self.num_flat_features(x))\n",
" x = F.relu(self.fc1(x))\n",
" x = F.relu(self.fc2(x))\n",
" x = self.fc3(x)\n",
" return x\n",
" @debug_on(KeyError)\n",
" def num_flat_features(self, x):\n",
" size = x.size()[1:] # all dimensions except the batch dimension\n",
" num_features = 1\n",
" for s in size:\n",
" num_features *= s\n",
" return num_features\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"def map_fn(index, flags):\n",
" # from fastai.callback.all import *\n",
" dede = xm.xla_device()\n",
" print(f'index is {index} and flags are {flags}')\n",
" #xm.rendezvous('init')\n",
"\n",
" if not xm.is_master_ordinal():\n",
" print(f\"this is {dede}:{index} entering download once\")\n",
" xm.rendezvous('download_only_once')\n",
" \n",
" dblock = DataBlock(\n",
" splitter = GrandparentSplitter(),\n",
" item_tfms = Resize(28),\n",
" blocks = (ImageBlock, CategoryBlock),\n",
" get_items = get_image_files,\n",
" get_y = parent_label,\n",
" batch_tfms = []\n",
" )\n",
" if xm.is_master_ordinal():\n",
" xm.master_print(f'this is {dede} exiting download once')\n",
" xm.rendezvous('download_only_once')\n",
" xm.master_print('creating lenet_tpu')\n",
" lenet_tpu = Lenet2()\n",
" xm.master_print('lenet created, goiing for dls_tpu')\n",
" dls_tpu = dblock.dataloaders(path, device=dede)\n",
" xm.master_print(f'creating learner!!! for {dede}')\n",
" \n",
"\n",
" tpu_learner = Learner(dls_tpu,\n",
" lenet_tpu,\n",
" metrics=accuracy, \n",
" loss_func=F.cross_entropy,\n",
" cbs=[])\n",
" print(f\"################ fit for {dede}\")\n",
" xm.master_print(f'***** fit for {dede}')\n",
" tpu_learner.fit(1, cbs=[fastai_xla_extensions.core.XLAOptCallback()])\n",
" xm.master_print(f'***** end fit for {dede}')\n",
" t = torch.randn((2, 2), device=dede)\n",
" print(\"################Process\", index ,\"is using\", xm.xla_real_devices([str(dede)])[0])\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"# https://stackoverflow.com/a/9929970/682603\n",
"# excepthook\n",
"# \n",
"import traceback\n",
"import logging\n",
"import os, sys\n",
"\n",
"def my_excepthook(excType, excValue, traceback, logger):\n",
" print(\"=== *** @@@ ### %%% === *** @@@ ### %%% === *** @@@ ### %%% === *** @@@ ### %%% === *** @@@ ### %%% === *** @@@ ### %%% Logging an uncaught exception\",\n",
" exc_info=(excType, excValue, traceback))\n",
"\n",
"sys.excepthook = my_excepthook\n",
"sys.unraisablehook = my_excepthook\n",
"##############threading.excepthook\n",
"#https://docs.python.org/3/library/sys.html#sys.excepthook\n",
"\n",
"\n",
"\n",
"\n",
"print('launching n procs')\n",
"\n",
"flags={}\n",
"flags['batch_size'] = 32\n",
"flags['num_workers'] = 8\n",
"flags['num_epochs'] = 1\n",
"flags['seed'] = 1234\n",
"\n",
"xmp.spawn(map_fn, args=(flags,), nprocs=8, start_method='fork')\n",
"print('end of launch')"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"torch version 1.7.0a0+12b5bdc\n",
"launching n procs\n",
"index is 0 and flags are {'batch_size': 32, 'num_workers': 8, 'num_epochs': 1, 'seed': 1234}\n",
"this is xla:1 exiting download once\n",
"index is 6 and flags are {'batch_size': 32, 'num_workers': 8, 'num_epochs': 1, 'seed': 1234}\n",
"this is xla:0:6 entering download once\n",
"index is 4 and flags are {'batch_size': 32, 'num_workers': 8, 'num_epochs': 1, 'seed': 1234}\n",
"this is xla:0:4 entering download once\n",
"index is 7 and flags are {'batch_size': 32, 'num_workers': 8, 'num_epochs': 1, 'seed': 1234}\n",
"this is xla:0:7 entering download once\n",
"index is 5 and flags are {'batch_size': 32, 'num_workers': 8, 'num_epochs': 1, 'seed': 1234}\n",
"this is xla:0:5 entering download once\n",
"index is 1 and flags are {'batch_size': 32, 'num_workers': 8, 'num_epochs': 1, 'seed': 1234}\n",
"this is xla:0:1 entering download once\n",
"index is 2 and flags are {'batch_size': 32, 'num_workers': 8, 'num_epochs': 1, 'seed': 1234}\n",
"this is xla:0:2 entering download once\n",
"index is 3 and flags are {'batch_size': 32, 'num_workers': 8, 'num_epochs': 1, 'seed': 1234}\n",
"this is xla:0:3 entering download once\n",
"creating lenet_tpu\n",
"lenet created, goiing for dls_tpu\n",
"################ fit for xla:0\n",
"creating learner!!! for xla:1\n",
"################ fit for xla:0\n",
"################ fit for xla:1\n",
"***** fit for xla:1\n",
"################ fit for xla:0\n",
"################ fit for xla:0\n",
"################ fit for xla:0\n"
],
"name": "stdout"
},
{
"output_type": "display_data",
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>0.693988</td>\n",
" <td>0.693839</td>\n",
" <td>0.504416</td>\n",
" <td>02:10</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "display_data",
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>0.693849</td>\n",
" <td>0.693839</td>\n",
" <td>0.504416</td>\n",
" <td>02:10</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "display_data",
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>0.693605</td>\n",
" <td>0.693839</td>\n",
" <td>0.504416</td>\n",
" <td>02:10</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "display_data",
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>0.694204</td>\n",
" <td>0.693839</td>\n",
" <td>0.504416</td>\n",
" <td>02:10</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "stream",
"text": [
"################ fit for xla:0\n"
],
"name": "stdout"
},
{
"output_type": "display_data",
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>0.693364</td>\n",
" <td>0.693839</td>\n",
" <td>0.504416</td>\n",
" <td>02:10</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "display_data",
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>0.693701</td>\n",
" <td>0.693839</td>\n",
" <td>0.504416</td>\n",
" <td>02:10</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "display_data",
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>0.694203</td>\n",
" <td>0.693839</td>\n",
" <td>0.504416</td>\n",
" <td>02:10</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "stream",
"text": [
"################ fit for xla:0\n"
],
"name": "stdout"
},
{
"output_type": "display_data",
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>0.694207</td>\n",
" <td>0.693839</td>\n",
" <td>0.504416</td>\n",
" <td>02:10</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "stream",
"text": [
"################Process 5 is using TPU:5\n",
"################Process 2 is using TPU:2\n",
"***** end fit for xla:1\n",
"################Process 4 is using TPU:4\n",
"################Process 0 is using TPU:0\n",
"################Process 1 is using TPU:1\n",
"################Process 3 is using TPU:3\n",
"################Process 7 is using TPU:7\n",
"################Process 6 is using TPU:6\n",
"end of launch\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "P0E0dthTvO-9",
"outputId": "dbeca9c4-f04a-4dbf-dabc-74e68bd5e3f3",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 260
}
},
"source": [
"dls = DataBlock(\n",
" splitter = GrandparentSplitter(),\n",
" item_tfms = Resize(28),\n",
" blocks = (ImageBlock, CategoryBlock),\n",
" get_items = get_image_files,\n",
" get_y = parent_label,\n",
" batch_tfms = []\n",
" ).dataloaders(path, device=xm.xla_device())\n",
"learner = cnn_learner(dls,\n",
" resnet34,\n",
" metrics=accuracy, \n",
" loss_func=F.cross_entropy,\n",
" cbs=[GetPred])\n",
"print(list(learner.model.parameters())[-1])\n",
"learner.fit(2)\n",
"print(list(learner.model.parameters())[-1])"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"Parameter containing:\n",
"tensor([[ 0.0128, 0.1270, -0.0180, ..., 0.0178, 0.0626, -0.1097],\n",
" [-0.0255, 0.0215, 0.0019, ..., 0.1076, 0.0628, 0.0713]],\n",
" requires_grad=True)\n"
],
"name": "stdout"
},
{
"output_type": "display_data",
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>1.249053</td>\n",
" <td>1.040144</td>\n",
" <td>0.497547</td>\n",
" <td>00:21</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>1.300601</td>\n",
" <td>1.029977</td>\n",
" <td>0.512758</td>\n",
" <td>00:21</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "stream",
"text": [
"Parameter containing:\n",
"tensor([[ 0.0128, 0.1270, -0.0180, ..., 0.0178, 0.0626, -0.1097],\n",
" [-0.0255, 0.0215, 0.0019, ..., 0.1076, 0.0628, 0.0713]],\n",
" device='xla:1', requires_grad=True)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "lHEIRtdNYXvY",
"outputId": "6f0a145a-ac6b-4231-e3dd-8056d2b54b01",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 186
}
},
"source": [
"learner.fit(2)\n",
"print(list(learner.model.parameters())[-1])"
],
"execution_count": null,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>1.322935</td>\n",
" <td>1.017787</td>\n",
" <td>0.490677</td>\n",
" <td>00:24</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>1.340775</td>\n",
" <td>1.057548</td>\n",
" <td>0.468597</td>\n",
" <td>00:20</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "stream",
"text": [
"Parameter containing:\n",
"tensor([[ 0.0324, 0.0307, 0.0026, ..., -0.0316, 0.0744, -0.0787],\n",
" [-0.0234, 0.0073, 0.0372, ..., 0.0579, -0.0632, 0.0346]],\n",
" device='xla:1', requires_grad=True)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "Jcon4Df4ZfCN",
"outputId": "95ba9a21-6ee6-4fb5-8c4d-1fff2ebc912f",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 86,
"referenced_widgets": [
"5eb5b3fdae2348258df1d841c2310685",
"96bb07f5e97844358ad5d2e6bfa86fe1",
"eb53ce9afd4b4825ade9ac4fcd8761d1",
"080099fc4e7e485b9513b9ccd43919aa",
"d2674c49b31642bfb19fa859ce111c91",
"57413751c8134ee0aad378133ea958be",
"9bce95278aa447c2813a3d0c2211c763",
"31651a7da3d643c89ba465dd856f49dd"
]
}
},
"source": [
"learner = cnn_learner(dls,\n",
" resnet34,\n",
" metrics=accuracy, \n",
" loss_func=F.cross_entropy,\n",
" cbs=[GetPred])"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"Downloading: \"https://download.pytorch.org/models/resnet34-333f7ec4.pth\" to /root/.cache/torch/hub/checkpoints/resnet34-333f7ec4.pth\n"
],
"name": "stderr"
},
{
"output_type": "display_data",
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "5eb5b3fdae2348258df1d841c2310685",
"version_minor": 0,
"version_major": 2
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=87306240.0), HTML(value='')))"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "stream",
"text": [
"\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "p-Cz3ef1jNbM",
"outputId": "d8be9e1c-ccf5-47fb-901e-20d4d027291d",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 35
}
},
"source": [
"learner.cbs"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(#4) [TrainEvalCallback,Recorder,ProgressCallback,XLAOptCallback]"
]
},
"metadata": {
"tags": []
},
"execution_count": 74
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "n4ZalqlVilPi"
},
"source": [
"pred=None\n",
"class GetPred(Callback):\n",
" def after_batch(self):\n",
" self.learn.pred1=self.learn.pred"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "aiS8L98vaH-y",
"outputId": "ecf172f8-a021-4842-ca0c-d91f3cfd7b9a",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 35
}
},
"source": [
"dls.tfms[0]"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"Pipeline: PILBase.create"
]
},
"metadata": {
"tags": []
},
"execution_count": 66
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "4TM3g62AbP2B",
"outputId": "9d0b0871-9b65-40f1-9956-db9fea0c9a28",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 35
}
},
"source": [
"dls.one_batch()[0].mean(dim=[0,2,3])"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"TensorImage([-1.5735, -1.4792, -1.2504], device='xla:1')"
]
},
"metadata": {
"tags": []
},
"execution_count": 109
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "Zd2_qdzKcoBy",
"outputId": "9f925aeb-cede-4c58-abaf-80d4ad3ebfc8",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 536
}
},
"source": [
"dls.show_batch()"
],
"execution_count": null,
"outputs": [
{
"output_type": "display_data",
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAgAAAAIHCAYAAADpfeRCAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO3de5SV1X0/4L0VkAhRc6mBxmq9EC/Y1aStWiUEFALeMFFcilGJNTEhbW2tl0rRaNXUxGhsVryErhqtaTVLTRHqNTbxbkxUbE0DNUhAUCIGwSgEQWXe3x8m9Kd7z3Auw5xzZj/PWvzhZ855zxa348eX7+w3VlUVAICybNHqBQAAfU8BAIACKQAAUCAFAAAKpAAAQIEUAAAokAIAAAVSABoQY/y3GOMLMcZXY4wLYoyfbfWaoB4xxjXv+LUhxnhFq9cFtbKHmxcdBFS/GOPIEMLCqqrWxxj3CCHcH0I4rKqqua1dGdQvxjg0hLA8hHBoVVUPtno9UC97uDHuADSgqqp5VVWt/+1f/ubXri1cEjRjcgjhlyGEh1q9EGiQPdwABaBBMcarY4xrQwhPhxBeCCHc2eIlQaM+HUL4duV2IJ3LHm6APwJoQoxxyxDC/iGEsSGES6qqeqO1K4L6xBh3CiEsCiHsVlXV4lavB+plDzfOHYAmVFW1oaqqh0MIO4QQvtDq9UADTgwhPOwbJx3MHm6QAtA7BgQzAHSmqSGE61u9CGiCPdwgBaBOMcbtY4xTYoxDY4xbxhgnhhCOCyH8oNVrg3rEGA8IIXwwhHBLq9cCjbCHmzOg1QvoQFV463b/zPBWgVoSQjitqqr/aOmqoH6fDiHMqqpqdasXAg2yh5tgCBAACuSPAACgQAoAABRIAQCAAikAAFAgBQAACtTjjwHGGP2IAA2rqiq2eg32MM1ohz0cgn1Mc7rbx+4AAECBFAAAKJACAAAFUgAAoEAKAAAUSAEAgAIpAABQIAUAAAqkAABAgRQAACiQAgAABVIAAKBACgAAFEgBAIACKQAAUCAFAAAKpAAAQIEUAAAokAIAAAVSAACgQAoAABRIAQCAAikAAFAgBQAACjSg1QsIIYRRo0Yl2ZFHHplkY8aMSbKPfOQjSfbMM88k2fjx47OfvWzZslqWCAD9ijsAAFAgBQAACqQAAECBFAAAKFCsqqr7L8bY/RcbNGLEiCR7+OGHk+y9731vbj1J1tP6/3833XRTNr/yyitrev/KlSuTLDdsyP+pqir9B9bHNscephztsIdDsI9pTnf72B0AACiQAgAABVIAAKBACgAAFGizDQFuv/322fyee+5JspEjR9Z0zWaGAJu95qpVq5IsN1h4xRVXJFmpw4LtMEBleIpmtMMeDqHcfXzvvffW9LoHHnggye6///6aXlcCQ4AAwEYKAAAUSAEAgAIpAABQoM02BLjDDjtk88WLFzd6yZYOAdbqrLPOSrKvf/3rDV+vk7XDANXmGJ4aNmxYks2ZMyfJ9tlnnyQ755xzkuxXv/pVU+updb9+7GMfS7Jjjz224c/96U9/ms1nzpyZZD/60Y+S7L/+678a/uy+0g57OIQyhgDPP//8mrLedsEFF2Tz/jREaAgQANhIAQCAAikAAFAgBQAACtTnjwOu1R/90R/VlN1xxx1JVs9g05lnnplkuSGvWuWGAP/xH/+x4et1snYYoNoce/iuu+5Kso9//OO9/TG9bnMM0dbqxRdfTLLrr78+yWbMmNEXy6lZO+zhEMoYAuzq6mr1Ejap1sHA7gYLW8UQIACwkQIAAAVSAACgQAoAABRoQKsX0J0nn3yypiynnpP3HnnkkST74Q9/WPP736mZAUI6w/Dhw1u9hI7zgQ98IMm6Oy2U/q3Z0/16+4S+etYzduzYmrKcdhsMDMEdAAAokgIAAAVSAACgQAoAABRIAQCAArXtTwH0ti9+8YvZ/POf/3yvfs6pp56aZKNGjar5c+fNm9er64F2NXLkyCT74Ac/mGTLli3ri+XQR+qZus9N/B900EG9uJoQxowZk81rne7Pya27HbkDAAAFUgAAoEAKAAAUSAEAgAL1yyHA3MDfOeeck33tgAHpb0Ezz0gfNGhQku23335JdtVVV2Xfv3Tp0iSbOnVqw+uh95177rlJNnv27BaspLMtXrw4yVasWNGCldCumjniNyc38NfMsF8IIWyxRef+f3TnrhwAaJgCAAAFUgAAoEAKAAAUqOOHAHfaaacky52ylxv2CyE/wNHV1ZVkjzzySJItWbIkyT7xiU8k2ZAhQ5Lsox/9aHY9Oblrjh49Osl+8pOf1HxNGpfbC8ccc0ySXXbZZTVdb9WqVUnW3cmV7eS6667L5u9///trev/atWuT7PXXX29qTbSXek79y+ntE/WaXc+BBx7YSytpD+4AAECBFAAAKJACAAAFUgAAoEAdPwQ4dOjQJMsN3XV3ul/uNLK5c+cm2bRp05IsN7yVe/TvV77ylSTbe++9s+up9e8ndxrdySefnGRr1qzJfg6Ne/nll5Ns1qxZNWWdKneC2lZbbdWCldBJLrjggppfmxv4a+YkwNzAXz2n/uUG/nr7ZMJWcwcAAAqkAABAgRQAACiQAgAABYo9Pfo2xtj4c3FbaOTIkUmWG64LIX+a3/Lly3t9Te904YUXZvPTTjstybbeeusky/1zy50u+OMf/7iB1fWOqqpiyz78Nzp1D7eb6dOnJ9mXvvSlpq554403Jlm7Pfq6HfZwCPbxpuSGVO+7776a3tvd6X79aeCvu33sDgAAFEgBAIACKQAAUCAFAAAK1PEnAebMmzev1UvYpPPOOy+b504IPOKII2q65iWXXJJkhx9+eJI5HZCe5E6ePOWUU3r9cy6//PJevyZlqueEP/6POwAAUCAFAAAKpAAAQIEUAAAoUL8cAuxkDz30UJJ98pOfTLKurq4ky50EuO222yaZIUB6csIJJyTZTjvt1NQ1Z8+enWSdMKxLZ8g9+jenhEf81sMdAAAokAIAAAVSAACgQAoAABRIAQCAAvkpgDZz++23J9mll16aZFXl8eBsHqNHj06yGLOPE8/Kvfbxxx9PsjfeeKO+hUHI/wRUzv33359kJU/857gDAAAFUgAAoEAKAAAUSAEAgAIZAuxHnn/++SR77bXXWrASOsU222yTZO9///uTrJ6h09xrDa3SiHvvvbem1+UG/g466KBeXk3/4w4AABRIAQCAAikAAFAgBQAACmQIsEWmTJmSzb/yla80fM3PfvazSbZq1aqGr0f/t++++ybZ+PHjm7rmSy+9lGS33XZbU9ek/8sN/I0dO7am9zrhrzHuAABAgRQAACiQAgAABVIAAKBAxQ8B7rHHHkn205/+tOHrbbFF2qlqfXxls9dcs2ZNU59DeS677LJev+att96aZP/7v//b659DZ+rudL9aB/4uuOCCmjI2zR0AACiQAgAABVIAAKBACgAAFKiYIcDhw4dn89zAUjOPLs0N5zX7KNR169Yl2YUXXphkzzzzTFOfA/X45S9/mc2vueaaPl4J7WrMmDFJVuuwX3cM/PUedwAAoEAKAAAUSAEAgAIpAABQoGKGAI899thsPmLEiCRrdmivt5166qlJ9q1vfasFK6GTDR06NMkGDRrU8PUef/zxbD537tyGr0nnyg38nX/++U1d88ADD2zq/fTMHQAAKJACAAAFUgAAoEAKAAAUqJghwCeffLLVS3ibiy66KJvnhvuWLVu2uZdDAU488cQk+9CHPtSCldAf5U74q+fUv9wJfw888EATK2JT3AEAgAIpAABQIAUAAAqkAABAgRQAAChQMT8F8OCDD2bzAQOK+S2AXnXGGWe0egl0oPvvvz+b534KgM3LHQAAKJACAAAFUgAAoEAKAAAUyAQcFGLlypVJtm7duiQbPHhwkq1YsSLJVq9e3TsLo9/KDfwZ9msf7gAAQIEUAAAokAIAAAVSAACgQLGqqu6/GGP3X4RNqKoqtnoN9nDPnnrqqSTbZZddkmzcuHFJ9thjj22WNbWTdtjDIdjHNKe7fewOAAAUSAEAgAIpAABQIAUAAArU4xAgANA/uQMAAAVSAACgQAoAABRIAQCAAikAAFAgBQAACqQAAECBFAAAKJACAAAFUgAAoEAKAAAUSAEAgAIpAABQIAUAAAqkADQgxvhvMcYXYoyvxhgXxBg/2+o1QT1ijGve8WtDjPGKVq8LamUPNy9WVdXqNXScGOPIEMLCqqrWxxj3CCHcH0I4rKqqua1dGdQvxjg0hLA8hHBoVVUPtno9UC97uDHuADSgqqp5VVWt/+1f/ubXri1cEjRjcgjhlyGEh1q9EGiQPdwABaBBMcarY4xrQwhPhxBeCCHc2eIlQaM+HUL4duV2IJ3LHm6APwJoQoxxyxDC/iGEsSGES6qqeqO1K4L6xBh3CiEsCiHsVlXV4lavB+plDzfOHYAmVFW1oaqqh0MIO4QQvtDq9UADTgwhPOwbJx3MHm6QAtA7BgQzAHSmqSGE61u9CGiCPdwgBaBOMcbtY4xTYoxDY4xbxhgnhhCOCyH8oNVrg3rEGA8IIXwwhHBLq9cCjbCHmzOg1QvoQFV463b/zPBWgVoSQjitqqr/aOmqoH6fDiHMqqpqdasXAg2yh5tgCBAACuSPAACgQAoAABRIAQCAAikAAFCgHn8KIMZoQpCGVVUVW70Ge5hmtMMeDsE+pjnd7WN3AACgQAoAABRIAQCAAikAAFAgBQAACqQAAECBFAAAKJACAAAFUgAAoEAKAAAUSAEAgAIpAABQIAUAAAqkAABAgRQAACiQAgAABVIAAKBACgAAFEgBAIACKQAAUKABrV5AJxs6dGiSTZkyJcmOPPLIJDvkkEOy16yqKsmuvvrqJJsxY0aSrV69OntNAHgndwAAoEAKAAAUSAEAgAIpAABQoJgbOtv4xRi7/2JhBg0alGSzZ89OsgkTJvTFcsL3v//9JDvssMOSbMOGDX2xnKyqqmLLPvw37GGa0Q57OAT7mOZ0t4/dAQCAAikAAFAgBQAACqQAAECBnARYo9zJfc0M/D366KPZfN26dUm26667Jtn48eOTbODAgUnWyiFAANqXOwAAUCAFAAAKpAAAQIEUAAAokJMAm3DSSSclWW6I78c//nGSLV26NHvN3NDe1772tST767/+6yT7zGc+k2TXX3999nP6QjucomYPt68ddtghyY477rgke+2115Lsyiuv3Cxreqd22MMhtN8+3mKL9P8dc98Pd95555qv+Yd/+IdJtnDhwiT79a9/nWSjRo1KsiFDhiTZH//xH9e8nptvvjnJzjvvvCTLrbHdOAkQANhIAQCAAikAAFAgBQAACmQIsAnvec97kqyrqyvJdt999yR77LHHstc899xzk+ycc85JshUrViTZuHHjkuyZZ57Jfk5faIcBKnu4PRx99NFJ9sUvfjHJ9t577yTLfY9atWpVkm2//fYNrq577bCHQ2i/ffzhD384yR5//PEWrOQtMab/mHr6b1uj7r777iSbNGlSr39ObzMECABspAAAQIEUAAAokAIAAAUq5nHAAwbk/1Y/97nPJVlumOjII49MsuHDh9f02UOHDk2y3BBTCCG8973vTbLbbrstyU455ZQke+WVV2paD/Rk8ODBSZZ7/HRu6DSE/GOyc4+0zj2+Oic34PW+972vpvcC3XMHAAAKpAAAQIEUAAAokAIAAAUqZggw90jdEEL4i7/4iz5eyVuGDRtW82vnzp2bZAb+2t+OO+6YZLnhz5EjRzb1OR/96EeT7Igjjmj4eoMGDUqyfffdt+HrNevNN99MsosuuqgFK+G3Fi9enGSzZs1Ksl122aUvllPzSYBbbbVVku25556bZU2dwB0AACiQAgAABVIAAKBACgAAFEgBAIACxZ6emdxuz6BuRndH726zzTab/bN/8YtfJNmrr75a8/u32267JPvEJz6RZLmfFmildniW+ubYw5/85CeT7IQTTkiy3LPtc3thzJgxvbOwPvarX/0qm2+77bZJlpvSrtXll1+eZGeddVbD16tHO+zhEPrX9+JWOvDAA5Psnnvuqfn9d999d5JNmjSpqTX1he72sTsAAFAgBQAACqQAAECBFAAAKFAxQ4B/9Vd/lc2nT5+eZNtvv32S3X777Uk2f/78ml73s5/9LMlWrlyZXU/Ot771rSTLPUt96tSpNV+zL7TDANXm2MOPPvpokrXyqNyc3PG5r732Wk3vzR0z/Q//8A9JNm/evOz777rrriQbMmRITZ/97LPPJtkhhxySZAsWLKjpes1qhz0cQv/6XtxXckdaz5kzJ8nGjx9f8zUNAQIAHU8BAIACKQAAUCAFAAAKNKDVC+gr3/jGN7L5VVddVdP7N2zY0JvLqcvy5cuT7Pd+7/dasBJCCOGOO+5IsmXLliXZkUce2RfLyQ4lXnrppUmWG4Cq1ZZbbplk//mf/5l9bW7gL3cSYO6Z8hMnTkyyhQsX1rJEeJszzzwzyeoZ+Mu59tprm3p/u3EHAAAKpAAAQIEUAAAokAIAAAUqZgiwO60c7mtG7nS0d73rXUlW6+lv1O5LX/pSkm2xRdqlc4Nzm0NXV1eSNbOvhw4dmmS5AcJ6HmOcO3H0xhtvTDIDfzRi6623TrLTTz+9qWvmBv5yp1x2MncAAKBACgAAFEgBAIACKQAAUKDihwA71Xve854kO/zww5Pslltu6YvlFC83iJfLOsHYsWNryrqTG+SbPHlykuUekw2N+Pa3v51k2267bVPXnDVrVpKtW7euqWu2G3cAAKBACgAAFEgBAIACKQAAUKCYO6Fr4xdj7P6LvWjgwIFJ9sYbb/TFR7edAQPSuczZs2cn2ZtvvplkucfP9vTPd3Orqip9Bmwf66s93Km++c1vJtkxxxyTZNttt12S5R7nG0IIEyZMSLJFixY1sLrWa4c9HIJ9vCm574e1fu+bOXNmNj/ttNOSrFNPju1uH7sDAAAFUgAAoEAKAAAUSAEAgAL1+UmA06ZNS7Lx48cn2YknnphkJTza9mtf+1qSHXzwwUmWe5RqKwf+aH+TJk1KsqOPPjrJcgN/S5YsSbLcsF8InTvwR/vJDYjPmDEjyXKP486dxPnCCy8k2ZVXXpn97E4d+KuHOwAAUCAFAAAKpAAAQIEUAAAokAIAAAXq86OAn3rqqSQbOXJkkp133nlJdvHFF/f2clpm+vTp2fzcc89NshdffDHJ9tlnnyRbtWpV8wvrRe1wjGqpR6gecMABSXbbbbclWW7iP7ePRo0alWQLFixocHWdox32cAjl7uPc98mLLrooyWJM/zHlpvhzP1128803N7i6zuEoYABgIwUAAAqkAABAgRQAAChQnx8FXOtxtbln23fCEOCAAelv6ec///kkyw37hRDC4MGDk+zOO+9MsnYb+KN1xowZk2S54ancwF/u38cbbrghyUoY+KO1ckdV5479rdV1112XZCUM/NXDHQAAKJACAAAFUgAAoEAKAAAUqM9PAswNxF111VVJllvXZZddlmTf/OY3k2zp0qUNrq5722yzTZJNnjy5puzggw9Osqeffjr7OblTrubMmZNk69aty76/nbTDKWolnKCW2x+HH354Te/N/Tt19tlnN72m/qId9nAIZezj3PfEXXfdtab35r7nT5gwIcl+/vOf17+wfsBJgADARgoAABRIAQCAAikAAFCgPh8C3GKLtHNceeWVSfa5z32upuutX78+yW688cYku/fee7Pvzz1GMvco1XHjxiXZiBEjalliuOmmm5Ls5JNPzr429/fTqdphgKo/DU/lHgEdQgiPPPJIkm255ZZJljtJ8+///u+TLPcY1VK1wx4OoX/t4+OOOy6bz5w5M8m23nrrmq550kknJVnuRMtSGQIEADZSAACgQAoAABRIAQCAAvX5EGBObjDw0ksvTbLc4Ny73/3uzbKmWjz//PM1ZR/72MeSrKura7OsqZ20wwBVpw5P7b///kl2++23Z1+be8zvc889l2S5xwYvWbKkgdWVox32cAidu4/322+/JPve976Xfe2QIUNqumbuNL899tijvoUVxhAgALCRAgAABVIAAKBACgAAFGhAqxcQQn4g7owzzkiyO+64I8mOOeaYmj4jNygVQghHH310kq1cuTLJvvzlLyfZv/7rv9b0XqjX1KlTk6y7PZzzs5/9LMleeumlptYE9Tr99NOTrNZhvxBCWLhwYZLlHvNLY9wBAIACKQAAUCAFAAAKpAAAQIHaYgiwVrlH+nb3mF8o2cCBA5Msd+Im9JbcSa2HHnpoU9e86667kix3yiWN8R0BAAqkAABAgRQAACiQAgAABVIAAKBAHfVTAFCKX//61zW/9rXXXkuyyy+/PMlWr17d1JqgJ0cddVSSDR48uOb3z58/P8m++tWvNrUmeuYOAAAUSAEAgAIpAABQIAUAAApkCBDa0IwZM5Js+fLl2dfOnTs3ye67775eXxP05NFHH02yiRMnJtnTTz+dff+ECROS7MUXX2x+YXTLHQAAKJACAAAFUgAAoEAKAAAUKFZV1f0XY+z+i7AJVVXFVq/BHqYZ7bCHQ7CPaU53+9gdAAAokAIAAAVSAACgQAoAABSoxyFAAKB/cgcAAAqkAABAgRQAACiQAgAABVIAAKBACgAAFEgBAIACKQAAUCAFAAAKpAAAQIEUAAAokAIAAAVSAACgQAoAABRIAWhAjPHfYowvxBhfjTEuiDF+ttVrgnrEGNe849eGGOMVrV4X1Moebl6sqqrVa+g4McaRIYSFVVWtjzHuEUK4P4RwWFVVc1u7MqhfjHFoCGF5COHQqqoebPV6oF72cGPcAWhAVVXzqqpa/9u//M2vXVu4JGjG5BDCL0MID7V6IdAge7gBCkCDYoxXxxjXhhCeDiG8EEK4s8VLgkZ9OoTw7crtQDqXPdwAfwTQhBjjliGE/UMIY0MIl1RV9UZrVwT1iTHuFEJYFELYraqqxa1eD9TLHm6cOwBNqKpqQ1VVD4cQdgghfKHV64EGnBhCeNg3TjqYPdwgBaB3DAhmAOhMU0MI17d6EdAEe7hBCkCdYozbxxinxBiHxhi3jDFODCEcF0L4QavXBvWIMR4QQvhgCOGWVq8FGmEPN2dAqxfQgarw1u3+meGtArUkhHBaVVX/0dJVQf0+HUKYVVXV6lYvBBpkDzfBECAAFMgfAQBAgRQAACiQAgAABVIAAKBAPf4UQIzRhCANq6oqtnoN9jDNaIc9HIJ9THO628fuAABAgRQAACiQAgAABVIAAKBACgAAFEgBAIACKQAAUCAFAAAKpAAAQIEUAAAokAIAAAVSAACgQAoAABRIAQCAAikAAFAgBQAACqQAAECBFAAAKJACAAAFUgAAoEAKAAAUSAEAgAIpAABQIAUAAAqkAABAgRQAACiQAgAABVIAAKBACgAAFEgBAIACDWj1AtrRSSedlGQ77LBDkv3u7/5ukk2bNi3Jqqqq+bNXr16dZAcddFCSPfnkkzVfEwDeyR0AACiQAgAABVIAAKBACgAAFCj2NKAWY6x9eq2N7L777kl2xBFHZF+bG9rbcccdkyzG2PzCGpQbDBw3blyStdtgYFVVrftN+41O3cObwwc+8IEk22effZLs8MMPT7JTTjmlqc9euXJlkk2ePDnJHn300SR78803m/rsZrTDHg6h/fbx4MGDk+zss89OsvPPPz/Jurq6std84oknkuxP/uRPalrPFluk/y/b3efUKree3ED22rVrm/qcvtDdPnYHAAAKpAAAQIEUAAAokAIAAAXqqCHAvfbaK8nOOOOMJDv22GOTLDe0EkIIS5YsSbLZs2cn2XXXXZdkixcvzl6zFjvttFM2/8lPfpJkuQGq3MmEb7zxRsPr2RzaYYCq3fZwrfbYY48ky508GUIIH/7wh5NszJgxSZYbqMoNBi5atCjJnn322exn5/z+7/9+ku2yyy41vXe33XZr6rN7Wzvs4RDabx+fcMIJSZb7Hpkbnq7nZNRa9dXnXHbZZUn2d3/3d73+Ob3NECAAsJECAAAFUgAAoEAKAAAUqM8fB5wbWDrttNOS7NBDD02yhx9+OMlyJ/x95zvfSbKLL744u57nn38+yXIn7zVjwoQJSZY7ga07jz32WJK128Afqd/5nd+p6XWHHXZYkuWGjbbbbruaPzs3FHXPPfck2Ze//OUkmz9/fpK99NJLNX92bijqoosuqvn9tL/vf//7SfbUU08lWe77fSebOnVqknXCEGB33AEAgAIpAABQIAUAAAqkAABAgfp8CHDYsGFJNmXKlCS76aabkuxHP/pRkv35n/95kq1YsSLJNmzYUOsSm3L88ccn2T/90z8lWXcnE955551Jlvv9oXd96lOfSrJzzjkn+9pbb701yQ444IAk22+//Wr67FdffTXJXn/99SSbOXNm9v25QdZvfOMbSbZ+/fok66t/L3KefvrpJHv55ZdbsBLqtXz58iQ7+OCDk2z48OFJ1sqTAE899dQkO/nkk3t9PZ3CHQAAKJACAAAFUgAAoEAKAAAUSAEAgAL1+U8B3H333Uk2YsSIJFu2bFmSdXV1bZY1NerEE09MsquvvjrJchPd11xzTfaaucnztWvXNrA66pE7mnn33XfPvnb69Ok1XXPhwoVJ9sMf/jDJvvCFLyRZbmK/3bzvfe/L5pMmTarp/VdccUWSvfLKK02tidbJHRddzxHSfeG73/1ukvkpAACgKAoAABRIAQCAAikAAFCgPh8CzHnuuedavYS3yT3H/V/+5V+SbPTo0UmWG/jLPUP6jjvuaGxxbBbXXnttkuX2QQgh3HDDDUm2cuXKJMsNAa5ataqB1bWnv/zLv8zm++67b5LlhvseeuihXl8T9GTIkCFJljtGuDsPPvhgby6n5dwBAIACKQAAUCAFAAAKpAAAQIHaYgiwLwwbNiybH3/88Uk2bdq0JNt5552TbN68eUk2efLkJMsNg9Fe/ud//ifJTjjhhBaspHNMnDix5tcuWrQoyebPn9+by4G3GTx4cJKdeeaZSVZVVc3XfOKJJ5paU7txBwAACqQAAECBFAAAKJACAAAF6pdDgLkhvr/5m7/JvnbXXXdt+HNyg4G5U+JuvvnmJHv88cez13eWCx4AAAV6SURBVOxvJ03RP+T2ei4LIf/Y7osvvrjX1wQ9+dCHPpRk++23X83vX7JkSZLlHifcydwBAIACKQAAUCAFAAAKpAAAQIH65RBgbtBjwID83+p9991X0zV/8IMfJNm4ceNqeu9RRx2VZBdeeGH2teedd16Sfec730myX/ziFzV9NtRr0KBBSZbbr909Ljl3quKtt97a/MKgDnvttVdT788N/OUGAzuZOwAAUCAFAAAKpAAAQIEUAAAoUOzpUYgxxtqfk9jmuju1bPHixZv9s9/1rnclWXePUj377LOTbODAgUk2adKkJHvhhRcaWN3mU1VVbPUa+tMe7isTJkxIsjvvvLPm9+dO/csNt3aCdtjDIdjHm5IbSM09gjr3iODu7LbbbknWqUOA3e1jdwAAoEAKAAAUSAEAgAIpAABQoGKGADvFVlttlWS33HJLkg0fPjzJ9tlnn82ypka1wwCVPdyzd7/73Ul21113Jdn+++9f8zWHDRuWZCtWrKhvYW2iHfZwCPbxpnz1q19NstNPP72m9z700EPZ/LDDDkuytWvX1rewNmEIEADYSAEAgAIpAABQIAUAAAqkAABAgQa0egG83fr165PsueeeS7LRo0f3xXLo53KT0n/6p3+aZLmfFrrhhhuy13zllVeaXxh0I3fs7yGHHJJkPf2E2//v1FNPzeadOvFfD3cAAKBACgAAFEgBAIACKQAAUCBDgFCwo446quH3XnXVVdn89ddfb/iasClz5sxJsj333DPJckOAs2bNSrJnn322V9bVidwBAIACKQAAUCAFAAAKpAAAQIE6fgjw6KOPTrLvfve7LVgJtLcZM2Yk2V577VXTex944IEke+KJJ5peE/Tkb//2b5PsIx/5SE3vXbBgQZJ95jOfSbI1a9bUv7B+wh0AACiQAgAABVIAAKBACgAAFKjjhwD/4A/+IMkMAVK6rbbaKskmTpyYZDHGJHv11VeT7M/+7M+SrKurq8HVQWrMmDFJdsEFFyTZgAG1/Wfr61//epKVPPCX4w4AABRIAQCAAikAAFAgBQAACtTxQ4B77713ku2+++5JtnDhwuz7N2zY0Otrasa2226bZKNHj06ytWvX9sVy6FBjx45NslGjRtX03qeeeirJli5d2uySIIQQwvDhw7P5tddem2QDBw6s6ZrXX399kv3zP/9zfQsrkDsAAFAgBQAACqQAAECBFAAAKFDHDwHedNNNSZZ7TOmtt96afX/upKmf//znzS9sE7obhMmtZ9iwYUn28Y9/vNfXRP/x7//+7w2/95JLLunFlcDbTZ06NZvvuOOOSVZVVZLNnz8/yXKP+WXT3AEAgAIpAABQIAUAAAqkAABAgTp+CPDmm29Osv/+7/9Osu9973vZ999+++1JNmfOnCS77rrrkmzRokVJlhta+dSnPpVkZ511VnY9e+65Z5JNmzYtyXKntcFvbb311klW6+N7V65c2dvLgY1y3+PqMX78+F5aCe4AAECBFAAAKJACAAAFUgAAoEAKAAAUqON/CiBnwYIFSbbzzjtnX7vLLrsk2ZQpU5Is99MCgwcPTrI1a9Yk2YgRI5Ls5Zdfzq7n9NNPT7Jrrrkm+1oIIYSjjjoqyXI/jZKzdOnSJPMTJlAGdwAAoEAKAAAUSAEAgAIpAABQoH45BFiP3HG+F198cU0ZtIOXXnqp4fdeeumlSfb66683sxzo0RNPPJHNjz/++CSbPn16kjWz33k7dwAAoEAKAAAUSAEAgAIpAABQoNjTiWExxtqOE4OMqqpiq9dgD9OMdtjDIdjHNKe7fewOAAAUSAEAgAIpAABQIAUAAArU4xAgANA/uQMAAAVSAACgQAoAABRIAQCAAikAAFAgBQAACvT/AMAOqQ6IKVzWAAAAAElFTkSuQmCC\n",
"text/plain": [
"<Figure size 648x648 with 9 Axes>"
]
},
"metadata": {
"tags": [],
"needs_background": "light"
}
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "8r37ywpAf3ye",
"outputId": "e9bca680-474e-46f9-c1b9-7b1e5face353",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 536
}
},
"source": [
"learner.fit(2)"
],
"execution_count": null,
"outputs": [
{
"output_type": "display_data",
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAgAAAAIHCAYAAADpfeRCAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO3de7iNZf7H8fu2ndvRGNmZGCo1NdSUpgghIlE6UKODEDrIdHA15rrKISHRXKWUUXI1KDEmp5QOY+QYY+RqTJGtUM4TcrYrnt8fzZj5+X6XefZ61lrPWvv7fl1Xf8xn77Weu+neq0+P774fHwSBAwAAtpSKewEAACDzKAAAABhEAQAAwCAKAAAABlEAAAAwiAIAAIBBFAAAAAyiABST937/cX8d8d6PintdQHF471/13m/13u/13q/13veIe01AcbGPo/EcBJQ8732+c26bc65tEAQL4l4PEJb3vq5zbl0QBEXe+3Odcx8459oFQbAi3pUB4bGPo+EOQDQdnHM7nHML414IUBxBEHwSBEHRv//nv/46K8YlAcXGPo6GAhBNF+fchIDbKMhB3vvR3vuDzrk1zrmtzrm3Y14SUGzs4+TxRwBJ8t7Xcs594ZyrEwTB+rjXAyTDe5/nnLvMOdfcOTc8CILv4l0RUHzs4+RwByB5nZ1zi/iXP3JZEARHgiBY5Jyr4Zy7N+71AMlgHyeHApC8O5xz4+NeBJAipR1/dorcxz4uBgpAErz3jZxzpzvnpsa9FqC4vPfVvPedvPf53vs87/1VzrlbnHNz414bEBb7ODpmAJLgvX/ROVcxCILOca8FKC7v/anOuT85537hfviPgI3OueeCIBgb68KAYmAfR0cBAADAIP4IAAAAgygAAAAYRAEAAMAgCgAAAAZRAAAAMKj0ib7ovedXBJC0IAh83GtgDyOKbNjDzrGPEU2ifcwdAAAADKIAAABgEAUAAACDKAAAABhEAQAAwCAKAAAABlEAAAAwiAIAAIBBFAAAAAyiAAAAYBAFAAAAgygAAAAYRAEAAMAgCgAAAAZRAAAAMIgCAACAQRQAAAAMogAAAGAQBQAAAINKx70AANH89Kc/FVn16tVF9vLLL4ts3bp1Imvfvn2o606YMEHN+/XrJ7LNmzeHek8AmcMdAAAADKIAAABgEAUAAACDKAAAABjkgyBI/EXvE38R+B+CIPBxr8HCHh42bJjI+vTpI7KioqKUXrdcuXJqPnfuXJH16tVLZBs2bEjpetIhG/awczb2saZChQoi04ZUmzZtGur9unfvLrIyZcqo3+u9/Eev/fvy9ddfF9mMGTNENnv2bJEdPnxYvXaqJdrH3AEAAMAgCgAAAAZRAAAAMIgCAACAQQwBIm2yYYDKwh6uVKmSyM477zyRLVu2LKXXbdy4sZpPnz5dZF26dBHZnDlzUrqedMiGPexc7u5jbcBu7Nix6vfWqVNHZPn5+SKrV69eqGtrQ6YLFy4M9driuP7660V28skni+zxxx8PlaUDQ4AAAOAYCgAAAAZRAAAAMIgCAACAQTwOGMhxe/fuFVmqB/40ixcvVvMDBw6k/dqIlzZ4qg33tWvXTmR79uxR33Pfvn0i++CDD0Q2fvx4kb300ksi++6770T27bffqteOQhvkKywsFFm1atVSfu2ouAMAAIBBFAAAAAyiAAAAYBAFAAAAgxgCBJCUa6+9Vs2rVq0qsiuvvFJkuXASIHTaCX0dOnQQmTYQetFFF6nvuWPHjugLi8G4ceNEtmvXLpFNmjQpE8spFu4AAABgEAUAAACDKAAAABhEAQAAwKCcGgKsUKGCyLSTpoqjfv36Ivvoo4+Sfj/v5VMXtUcub9q0SX390qVLk742kArlypUTmTbE98QTT6iv135OlyxZEn1hyBrbtm0T2SWXXCKy0047TWS5MuxXsWJFkT333HMi0x6Lfeedd4osG38GuAMAAIBBFAAAAAyiAAAAYBAFAAAAgzI+BFi9enWRderUSWTaaVEtW7YUWUFBQaT1hB3aS/X7HTp0SH29doLUggULRPbWW2+JbPLkyWGWCKNOOeUUkT3//PMiq1Onjsh++ctfhr7OihUrRDZ//vzQr0f2O3r0qMhWrlwZw0qiS/SY3tGjR4vs+uuvF9nIkSNF9tprr0VfWAZwBwAAAIMoAAAAGEQBAADAIAoAAAAGUQAAADDIn2ji3Xuf9Dj8bbfdpuYTJkwI9fr9+/eLbPPmzSIbO3asyNauXSsybWo+U7Tjis8555zQr3/44YdFph2xqf09du7cWWR79uwJfe0ogiCQvxKRYVH2cLZJdPTuPffcE+r1eXl5IjvppJNCvVab+p43b576vY899pjIPvzww1DXyTbZsIedK1n7OKr8/HyRaXtb88ADD4isV69e6vdqPxvXXHONyBYvXiyy77//PtR6MiXRPuYOAAAABlEAAAAwiAIAAIBBFAAAAAxK2xDgmjVr1FwbXuvXr5/ItONv//73vye7nJymPV+9R48eIks0zHK88847L/KawsiGAaqSNDylPYvcufD/3MPShpqGDh0qsvfeey+l181G2bCHnStZ+ziRbt26iezKK68UWevWrUX2ox/9KOXrKSoqEpl2dPbAgQNFdvjw4ZSvJwqGAAEAwDEUAAAADKIAAABgEAUAAACD0jYEqA1GOOfcQw89JLLGjRuL7JNPPkn20mZ17NhRZJMnTxZZ6dKlM7GcrBigKknDU2eeeaaaz5kzR2TasG3YU/+2bNkisksuuURk27dvD/V+uSwb9rBzJWsf16xZU82XLl0qsoKCglDveeTIEZFpJ8Jqp2nu3r1bfc9HH31UZJdddpnItM/dGTNmqO8ZF4YAAQDAMRQAAAAMogAAAGAQBQAAAIPSNg02ZMgQNa9fv77ImjVrJjKGAIH/74svvlDzRo0aiaxKlSoi04YA69WrJ7Jx48aJ7M477xTZsGHD1PUAJ/Lkk0+quTbwd/DgQZFNmzZNZMOHDxfZ6tWrk1jdf/ztb38T2bZt20T2yCOPiCzbhgAT4Q4AAAAGUQAAADCIAgAAgEEUAAAADErbSYCJ9OzZU2Rdu3YVWdu2bUW2Z8+eVC+nRBk1apTI7rvvPpGVKpWZ3pcNp6iVpBPUMqVatWoi005pu/nmm9XXa8NTuSob9rBzJWsfn3HGGWrepEkTkS1fvlxkiR41n2r5+fkiW7lypci04cVKlSqlZU3J4iRAAABwDAUAAACDKAAAABhEAQAAwKDMPBf2v4wdO1Zk2uNMi4qKMrGcnKWdnvirX/1KZJs3b87EclCC7NixQ2SHDh0SWa1atdTXl6QhQKTe+vXri5XHpXz58iLTBhjfe++9TCwnLbgDAACAQRQAAAAMogAAAGAQBQAAAIMyPgSo2bRpU9xLyGraCVlTp04VmfboTO1ERQBAajzxxBNxLyFp3AEAAMAgCgAAAAZRAAAAMIgCAACAQRQAAAAMyorfAsB/1K9fX2RvvfWWyCZNmiSyRx99VGS7du1KzcIA5KQyZcqo+fTp00XWoUMHkVk9lv3FF18U2eHDh0X21VdfZWI5acEdAAAADKIAAABgEAUAAACDKAAAABjEEGAGVKlSRWTDhg1Tv7dHjx4i++yzz0T29NNPi4yBv+xXqVIlkb366qsi69Kli8h2796dljUdTzs+ukaNGhm5NlLv4YcfVvPly5eLzOLAX4MGDdS8TZs2ItN+Vjdu3JjyNWUKdwAAADCIAgAAgEEUAAAADKIAAABgEEOAKdaqVSuRvfzyyyI7/fTT1devWLFCZJdeemn0hSErlC1bVmTa0F3VqlVFFnUIUNtHf/jDH0RWq1YtkWmnyR04cCDSepAZffr0UfNrr702wyvJrHLlyolMOy313nvvVV//9ddfi+yBBx6IvrAswh0AAAAMogAAAGAQBQAAAIMoAAAAGMQQYAQDBgwQ2cCBA0WmDZMMHTpUfc8hQ4ZEXxiylvY40c2bN4tszpw5ItuwYYP6ntpJkT/72c9E1qxZM5F579X3PN7gwYNF9s4774R6LeK1Z88eNW/durXIli5dmu7lpEWLFi1E1rt3b5G1b99eZDNmzFDfs3///iLTfn5zGXcAAAAwiAIAAIBBFAAAAAyiAAAAYJAPgiDxF71P/EVjtIE/7VQpbUjk6quvFtmSJUtSs7AsFgRBuAmzNMqFPawNJk2bNi3Se2rDfYcOHRLZ2rVrRaY9qvpPf/qTyE702VFSZMMedi7aPk70iN+DBw+KTDsd8KOPPgr12rBKldL/u7N8+fIia968ucgeeeQRkTVs2FBk2t+39rj1mTNnquuJ8veYbRLtY+4AAABgEAUAAACDKAAAABhEAQAAwCCGABVTpkwRWbt27UT2/vvvi+yGG25Iy5pyUTYMUOXCHq5cubLIfve734msW7duod+zsLBQZNrJZtpwH/4jG/awc9H2sfa4aeecmzRpksjy8/NDvefUqVNFlmjY8HgFBQVqrj1KPSzt5+Xpp58W2Y4dO5K+Ri5jCBAAABxDAQAAwCAKAAAABlEAAAAwiCFAhfboU+3/J+10QPxHNgxQWd3DSI1s2MPOpWcfn3XWWSJ7/fXXRVa/fv2UXnfr1q1qrp3IN3/+fJGtWbNGZKtWrYq+sBKMIUAAAHAMBQAAAIMoAAAAGEQBAADAIAoAAAAG8VsASJtsmKBmDyOKbNjDzrGPEQ2/BQAAAI6hAAAAYBAFAAAAgygAAAAYRAEAAMAgCgAAAAZRAAAAMIgCAACAQRQAAAAMogAAAGAQBQAAAIMoAAAAGEQBAADAIAoAAAAGnfBxwAAAoGTiDgAAAAZRAAAAMIgCAACAQRQAAAAMogAAAGAQBQAAAIMoAAAAGEQBAADAIAoAAAAGUQAAADCIAgAAgEEUAAAADKIAAABgEAUAAACDKADF5L3ff9xfR7z3o+JeF1Ac7GOUBN77V733W733e733a733PeJeUy7xQRDEvYac5b3Pd85tc861DYJgQdzrAZLBPkau8t7Xdc6tC4KgyHt/rnPuA+dcuyAIVsS7stzAHYBoOjjndjjnFsa9ECAC9jFyUhAEnwRBUPTv//mvv86KcUk5hQIQTRfn3ISA2yjIbexj5Czv/Wjv/UHn3Brn3Fbn3NsxLyln8EcASfLe13LOfeGcqxMEwfq41wMkg32MksB7n+ecu8w519w5NzwIgu/iXVFu4A5A8jo75xbxoYkcxz5GzguC4EgQBIucczWcc/fGvZ5cQQFI3h3OufFxLwKIiH2MkqS0YwYgNApAErz3jZxzpzvnpsa9FiBZ7GPkMu99Ne99J+99vvc+z3t/lXPuFufc3LjXlitKx72AHNXFOTctCIJ9cS8EiIB9jFwWuB9u949xP/zH7Ebn3INBEMyKdVU5hCFAAAAM4o8AAAAwiAIAAIBBFAAAAAyiAAAAYNAJfwvAe8+EIJIWBIGPew3sYUSRDXvYOfYxokm0j7kDAACAQRQAAAAMogAAAGAQBQAAAIMoAAAAGEQBAADAIAoAAAAGUQAAADCIAgAAgEEUAAAADKIAAABgEAUAAACDKAAAABhEAQAAwCAKAAAABlEAAAAwiAIAAIBBFAAAAAyiAAAAYBAFAAAAgygAAAAYRAEAAMAgCgAAAAZRAAAAMIgCAACAQRQAAAAMogAAAGAQBQAAAIMoAAAAGEQBAADAoNJxLwBANP379xfZoEGDRDZq1CiRPfDAA2lZE5CNWrZsKbKCgoIYVvKDxYsXi2zjxo0Zuz53AAAAMIgCAACAQRQAAAAMogAAAGCQD4Ig8Re9T/zFEqx69eoi27p1awwr+UHZsmVF1q1bN5H95je/EdkZZ5whsgsvvFBkq1atSnJ1iQVB4FP+psVkYQ9v2LBBZDVq1BDZzp07RTZs2DCRjRw5MiXrKgmyYQ87l7v7+OSTTxZZXl5e6Nf37NlTZE2bNk16PQ0aNBBZlSpVkn6/qO6++26RjRs3LuXXSbSPuQMAAIBBFAAAAAyiAAAAYBAFAAAAg8wPAb700ksi69ixo8gGDx4ssokTJ4rs66+/DnXd/Px8NdeGXvr27SuyU089VWTeyzmPgwcPiqx+/foiKywsVNcTRTYMUJWkPdyiRQs1f/PNN0VWrlw5kWk/67t27RKZtheaNGkSZoklTjbsYefi3cfnn3++yC6//PJQr3300UdFFufJe9nmvffeE1nbtm1Tfh2GAAEAwDEUAAAADKIAAABgEAUAAACDzAwBnn322Wq+cuVKkZUvX15k2oDdli1bRPbPf/4z1HoqVaqk5rVr1w71es23334rsmbNmols+fLlSV+jOLJhgKok7eF33nlHzbVHnN54440i00641B4lfNppp4msTJkyYZZY4mTDHnYuPfv4tttuE9m9994rMm0/RPmcypR9+/aJTNvH2ud9OixZskRkN9xwg8i0EzujYggQAAAcQwEAAMAgCgAAAAZRAAAAMKh03AvIlGeeeUbNww6AaAN/u3fvFtkFF1wQ6v20oULn9NPawurcubPIMjXwh9T67W9/K7Irr7xS/V7tUdXa6YCazZs3i2zChAkimzdvnsj69eunvufixYtDXRvx0k4yPXr0aAwr+YH2GOr9+/cn/X7aZ+ntt98uMu2R6ekwfvx4kaVj4K84uAMAAIBBFAAAAAyiAAAAYBAFAAAAgygAAAAYVCKPAl63bp3IEk16FhUViWzgwIEie+qpp0R2yimniOyaa64Js0R1cts557p27SqyW2+9VWTapHXTpk1DXTtTsuEY1VzYw/Xq1RPZ7NmzRbZs2TL19V26dBHZ4cOHk16Pdmz26tWrRfbhhx+qrw/7rPhckA172Ln07OPPP/9cZDVq1BDZxx9/LDLtyPM+ffpEWo/2uX3kyBGRaZ+755xzjsimTZsmMu1Y46i04+S13zYoLCwUWaZ+64KjgAEAwDEUAAAADKIAAABgEAUAAACDcn4IUBsI0Y5XTPT32b17d5FpRzam2vXXX6/mf/zjH0V24MABkTVs2FBkn332WfSFpVA2DFDlwh7etGmTyAoKCkTWpEkT9fWJhgNT6dprrxWZdmSwc85Nnz5dZPfdd5/IDh06FH1haZYNe9i59Ozjc889V2QtWrQQ2ejRo1N9aVWdOnVEdtddd4lMG1LV9mem9O7dW2RjxoyJYSWJMQQIAACOoQAAAGAQBQAAAIMoAAAAGFQ67gUUR8WKFUX29ttvh3rtl19+qebvvvtupDWFUb9+fZE98sgj6vdqJ1+NGzdOZNk28IfkaaeTbdmyRWSJTo/MBO3nbMGCBer3du7cWWQbN24U2aBBg6IvDElbs2ZNqCyKBg0aiEz7HHfOubFjx4qsdu3aKV2PdvKrdqqqc87169dPZNo+1k5FzBXcAQAAwCAKAAAABlEAAAAwiAIAAIBBWTsEqJ3wpw3DXXrppSLTBqiuvvpq9Trbtm1LYnWJaQN/c+bMEVnVqlXV12unqD388MPRF4asVbp01v4YHqMNpyY6yc97eejYQw89JDKGAEs+7TNbO4EwU7R9rA32OacP3W7fvj3la4oTdwAAADCIAgAAgEEUAAAADKIAAABgUNZOHzVv3lxkiR6he7ymTZuKbP369VGXJOTn54tMe3ylNvC3a9cu9T2HDh0afWFAjLRHb7/00ksxrAT4/7RTCLt166Z+7xVXXCGydu3aiSyXT2XlDgAAAAZRAAAAMIgCAACAQRQAAAAMytohwEWLFols9OjRIps7d67I0jHwp5k8ebLI2rRpIzJtKGrEiBHqe65cuTL6woAUq1ChgsgOHjwY+vXr1q1L5XKQI0aOHBkqK46uXbuKTDt98pVXXgn1fuXLl1dz7VHEU6dOFZn2ma+dRpuNuAMAAIBBFAAAAAyiAAAAYBAFAAAAg7w2oHbsi94n/mIJFvZRxNqpUNqjXbXHpl5wwQXqtTM1wJgJQRDIyZwMs7qHU619+/YiGz9+vPq98+fPF9mNN94osqNHj0ZfWJplwx52jn2cKpUrVxaZ9rh25/RHzWuWLFkiMu002jgl2sfcAQAAwCAKAAAABlEAAAAwiAIAAIBBFAAAAAzK2qOA4/T666+LrHXr1iLTfoPi008/FVmfPn1EVpKm/VGy5Ofni2z69Oki27Bhg/r6e+65R2S5MPGPkm/fvn0ie/fdd9XvDftbALmMOwAAABhEAQAAwCAKAAAABlEAAAAwiCFAhTagpz1vWtO/f3+R/fnPf468JiBTtIE/beB17Nix6uu3bduW8jUBxVWvXj2RdejQQWTaZ3ZxvP3225FeHyfuAAAAYBAFAAAAgygAAAAYRAEAAMAg80OA2qlnV111lci0IagJEyaIbPbs2alZGJABd999t8gaNWokskOHDols3rx5aVkT4vPss8+KrHHjxiLr1q2byFatWpWWNR3v3HPPFVnZsmVFNnPmTJHVrFkz0rV37NghMu3fA7mCOwAAABhEAQAAwCAKAAAABlEAAAAwyPwQ4GOPPSayWrVqiayoqEhkQ4cOFdmRI0dSsi4g1apXry6yvn37ikwbqNKGpzjxL7f94he/EFmbNm1EdtZZZ4msYsWKKV/PTTfdJLLvv/9eZJMnTxZZXl5eytfzwgsviGzKlCki27JlS8qvnSncAQAAwCAKAAAABlEAAAAwiAIAAIBBZoYAW7Zsqebdu3cP9fqBAweKbN26dZHWBGTSM888I7Jq1aqJbNCgQSJj4K/kqVu3rsi0gT/N73//e5Hdd999IhsxYoTIfvKTn6jvWbVqVZFpJ7BGGfhbvHixyO644w71e7Xhvu+++y7pa2cj7gAAAGAQBQAAAIMoAAAAGEQBAADAIK8NWRz7oveJv5jFWrVqJbLp06er31uhQgWRTZ06VWSdOnWKvjBjgiDwca8hV/dwVNppftqjqnfu3CmyW265JS1rykXZsIedS88+Pvvss0X2xhtviOznP/95qi+dctpwXp8+fUSmPSI4l0/yCyvRPuYOAAAABlEAAAAwiAIAAIBBFAAAAAwqkScBagN/5cuXV79348aNItMekQrkklGjRonsiiuuEFnHjh0zsRxkocLCQpFpQ4DaoLh2imBU2kDqsmXLRLZ8+XKRPfvssyLbu3dvahZWgnEHAAAAgygAAAAYRAEAAMAgCgAAAAZRAAAAMCjnjwJu166dyGbNmiWyHTt2qK9v0aKFyFavXh19YciKY1RzYQ+nwzfffCMy7cjTd955R2QDBgwQ2f79+1OzsByTDXvYuXj3sXZk8MUXX5zy62zfvl1k8+bNS/l1LOIoYAAAcAwFAAAAgygAAAAYRAEAAMCgnD8KePfu3aG+7/HHH1dzBv5QEp1yyilxLwElhHZksJYh93AHAAAAgygAAAAYRAEAAMAgCgAAAAbl/EmAyF7ZcIoaexhRZMMedo59jGg4CRAAABxDAQAAwCAKAAAABlEAAAAw6IRDgAAAoGTiDgAAAAZRAAAAMIgCAACAQRQAAAAMogAAAGAQBQAAAIMoAAAAGEQBAADAIAoAAAAGUQAAADCIAgAAgEEUAAAADKIAAABgEAUAAACDKADF5L3ff9xfR7z3o+JeF1Ac3vtXvfdbvfd7vfdrvfc94l4TUBx8FkfngyCIew05y3uf75zb5pxrGwTBgrjXA4Tlva/rnFsXBEGR9/5c59wHzrl2QRCsiHdlQPHxWZwc7gBE08E5t8M5tzDuhQDFEQTBJ0EQFP37f/7rr7NiXBIQBZ/FSaAARNPFOTch4DYKcpD3frT3/qBzbo1zbqtz7u2YlwQki8/iJPBHAEny3tdyzn3hnKsTBMH6uNcDJMN7n+ecu8w519w5NzwIgu/iXRFQPHwWJ487AMnr7JxbxIZDLguC4EgQBIucczWcc/fGvR4gCXwWJ4kCkLw7nHPj414EkCKlHTMAyE18FieJApAE730j59zpzrmpca8FKC7vfTXvfSfvfb73Ps97f5Vz7hbn3Ny41wYUB5/F0ZSOewE5qotzbloQBPviXgiQhMD9cLt/jPvhPwI2OuceDIJgVqyrAoqPz+IIGAIEAMAg/ggAAACDKAAAABhEAQAAwCAKAAAABp3wtwC890wIImlBEPi418AeRhTZsIedYx8jmkT7mDsAAAAYRAEAAMAgCgAAAAZRAAAAMIgCAACAQRQAAAAMogAAAGAQBQAAAIMoAAAAGEQBAADAIAoAAAAGUQAAADCIAgAAgEEUAAAADKIAAABgEAUAAACDKAAAABhEAQAAwCAKAAAABlEAAAAwiAIAAIBBFAAAAAyiAAAAYBAFAAAAgygAAAAYRAEAAMAgCgAAAAZRAAAAMIgCAACAQRQAAAAMKh33AnJFXl6eyM4//3yRXXbZZSI7++yzRXbRRRep12nWrJnItm3bJrLmzZuLbO3atep7Ih4dOnQQWX5+vsi6d+8usnHjxkW69p133imyyy+/XGRHjx6NdB1NqVLyvyu06wwaNEhkgwcPTvl6AOi4AwAAgEEUAAAADKIAAABgEAUAAACDSuQQYPny5UV20kknqd970003iax169Yia9GihchOPvnkJFZ3YkEQiKygoEBkTz31lMiuu+66lK8H4XTp0kVkzz//vMjKlSsnMm1ormHDhqlZ2H/RBvHSMQQY9toA4sUdAAAADKIAAABgEAUAAACDKAAAABiUtUOAVapUEVn79u1F1qpVK5F16tQpLWvKJjVr1hRZ2BPYkHq1atUSmTbwZ9U333wjsvnz58ewEpxIhQoVRKZ9xjZt2jTpa/To0UPNK1WqJDLt86uwsFBk1atXD/V+b775psgSnT6pncCq7eP9+/err88F3AEAAMAgCgAAAAZRAAAAMIgCAACAQV47ee7YF71P/MUkaaf0vfjiiyK7/fbbU33pjNizZ4/IVq5cKbIPP/xQfb12umDv3r1DXbtixYoiKyoqCvXadAiCwMd28X9Jxx7WVK5cWWTavv7xj38sMu/l/00n+rkMQ3uc8JIlS0K99v777w+VJaINo3bs2FFkM2bMCP2eccmGPexc5vbxK6+8IjLts1jbs/v27RPZP/7xj9DXTvXPgfZ+2omw9erVC/2e2s/QDTfcILJdu3aFfs9MSLSPuQMAAIBBFAAAAAyiAAAAYBAFAAAAgzJ+El+YjbkAAAgASURBVGCHDh1EFufA3+7du0W2atUqkWnDH3Pnzg31fYcPHw69Hu2xw2GHABEfbfgzV0+k3Lt3b6TXL1q0SGRhBxARr6+++kpkH3/8sci0IVPtZMdPP/00NQtLkfz8fJHVrVtX/d7hw4eLrEmTJiIbM2aMyG6++eYkVpd53AEAAMAgCgAAAAZRAAAAMIgCAACAQRQAAAAMyvhvARRnIv54O3fuFJk2ZapNHM+ZM0d9T+2Y3lx+vjNQHLfddpvIBgwYEOk9J06cKLIdO3ZEek9khvbPfsSIESLL1c9Ibd3Lli1Tv/eaa64R2TfffCOy6tWrR19YTLgDAACAQRQAAAAMogAAAGAQBQAAAIMyPgQ4bdo0kTVv3lxk2nPT//KXv4gs6rGlQK4rV66cyE477TSRaceYjhw5UmRHjx4VmTa8+8QTT6jr0Y6JRe7K1YG/qHr27Cky773IFixYkInlpAV3AAAAMIgCAACAQRQAAAAMogAAAGBQxocAgyAQ2cKFCzO9jKx1xRVXxL0E5Jhhw4aJrHfv3iIrVUr2fW3gT7N9+3aRJTpdEygJtEHaoqIikb3//vuZWE5acAcAAACDKAAAABhEAQAAwCAKAAAABnltKO/YF71P/EVEUrNmTTWfMmWKyBo0aBDqPStWrCgybWglU4IgkMdmZVgu7OGCggKRaaf2DR8+XH396aefLrK8vDyRRRkC1GiP53ZOH4r69a9/LbI9e/Ykfe1MyYY97Fxu7ONcVaNGDTVfsWKFyA4dOiSy2rVrp3pJKZdoH3MHAAAAgygAAAAYRAEAAMAgCgAAAAZl/CRAiypUqCCy5557Tv3esAN/hYWFIvv++++LtzBkhTZt2ohs7NixMaykeLRHdjvn3K233ioybdiwa9euqV4SUGwvv/yymlepUkVkgwYNSvdyMoo7AAAAGEQBAADAIAoAAAAGUQAAADCIIcAU0wb+HnzwQZG1b98+0nV69eolsiNHjkR6T8Rj3bp1Ips4cWLo13/22WciGzFiRNLrue6660R21113ieyqq65SX6+dOHj77beL7PPPPxfZ4MGDwywRSEqtWrVEVq9ePfV7FyxYILIhQ4akfE1x4g4AAAAGUQAAADCIAgAAgEEUAAAADGIIMIKwA39RB0dmzZolsnnz5kV6T2SPxYsXh8oyZebMmSL74IMPRPbaa6+pr9dONtROAqxcubLItMcYM9yKZJQtW1Zkffv2FZn2OG7nnHv66adTvqZswx0AAAAMogAAAGAQBQAAAIMoAAAAGEQBAADAIB8EQeIvep/4i8ZoE//acavaEb3FUVhYKLKbbrpJZKtWrYp0nUwIgsDHvQb2cPokOs562rRpItN+C0BTp04dkX355ZfFW1gKZcMedo59nIwaNWqIbP369SLbtGmT+vqLL75YZLt27Yq+sBgk2sfcAQAAwCAKAAAABlEAAAAwiAIAAIBBHAWs0Ab+Jk2aJLJEQ1Bh7NmzR8179uwpslwY+IM9p556atxLABLq3r17qO/r0aOHmufqwF9xcAcAAACDKAAAABhEAQAAwCAKAAAABpkfAtQGmaZMmSKyZs2aJX0NbeDvuuuuU7934cKFSV8HyCUzZ84U2c6dO2NYCXLdQw89JLIBAwaITBuo/uSTT9KyplzAHQAAAAyiAAAAYBAFAAAAgygAAAAYZGYI8LzzzlPzWbNmiezMM89M+jrLli0T2Y033iiybdu2JX0N5IYOHTqILD8/X2Tr1q0T2eLFi9OyplRq3LixmpcqJf+7Yv/+/SKbOHGiyA4cOBB9YTCnbdu2Ivv6669F1q5dO5FZ/izmDgAAAAZRAAAAMIgCAACAQRQAAAAMKpFDgPXr1xfZggUL1O/VHv0b1oQJE0R2zz33iKyoqCjpayD7FBQUiKxTp04iGzJkiMj27dsnsubNm6dkXenUsmVLkbVq1Ur93qNHj4ps9erVItMGcIH/5cUXXxRZo0aNRHb//feLbPPmzWlZU67iDgAAAAZRAAAAMIgCAACAQRQAAAAMyvkhwAsvvFBkixYtElm5cuVCv6c2qDV8+HCRPfnkkyILgiD0dZCbJk2aJLLLL7881Gt79eolsrVr10ZeUyppj76ePHmyyCpXrhz6Pfv27RtpTcC/1a1bV2Tjx48X2bhx4zKxnJzGHQAAAAyiAAAAYBAFAAAAgygAAAAYlFNDgFWrVhXZ0qVLRVamTJlI1xk1apTIhg0bFuk9YY/2aOhp06al/DraMJ42HKvRfqbGjBkT6hpr1qxR31P7e/zrX/8aaj2wKy8vT2T9+vUTWcOGDUWmDWnjf+MOAAAABlEAAAAwiAIAAIBBFAAAAAyiAAAAYFBO/RZAgwYNRBZ14l87zrd///6R3hNwzrnatWuL7LnnnhOZ915kxTlSuqCgQGStW7cWWalSsu8fPXo09HWO99RTT6n5xIkTk35P2FW9enWRab8F8MYbb4js/fffT8uaSjruAAAAYBAFAAAAgygAAAAYRAEAAMCgnBoCfOutt0SmHR8JpFPLli3jXgJg1uzZs0V2+PDhGFaS+7gDAACAQRQAAAAMogAAAGAQBQAAAIP8iU4c896HP44MOE4QBPKIuwxjDyOKbNjDztnYx5UrVxbZCy+8IDJt8PuWW25Jy5pKikT7mDsAAAAYRAEAAMAgCgAAAAZRAAAAMOiEQ4AAAKBk4g4AAAAGUQAAADCIAgAAgEEUAAAADKIAAABgEAUAAACD/g/Ns090W5U2ugAAAABJRU5ErkJggg==\n",
"text/plain": [
"<Figure size 648x648 with 9 Axes>"
]
},
"metadata": {
"tags": [],
"needs_background": "light"
}
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "IGCQjbfEf4yV",
"outputId": "11cc6e81-96cf-40d9-882c-596895ba778e",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 112
}
},
"source": [
"learner.fit(2)"
],
"execution_count": null,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>1.338328</td>\n",
" <td>1.158662</td>\n",
" <td>0.504907</td>\n",
" <td>00:24</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>1.364854</td>\n",
" <td>1.161076</td>\n",
" <td>0.500491</td>\n",
" <td>00:20</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {
"tags": []
}
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "XMX2IEJqoL3Q",
"outputId": "12d8a218-861a-4424-f507-e52b8edce0fe",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
}
},
"source": [
"learner.summary()"
],
"execution_count": null,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>00:02</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "execute_result",
"data": {
"application/vnd.google.colaboratory.intrinsic+json": {
"type": "string"
},
"text/plain": [
"Sequential (Input shape: ['64 x 3 x 28 x 28'])\n",
"================================================================\n",
"Layer (type) Output Shape Param # Trainable \n",
"================================================================\n",
"Conv2d 64 x 64 x 14 x 14 9,408 False \n",
"________________________________________________________________\n",
"BatchNorm2d 64 x 64 x 14 x 14 128 True \n",
"________________________________________________________________\n",
"ReLU 64 x 64 x 14 x 14 0 False \n",
"________________________________________________________________\n",
"MaxPool2d 64 x 64 x 7 x 7 0 False \n",
"________________________________________________________________\n",
"Conv2d 64 x 64 x 7 x 7 36,864 False \n",
"________________________________________________________________\n",
"BatchNorm2d 64 x 64 x 7 x 7 128 True \n",
"________________________________________________________________\n",
"ReLU 64 x 64 x 7 x 7 0 False \n",
"________________________________________________________________\n",
"Conv2d 64 x 64 x 7 x 7 36,864 False \n",
"________________________________________________________________\n",
"BatchNorm2d 64 x 64 x 7 x 7 128 True \n",
"________________________________________________________________\n",
"Conv2d 64 x 64 x 7 x 7 36,864 False \n",
"________________________________________________________________\n",
"BatchNorm2d 64 x 64 x 7 x 7 128 True \n",
"________________________________________________________________\n",
"ReLU 64 x 64 x 7 x 7 0 False \n",
"________________________________________________________________\n",
"Conv2d 64 x 64 x 7 x 7 36,864 False \n",
"________________________________________________________________\n",
"BatchNorm2d 64 x 64 x 7 x 7 128 True \n",
"________________________________________________________________\n",
"Conv2d 64 x 64 x 7 x 7 36,864 False \n",
"________________________________________________________________\n",
"BatchNorm2d 64 x 64 x 7 x 7 128 True \n",
"________________________________________________________________\n",
"ReLU 64 x 64 x 7 x 7 0 False \n",
"________________________________________________________________\n",
"Conv2d 64 x 64 x 7 x 7 36,864 False \n",
"________________________________________________________________\n",
"BatchNorm2d 64 x 64 x 7 x 7 128 True \n",
"________________________________________________________________\n",
"Conv2d 64 x 128 x 4 x 4 73,728 False \n",
"________________________________________________________________\n",
"BatchNorm2d 64 x 128 x 4 x 4 256 True \n",
"________________________________________________________________\n",
"ReLU 64 x 128 x 4 x 4 0 False \n",
"________________________________________________________________\n",
"Conv2d 64 x 128 x 4 x 4 147,456 False \n",
"________________________________________________________________\n",
"BatchNorm2d 64 x 128 x 4 x 4 256 True \n",
"________________________________________________________________\n",
"Conv2d 64 x 128 x 4 x 4 8,192 False \n",
"________________________________________________________________\n",
"BatchNorm2d 64 x 128 x 4 x 4 256 True \n",
"________________________________________________________________\n",
"Conv2d 64 x 128 x 4 x 4 147,456 False \n",
"________________________________________________________________\n",
"BatchNorm2d 64 x 128 x 4 x 4 256 True \n",
"________________________________________________________________\n",
"ReLU 64 x 128 x 4 x 4 0 False \n",
"________________________________________________________________\n",
"Conv2d 64 x 128 x 4 x 4 147,456 False \n",
"________________________________________________________________\n",
"BatchNorm2d 64 x 128 x 4 x 4 256 True \n",
"________________________________________________________________\n",
"Conv2d 64 x 128 x 4 x 4 147,456 False \n",
"________________________________________________________________\n",
"BatchNorm2d 64 x 128 x 4 x 4 256 True \n",
"________________________________________________________________\n",
"ReLU 64 x 128 x 4 x 4 0 False \n",
"________________________________________________________________\n",
"Conv2d 64 x 128 x 4 x 4 147,456 False \n",
"________________________________________________________________\n",
"BatchNorm2d 64 x 128 x 4 x 4 256 True \n",
"________________________________________________________________\n",
"Conv2d 64 x 128 x 4 x 4 147,456 False \n",
"________________________________________________________________\n",
"BatchNorm2d 64 x 128 x 4 x 4 256 True \n",
"________________________________________________________________\n",
"ReLU 64 x 128 x 4 x 4 0 False \n",
"________________________________________________________________\n",
"Conv2d 64 x 128 x 4 x 4 147,456 False \n",
"________________________________________________________________\n",
"BatchNorm2d 64 x 128 x 4 x 4 256 True \n",
"________________________________________________________________\n",
"Conv2d 64 x 256 x 2 x 2 294,912 False \n",
"________________________________________________________________\n",
"BatchNorm2d 64 x 256 x 2 x 2 512 True \n",
"________________________________________________________________\n",
"ReLU 64 x 256 x 2 x 2 0 False \n",
"________________________________________________________________\n",
"Conv2d 64 x 256 x 2 x 2 589,824 False \n",
"________________________________________________________________\n",
"BatchNorm2d 64 x 256 x 2 x 2 512 True \n",
"________________________________________________________________\n",
"Conv2d 64 x 256 x 2 x 2 32,768 False \n",
"________________________________________________________________\n",
"BatchNorm2d 64 x 256 x 2 x 2 512 True \n",
"________________________________________________________________\n",
"Conv2d 64 x 256 x 2 x 2 589,824 False \n",
"________________________________________________________________\n",
"BatchNorm2d 64 x 256 x 2 x 2 512 True \n",
"________________________________________________________________\n",
"ReLU 64 x 256 x 2 x 2 0 False \n",
"________________________________________________________________\n",
"Conv2d 64 x 256 x 2 x 2 589,824 False \n",
"________________________________________________________________\n",
"BatchNorm2d 64 x 256 x 2 x 2 512 True \n",
"________________________________________________________________\n",
"Conv2d 64 x 256 x 2 x 2 589,824 False \n",
"________________________________________________________________\n",
"BatchNorm2d 64 x 256 x 2 x 2 512 True \n",
"________________________________________________________________\n",
"ReLU 64 x 256 x 2 x 2 0 False \n",
"________________________________________________________________\n",
"Conv2d 64 x 256 x 2 x 2 589,824 False \n",
"________________________________________________________________\n",
"BatchNorm2d 64 x 256 x 2 x 2 512 True \n",
"________________________________________________________________\n",
"Conv2d 64 x 256 x 2 x 2 589,824 False \n",
"________________________________________________________________\n",
"BatchNorm2d 64 x 256 x 2 x 2 512 True \n",
"________________________________________________________________\n",
"ReLU 64 x 256 x 2 x 2 0 False \n",
"________________________________________________________________\n",
"Conv2d 64 x 256 x 2 x 2 589,824 False \n",
"________________________________________________________________\n",
"BatchNorm2d 64 x 256 x 2 x 2 512 True \n",
"________________________________________________________________\n",
"Conv2d 64 x 256 x 2 x 2 589,824 False \n",
"________________________________________________________________\n",
"BatchNorm2d 64 x 256 x 2 x 2 512 True \n",
"________________________________________________________________\n",
"ReLU 64 x 256 x 2 x 2 0 False \n",
"________________________________________________________________\n",
"Conv2d 64 x 256 x 2 x 2 589,824 False \n",
"________________________________________________________________\n",
"BatchNorm2d 64 x 256 x 2 x 2 512 True \n",
"________________________________________________________________\n",
"Conv2d 64 x 256 x 2 x 2 589,824 False \n",
"________________________________________________________________\n",
"BatchNorm2d 64 x 256 x 2 x 2 512 True \n",
"________________________________________________________________\n",
"ReLU 64 x 256 x 2 x 2 0 False \n",
"________________________________________________________________\n",
"Conv2d 64 x 256 x 2 x 2 589,824 False \n",
"________________________________________________________________\n",
"BatchNorm2d 64 x 256 x 2 x 2 512 True \n",
"________________________________________________________________\n",
"Conv2d 64 x 512 x 1 x 1 1,179,648 False \n",
"________________________________________________________________\n",
"BatchNorm2d 64 x 512 x 1 x 1 1,024 True \n",
"________________________________________________________________\n",
"ReLU 64 x 512 x 1 x 1 0 False \n",
"________________________________________________________________\n",
"Conv2d 64 x 512 x 1 x 1 2,359,296 False \n",
"________________________________________________________________\n",
"BatchNorm2d 64 x 512 x 1 x 1 1,024 True \n",
"________________________________________________________________\n",
"Conv2d 64 x 512 x 1 x 1 131,072 False \n",
"________________________________________________________________\n",
"BatchNorm2d 64 x 512 x 1 x 1 1,024 True \n",
"________________________________________________________________\n",
"Conv2d 64 x 512 x 1 x 1 2,359,296 False \n",
"________________________________________________________________\n",
"BatchNorm2d 64 x 512 x 1 x 1 1,024 True \n",
"________________________________________________________________\n",
"ReLU 64 x 512 x 1 x 1 0 False \n",
"________________________________________________________________\n",
"Conv2d 64 x 512 x 1 x 1 2,359,296 False \n",
"________________________________________________________________\n",
"BatchNorm2d 64 x 512 x 1 x 1 1,024 True \n",
"________________________________________________________________\n",
"Conv2d 64 x 512 x 1 x 1 2,359,296 False \n",
"________________________________________________________________\n",
"BatchNorm2d 64 x 512 x 1 x 1 1,024 True \n",
"________________________________________________________________\n",
"ReLU 64 x 512 x 1 x 1 0 False \n",
"________________________________________________________________\n",
"Conv2d 64 x 512 x 1 x 1 2,359,296 False \n",
"________________________________________________________________\n",
"BatchNorm2d 64 x 512 x 1 x 1 1,024 True \n",
"________________________________________________________________\n",
"AdaptiveAvgPool2d 64 x 512 x 1 x 1 0 False \n",
"________________________________________________________________\n",
"AdaptiveMaxPool2d 64 x 512 x 1 x 1 0 False \n",
"________________________________________________________________\n",
"Flatten 64 x 1024 0 False \n",
"________________________________________________________________\n",
"BatchNorm1d 64 x 1024 2,048 True \n",
"________________________________________________________________\n",
"Dropout 64 x 1024 0 False \n",
"________________________________________________________________\n",
"Linear 64 x 512 524,288 True \n",
"________________________________________________________________\n",
"ReLU 64 x 512 0 False \n",
"________________________________________________________________\n",
"BatchNorm1d 64 x 512 1,024 True \n",
"________________________________________________________________\n",
"Dropout 64 x 512 0 False \n",
"________________________________________________________________\n",
"Linear 64 x 2 1,024 True \n",
"________________________________________________________________\n",
"\n",
"Total params: 21,813,056\n",
"Total trainable params: 545,408\n",
"Total non-trainable params: 21,267,648\n",
"\n",
"Optimizer used: <function Adam at 0x7f38b9576158>\n",
"Loss function: <function cross_entropy at 0x7f38bdbe3d90>\n",
"\n",
"Model frozen up to parameter group #2\n",
"\n",
"Callbacks:\n",
" - TrainEvalCallback\n",
" - Recorder\n",
" - ProgressCallback\n",
" - XLAOptCallback\n",
" - GetPred"
]
},
"metadata": {
"tags": []
},
"execution_count": 99
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "eiqWMXfDmGTn",
"outputId": "a2bf2944-46aa-4934-8b79-13a2b1d85425",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 35
}
},
"source": [
""
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"tensor([[-2.8908, 1.3693]], device='xla:1')"
]
},
"metadata": {
"tags": []
},
"execution_count": 110
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "KX_1_MZhj9Zu",
"outputId": "87490900-3fea-4678-e666-cb851e20c94f",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 35
}
},
"source": [
"dls.one_batch()[0].std(dim=[0,2,3])"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"TensorImage([1.3302, 1.3599, 1.3539], device='xla:1')"
]
},
"metadata": {
"tags": []
},
"execution_count": 97
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "_jnecIehhOrP",
"outputId": "278825f9-c4f2-45bd-dc24-90a91a18733d",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 91
}
},
"source": [
"list(learner.model.parameters())[-1],list(learner.model.parameters())[-1].shape"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(Parameter containing:\n",
" tensor([[ 0.0154, -0.0494, -0.1158, ..., -0.0006, -0.0389, -0.0100],\n",
" [ 0.0540, 0.0339, -0.0311, ..., 0.1220, 0.0432, -0.0705]],\n",
" device='xla:1', requires_grad=True), torch.Size([2, 512]))"
]
},
"metadata": {
"tags": []
},
"execution_count": 113
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "H1Y4mu8DhUA_",
"outputId": "87258769-67d4-4487-ee3e-8c907331adb2",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 91
}
},
"source": [
"list(learner.model.parameters())[-1],list(learner.model.parameters())[-1].shape"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(Parameter containing:\n",
" tensor([[ 0.0154, -0.0494, -0.1158, ..., -0.0006, -0.0389, -0.0100],\n",
" [ 0.0540, 0.0339, -0.0311, ..., 0.1220, 0.0432, -0.0705]],\n",
" device='xla:1', requires_grad=True), torch.Size([2, 512]))"
]
},
"metadata": {
"tags": []
},
"execution_count": 115
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "R9ZRGb3tpzBb"
},
"source": [
""
],
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment