Skip to content

Instantly share code, notes, and snippets.

@mdvsh
Last active April 25, 2020 14:17
Show Gist options
  • Save mdvsh/7017ebd837f45183b541bcc30896ce58 to your computer and use it in GitHub Desktop.
Save mdvsh/7017ebd837f45183b541bcc30896ce58 to your computer and use it in GitHub Desktop.
Version1
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "meeting_torch_nn",
"provenance": [],
"collapsed_sections": [],
"authorship_tag": "ABX9TyMxGCCB0/fzyI9J1SP90B/f",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/PseudoCodeNerd/7017ebd837f45183b541bcc30896ce58/meeting_torch_nn.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jjEncWTOS16c",
"colab_type": "text"
},
"source": [
"# Meeting `torch.nn` | Hello World with PyTorch\n",
"\n",
"After months of experimenting with numpy and building shallow neural networks from scratch; In this notebook, I'll finally delve deep into applying my newly acquired theoretical knowledge of Deep Learning with the help of PyTorch.\n",
"\n",
"To better get the PyTorch workflow, I'll start by doing the Hello World of AI ~ The MNIST Dataset of-course.\n",
"\n",
"*partial implementation of official PyTorch documentation along with some SO answers*"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9JDeeUr_TyZg",
"colab_type": "text"
},
"source": [
"## Data setup\n"
]
},
{
"cell_type": "code",
"metadata": {
"id": "IXJMtX01Ss4f",
"colab_type": "code",
"colab": {}
},
"source": [
"import requests\n",
"import matplotlib.pyplot as plt\n",
"from pathlib import Path\n",
"data_path = Path(\"data\")\n",
"path = data_path / \"mnist\"\n",
"\n",
"path.mkdir(parents=True, exist_ok=True)\n",
"\n",
"url = \"http://deeplearning.net/data/mnist/\"\n",
"file_name = \"mnist.pkl.gz\"\n",
"\n",
"if not (path / file_name).exists():\n",
" content = requests.get(url + file_name).content\n",
" (path / file_name).open(\"wb\").write(content)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "xROul7-7U_yh",
"colab_type": "text"
},
"source": [
"**Unpacking pickled dataset**"
]
},
{
"cell_type": "code",
"metadata": {
"id": "DHvcA57vUhaH",
"colab_type": "code",
"colab": {}
},
"source": [
"import pickle\n",
"import gzip\n",
"\n",
"with gzip.open((path / file_name).as_posix(), \"rb\") as f:\n",
" ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding=\"latin-1\")"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "jvg47G9uVRYG",
"colab_type": "code",
"outputId": "ef92d603-79f3-4d32-f168-790adfc2d33f",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 593
}
},
"source": [
"%matplotlib inline \n",
"import numpy as np\n",
"\n",
"def show_multi(h, w, fig, col, row):\n",
" for i in range(1, col*row+1):\n",
" fig.add_subplot(row, col, i)\n",
" plt.imshow(x_train[i].reshape((28, 28)), cmap=\"gray\")\n",
" plt.show()\n",
"\n",
"w=10\n",
"h=10\n",
"fig=plt.figure(figsize=(10, 10))\n",
"columns = 4\n",
"rows = 4\n",
"\n",
"show_multi(h, w, fig, columns, rows)"
],
"execution_count": 47,
"outputs": [
{
"output_type": "display_data",
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAkwAAAJACAYAAAB/pjm4AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAgAElEQVR4nOzdf/zV8/3/8ftTfkVKYe09UqGQRvlZfsSHIo0ym5SozLTxGbGyMmY2zI9hfs0Iyca0COXH0BLm57eibUmpTJR++F1+ldbz+0fHZ+/n83U6z/Pjdc55vc65XS8Xl9731/uc1+txvB9OT6/zfD+fxlorAAAAbNhG1S4AAAAg6RgwAQAABDBgAgAACGDABAAAEMCACQAAIIABEwAAQEBJAyZjTG9jzDxjzAJjzOi4ikJ9oY9QKnoIcaCPkIspdh0mY0wTSW9I6iVpsaTpkgZaa+fkeA6LPtUoa60p5nmF9hE9VNPet9ZuV+iTeC9CY5V6L8o8hz6qUdn6qJQ7TPtLWmCtfdNau0bSeEn9Sjgf6hN9hK8tKvJ59BDiQB8hp1IGTNtLeqdRXpw55jDGDDPGzDDGzCjhWqhdwT6ihxDAexHiQB8hp43LfQFr7RhJYyRuX6I49BDiQB8hDvRR/SrlDtMSSW0a5R0yx4BC0EcoFT2EONBHyKmUAdN0SR2MMe2NMZtKGiBpcjxloY7QRygVPYQ40EfIqeiP5Ky1a40xP5H0hKQmksZaa1+LrTLUBfoIpaKHEAf6CCFFLytQ1MX4vLdmFfurvIWih2raTGvtvpW4EH1Uuyr1XiTRR7Us7mUFAAAA6gIDJgAAgAAGTAAAAAEMmAAAAAIYMAEAAAQwYAIAAAhgwAQAABDAgAkAACCg7Jvv4r/22WcfJ//kJz9x8uDBg538xz/+MXKOG2+80cmvvPJKTNUBAIAN4Q4TAABAAAMmAACAAAZMAAAAAWy+WyZdunSJHHvqqaec3Lx584LP+8knnzh5m222Kfgc5cDmu+l1xBFHOPmee+5x8qGHHhp5zrx588pRCpvvJtSFF14YOfarX/3KyRtt5P7/92GHHebkZ555Jva6smHzXcSBzXcBAACKwIAJAAAggAETAABAAOswxWT//fd38sSJEyOPadGihZP9+WOrVq1y8po1ayLn8OcsdevWzcn+ukzZzlHvevTo4WT/3+mDDz5YyXKqbr/99nPy9OnTq1QJkmLo0KFOHjVqVOQx69aty3mOSs6PBSqBO0wAAAABDJgAAAACGDABAAAEMGACAAAIYNJ3nrbYYgsn77333k6+++67ndzQ0FDwNebPn+/kq666KvKY8ePHO/n55593sr/A3OWXX15wHbXOX1CvQ4cOTq71Sd/+AoPt27d3ctu2bZ1sTMXWAURC+D2w+eabV6kSVMoBBxzg5JNPPtnJ2Raw3WOPPXKec+TIkU5+9913I485+OCDnez/Xfryyy/nvEYlcYcJAAAggAETAABAAAMmAACAAOYw5enWW2918sCBA2O/hj8vqlmzZpHH+BtY+vNx9txzz9jrqjWDBw928osvvlilSqrDn193+umnO9mfQzB37tyy14Tq6tmzp5PPOuus4HP8vjjmmGOcvHz58tILQ9mceOKJTr7++uudvO222zo521zGp59+2snbbbedk3/7298G6/DP659jwIABwXNUCneYAAAAAhgwAQAABDBgAgAACGAOUxb77LNP5Nh3vvMdJ4fWpvHnGknSww8/7OSrr77ayf4aFa+++mrkHB999JGTDz/88ILqQnQdonpz++235/y+vx4Yao+/9s2dd97pZH+j8Gz8+SmLFi0qvTDEYuON3b/a991338hjbrvtNif7aw0+++yzTr7kkksi53juueecvNlmmzl5woQJTj7yyCM3UPF/zZgxI/iYaqnvvzkAAADywIAJAAAggAETAABAQHAOkzFmrKRjJK2w1nbOHGsl6S+S2kl6S1J/a+1HGzpH0nXp0sXJU6ZMiTymefPmTrbWOvmvf/2rk7Ot0+TvxePv++bPLXnvvfci5/jHP/7h5HXr1jnZn2vlr+0kSa+88krkWLlVs4/8talat24d9yVSJTQ/JVv/J0E9vBdVypAhQ5z8rW99K+fj/fV2JOmPf/xjnCVVTD30kb8PXGjeohT9795fp2nlypXBc/jPyWfO0uLFi5181113BZ9TLfncYRonqbd3bLSkqdbaDpKmZjKQyzjRRyjNONFDKN040UcoQnDAZK19VtKH3uF+kr4eBt4l6biY60KNoY9QKnoIcaCPUKxi5zC1ttYuzXy9TFJ9f8aBYtFHKBU9hDjQRwgqeR0ma601xtgNfd8YM0zSsFKvg9qWq4/oIeSD9yLEgT7ChhQ7YFpujGmw1i41xjRIWrGhB1prx0gaI0m5mrCSOnbs6OTzzjvPydkmxb7//vtOXrp0qZP9iWqffvpp5ByPPvpozhyHpk2bOnnEiBGRxwwaNCj26xYprz4qtYf69OnjZP/fUS3LNsG9ffv2OZ+zZMmScpVTDql+L6oEfxNVSfrBD37gZP+XRz7++GMnX3rppfEXliyp7iN/Ucmf//znTvZ/SUmSbr75Zif7v4SUzyRv3wUXXFDwc84++2wnZ/tlp6Qo9iO5yZK+/jWLIZImxVMO6gx9hFLRQ4gDfYSg4IDJGHOvpBcl7WqMWWyMOU3SFZJ6GWPmS+qZycAG0UcoFT2EONBHKFbwIzlrbXRBofWOiLkW1DD6CKWihxAH+gjFqovNd/0NAf1Nb/05LqtWrYqcY/DgwU72NwhM6ryYHXfcsdolVN2uu+6a8/uvvfZahSqpPL/Xpei8pjfeeMPJ2fof6dGuXTsnT5w4seBz3HjjjU6eNm1aKSUhZhdddJGT/TlLa9ascfITTzwROceoUaOc/MUXX+S85uabbx455i9M6f99428Gn20u3KRJ6fn0k61RAAAAAhgwAQAABDBgAgAACKiLOUxdu3Z1sj9nydevX7/IsWeeeSbWmpAc06dPr3YJefM3ge7d290Sy990M5/NL/01XPw1eJAufk/4m09nM3XqVCdff/31sdaE4m299daRY2eeeaaT/XWW/DlLxx1X+E4vu+yyi5PvueeeyGP22WefnOe4//77nXzVVVcVXEeScIcJAAAggAETAABAAAMmAACAgLqYw3Tttdc62V8bwp+flKb5Shtt5I55/T2hENaqVauSz7HXXns52e8xSerZs6eTd9hhBydvuummTs6255//8/bXTnn55ZedvHr16sg5Nt7Y/c9+5syZkccgPfz5KVdcEV6k+rnnnnPykCFDnPzJJ5+UXhhi4b8vSNn3B2zM35/tG9/4RuQxp556qpP79u3r5M6dOzu5WbNmkXP4c6f8fPfddzv5s88+20DF6cAdJgAAgAAGTAAAAAEMmAAAAAIYMAEAAATU3KTvY445JnKsS5cuTvYnpk2ePLmsNZWTP8nbf22zZs2qZDmJ5E+M9v8d3XLLLU72N7LMh784YLZJ32vXrnXy559/7uQ5c+Y4eezYsZFz+Js++7+gsHz5cicvXrw4cg5/o+i5c+dGHoPkimNz3TfffNPJft8gOfyNdCXpvffec/J2223n5H//+99O9t/z8vHuu+86eeXKlZHHNDQ0OPn999938sMPP1zwdZOMO0wAAAABDJgAAAACGDABAAAE1NwcJn9+hhRd+GvFihVO/stf/lLWmoq12WabRY5dfPHFOZ/z1FNPOfn888+Ps6RU8jeqXLRokZMPPPDAkq/x9ttvO/mhhx6KPOb111938ksvvVTydX3Dhg1zsj+3QYrOX0G6jBo1ysnFLFabz+KWSIZsm2H7i5U+8sgjTvYX4124cGHkHJMmTXLyuHHjnPzhhx86efz48ZFz+HOYsj2mlnCHCQAAIIABEwAAQAADJgAAgICam8OUD39D0qVLl1apEpc/Z+nCCy+MPOa8885zsr/OzjXXXOPkTz/9NKbqaseVV15Z7RLK5ogjjgg+pph1e1Ad/hpyknTkkUcWdA5/rookzZs3r+iaUH3+JtvZ5iqWqkePHk4+9NBDI4/x58/V+vxI7jABAAAEMGACAAAIYMAEAAAQUJdzmJKyd5w/P8Gfn3TiiSdGnuPPR/je974Xf2GoaQ8++GC1S0Cennzyycixli1b5nyOv77X0KFD4ywJdcJf0zDbel/+HnWswwQAAFDnGDABAAAEMGACAAAIYMAEAAAQUHOTvo0xwWP+xoXDhw8va01fO/fcc538i1/8wsktWrRw8j333BM5x+DBg+MvDEAibbPNNpFjoc12b775ZiezeC2K8cQTT1S7hMThDhMAAEAAAyYAAIAABkwAAAABwTlMxpg2kv4oqbUkK2mMtfZ6Y0wrSX+R1E7SW5L6W2s/Kl+p+fEX0sp27Jvf/KaTb7jhBiePHTs2co4PPvjAyd26dXPyKaec4uS99torco4ddtjByW+//baT/c+M/bkIaZa2PqoV2eb0dezY0cn+QodJVQ89dOeddzp5o40K/3/aF154Ia5yalI99FEcjjrqqGqXkDj5/Ne4VtIIa20nSd0k/a8xppOk0ZKmWms7SJqaycCG0EcoFT2EONBHKEpwwGStXWqtfSXz9SpJr0vaXlI/SXdlHnaXpOOynwGgj1A6eghxoI9QrIKWFTDGtJPUVdLLklpba5dmvrVM629vZnvOMEnDii8RtabQPqKH4OO9CHGgj1CIvAdMxphmkiZKOsdau7Lx3AhrrTXGRCcPrf/eGEljMufI+phKa9KkiZPPPPNMJ2fb0HblypVO7tChQ8HX9ecWTJs2zckXXXRRwedMm2L6KIk9lBbZ5vQVMy8mSWrpvcjfgLtnz55Ozrbm0po1a5z8+9//3snLly+PqbraVkt9VA477bRTtUtInLzeOY0xm2h9Y91jrX0gc3i5MaYh8/0GSSvKUyJqBX2EUtFDiAN9hGIEB0xm/bD7DkmvW2uvbfStyZKGZL4eImlS/OWhVtBHKBU9hDjQRyhWPh/JHSTpFEn/MsbMyhz7uaQrJE0wxpwmaZGk/uUpETWCPkKp6CHEgT5CUYIDJmvtc5Kii7msd0S85ZTuxRdfjBybPn26k/fbb7+c5/DXaZKk1q2zzv/7P/46TePHj488plJ71iVR2vqolnXv3t3J48aNq04hBarFHtp6662dnO29x7dkyRInjxw5Mtaaal0t9lE5/P3vf3dytrmPoX0Na026Z38CAABUAAMmAACAAAZMAAAAAQyYAAAAAgpa6TsNFi9eHDl2/PHHO/lHP/qRky+88MKCr3P99dc7+Q9/+IOTFyxYUPA5gbhl23wXAEJmz57t5Pnz50ce4y9uufPOOzv5vffei7+wKuIOEwAAQAADJgAAgAAGTAAAAAE1N4cpm6VLlzr54osvzpmBtPrrX//q5BNOOKFKlSAfc+fOdbK/QffBBx9cyXKADfrNb34TOXb77bc7+bLLLnPyWWed5eQ5c+bEX1gFcYcJAAAggAETAABAAAMmAACAAGOtrdzFjKncxVBR1tqKLPhDD9W0mdbafStxIfqodlXqvUiqrz5q3rx55NiECROc3LNnTyc/8MADTj711FMj5/jss89iqC5+2fqIO0wAAAABDJgAAAACGDABAAAEMIcJsWAOE2LAHCaUjDlMlePPa/LXYTrjjDOcvOeee0bOkdS1mZjDBAAAUAQGTAAAAAEMmAAAAAIYMAEAAAQw6RuxYNI3YsCkb5SMSd+IA5O+AQAAisCACQAAIIABEwAAQMDGFb7e+5IWSdo283XSUWd+2lbwWl/3kFT9150v6sxPNfqo2q85X9SZn0r2kEQflUu168zaRxWd9P1/FzVmRqUmd5aCOpMtLa+bOpMrLa+ZOpMtLa+bOkvDR3IAAAABDJgAAAACqjVgGlOl6xaKOpMtLa+bOpMrLa+ZOpMtLa+bOktQlTlMAAAAacJHcgAAAAEMmAAAAAIqOmAyxvQ2xswzxiwwxoyu5LVDjDFjjTErjDGzGx1rZYyZYoyZn/mzZTVrzNTUxhgzzRgzxxjzmjFmeFJrLRf6qOQa676HpOT2URp6KFNT3fdRUntISkcfpa2HKjZgMsY0kfR7SUdL6iRpoDGmU6Wun4dxknp7x0ZLmmqt7SBpaiZX21pJI6y1nSR1k/S/mX+PSaw1dvRRLOq6h6TE99E4Jb+HpDrvo4T3kJSOPkpXD1lrK/KPpO6SnmiUz5d0fqWun2eN7STNbpTnSWrIfN0gaV61a8xS8yRJvdJQa0yvlz6Kv9666qHM60t0H6WthzJ11VUfJb2HMjWlqo+S3kOV/Ehue0nvNMqLM8eSrLW1dmnm62WSWlezGJ8xpp2krpJeVsJrjRF9FKM67SEpfX2U6J9NnfZR2npISvDPJg09xKTvPNn1Q93ErMFgjGkmaaKkc6y1Kxt/L2m14r+S9LOhh9IpaT8b+iidkvSzSUsPVXLAtERSm0Z5h8yxJFtujGmQpMyfK6pcjyTJGLOJ1jfXPdbaBzKHE1lrGdBHMajzHpLS10eJ/NnUeR+lrYekBP5s0tRDlRwwTZfUwRjT3hizqaQBkiZX8PrFmCxpSObrIVr/+WpVGWOMpDskvW6tvbbRtxJXa5nQRyWihySlr48S97Ohj1LXQ1LCfjap66EKT+jqI+kNSQslXVDtCVxebfdKWirpK63/LPo0Sdto/Qz9+ZL+JqlVAuo8WOtvT/5T0qzMP32SWCt9lMw+ooeS3Udp6CH6KNk9lJY+SlsPsTUKAABAAJO+AQAAAhgwAQAABDBgAgAACGDABAAAEMCACQAAIIABEwAAQEBJAyZjTG9jzDxjzAJjTDJ2E0bq0EcoFT2EONBHyKXodZiMMU20fsGuXlq/KNZ0SQOttXNyPIdFn2qUtdYU87xC+4geqmnvW2u3K/RJvBehsUq9F2WeQx/VqGx9VModpv0lLbDWvmmtXSNpvKR+JZwP9Yk+wtcWFfk8eghxoI+QUykDpu0lvdMoL84ccxhjhhljZhhjZpRwLdSuYB/RQwjgvQhxoI+Q08blvoC1doykMRK3L1EceghxoI8QB/qofpVyh2mJpDaN8g6ZY0Ah6COUih5CHOgj5FTKgGm6pA7GmPbGmE0lDZA0OZ6yUEfoI5SKHkIc6CPkVPRHctbatcaYn0h6QlITSWOtta/FVhnqAn2EUtFDiAN9hJCilxUo6mJ83luziv1V3kLRQzVtprV230pciD6qXZV6L5Loo1oW97ICAAAAdYEBEwAAQAADJgAAgAAGTAAAAAEMmAAAAAIYMAEAAAQwYAIAAAhgwAQAABBQ9s13AWR3/fXXO/nss8928uzZsyPPOeaYY5y8aNGi+AsDAERwhwkAACCAARMAAEAAAyYAAIAA5jBV0FZbbeXkZs2aOfk73/mOk7fbbrvIOa699lonr169OqbqUG7t2rVz8sknn+zkdevWOXn33XePnGO33XZzMnOY6k/Hjh2dvMkmmzi5R48eTr755psj5/B7LQ6TJk1y8oABA5y8Zs2a2K+J+Ph9dOCBBzr5N7/5TeQ5Bx10UFlrShruMAEAAAQwYAIAAAhgwAQAABDAHKaY+PNTRo0aFXlM9+7dndy5c+eCr9PQ0OBkf+0eJNd7773n5GeffdbJffv2rWQ5SKA99tjDyUOHDo085oQTTnDyRhu5/9/7rW99y8nZ5itZa4uscMP8/r3lllucfM4550Ses3LlytjrQHFatGjh5GnTpjl52bJlked885vfDD6mlnCHCQAAIIABEwAAQAADJgAAgAAGTAAAAAFM+s6Tv2CgP4Fx0KBBTm7atGnkHMYYJ7/zzjtOXrVqlZOzLVzYv39/J/uL0s2dOzfyHCTDZ5995mQWnYTv8ssvd3KfPn2qVEnpBg8e7OQ77rgj8pjnn3++UuWgRP4E72zHmPQNAABQ5xgwAQAABDBgAgAACGAOk6ILdl155ZWRx5x44olO9jfSzcf8+fOdfNRRRznZ3/ww23ykbbfdNmdGcm299dZO3muvvapUCZJqypQpTs5nDtOKFSuc7M8V8he2lMKb7/obrx566KHBOlDb/Dm49Yg7TAAAAAEMmAAAAAIYMAEAAAQwh0nSd7/7XSf/8Ic/LPmcCxcujBzr1auXk/11mHbZZZeSr4vk2mKLLZy84447FnyO/fbbz8n+PDfWdkq3P/zhD05+6KGHgs/56quvnBzHWjjNmzd38uzZsyOP8Tf59fm1z5gxo+S6UD3ZNmzefPPNq1BJ9XCHCQAAIIABEwAAQAADJgAAgIDgHCZjzFhJx0haYa3tnDnWStJfJLWT9Jak/tbaj8pXZnmdcMIJBT/nrbfecvL06dOdPGrUqMhz/DlLvmx7x9WKeuijkHfffdfJ48aNc/LFF18cPIf/mI8//tjJN910UzGlpUI99NDatWudHHrPKBd/jbiWLVsWfI7Fixc7efXq1SXVFJd66KNK2XfffZ380ksvVamSysjnDtM4Sb29Y6MlTbXWdpA0NZOBXMaJPkJpxokeQunGiT5CEYIDJmvts5I+9A73k3RX5uu7JB0Xc12oMfQRSkUPIQ70EYpV7LICra21SzNfL5PUekMPNMYMkzSsyOugtuXVR/QQcuC9CHGgjxBU8jpM1lprjIku0PDf74+RNEaScj0O9S1XH9FDyAfvRYgDfYQNKXbAtNwY02CtXWqMaZC0IviMBDv99NOdPGxY9H8ennzySScvWLDAyf4GmMVo3XqD/1NTq2qqjwp1ySWXODmfSd+IqOseisuAAQOc7L8nNm3atOBzXnTRRSXVVGF130f+Lxx88sknTvY3qZeknXfeuaw1JU2xywpMljQk8/UQSZPiKQd1hj5CqeghxIE+QlBwwGSMuVfSi5J2NcYsNsacJukKSb2MMfMl9cxkYIPoI5SKHkIc6CMUK/iRnLV24Aa+dUTMtaCG0UcoFT2EONBHKBab7yq6oGC15pJ07969KtdFMmy0kXvDd926dVWqBLVk0KBBkWOjR7vLDPkbf2+yySYFX2fWrFlO9jcFRrL5i+D+/e9/d/IxxxxTyXISia1RAAAAAhgwAQAABDBgAgAACGAOU0zOPvtsJ2+55ZYFn+Pb3/528DEvvPCCk1988cWCr4Nk8ucsWcuaePWmXbt2Tj7llFMij+nZs2dB5zz44IMjxwrtrZUrV0aO+fOgHnvsMSd/8cUXBV0DSDruMAEAAAQwYAIAAAhgwAQAABDAHKYstthii8ixTp06OfmXv/ylk/v06RM8b6Hr7PjrQ0nSqaee6uT//Oc/wesCSKbOnTs7efLkyU7ecccdK1nOBvlr8kjSmDFjqlAJkmSbbbapdgkVxR0mAACAAAZMAAAAAQyYAAAAAhgwAQAABNTlpG9/Y8muXbs6eeLEiZHnNDQ0ONlflM2foJ1tQcnevXs7Odvk8sY23jj64zn++OOdfP311zt5zZo1Oc8JILmMMTlzMfxfNpEK39g528arRx99tJP/+te/FlYYUq9v377VLqGiuMMEAAAQwIAJAAAggAETAABAQF3MYdp0002d7M8leuCBB4Ln+NWvfuXkp556ysnPP/+8k1u1ahU5h/8cf9E633bbbRc5dvnllzv57bffdvJDDz3k5NWrV+e8BpKj0IVNJalHjx5Ovummm2KtCeU1e/ZsJx922GFOPvnkkyPPeeKJJ5z85ZdfllzHaaed5uSzzjqr5HMi3aZNm+bkbPPY6g13mAAAAAIYMAEAAAQwYAIAAAgw1trKXcyYsl/MX2NJkn796187+bzzzst5jmzriZxyyilO/vjjj53szzd67LHHIufYe++9neyvmXTVVVc5Odscp379+mWp+L/+9re/OfnKK6+MPOajjz7KeY5Zs2bl/H421trSF4zJQyV6qFr8jZSL+W9zzz33dPKcOXNKqqnCZlpr963EhWq5j4rRokULJ3/wwQfB5xx77LFOTso6TJV6L5Jqu4++973vOfm+++6LPMZfj9DfpH7RokXxF1Yh2fqIO0wAAAABDJgAAAACGDABAAAEpH4dpiZNmjj5kksuiTxm5MiRTv7ss8+cPHr0aCePHz8+cg5/ztK++7pTLfz1b/z96SRp/vz5Tj7jjDOc7K970bx588g5DjzwQCcPGjTIyf7ePlOmTImcw/fOO+84uX379sHnIH633HKLk3/0ox8VfI5hw4Y5+ZxzzimpJtSHo446qtolIGHWrl0bfIy/1+Fmm21WrnISgTtMAAAAAQyYAAAAAhgwAQAABDBgAgAACEj9pG9/kqs/wVuSPv/8cyf7k2mffPJJJ3fr1i1yjlNPPdXJRx99tJObNm3qZH+xTEm68847nexPtvatXLkycuzxxx/PmQcOHOjkk046Kec1JOncc88NPgblN3fu3GqXgBhlW0T3yCOPdLK/Ibe/EGC5+O9n119/fUWui/SYNGmSk7O9P+22225O9n/J5Mwzz4y/sCriDhMAAEAAAyYAAICA4IDJGNPGGDPNGDPHGPOaMWZ45ngrY8wUY8z8zJ8ty18u0oo+QqnoIcSBPkKxgpvvGmMaJDVYa18xxmwlaaak4yQNlfShtfYKY8xoSS2ttaMC54p9o8KlS5c62d8EV5JWr17tZP+z2C233NLJu+yyS8F1XHzxxU6+/PLLI4/xN1etJaENL+Pqo1re7NL3xhtvRI7tvPPOOZ+z0Ubu/wNl6+WFCxeWVlj55Nx8N+nvRQcffLCTL7jggshjevXq5WR/kdjQvMZ8tGrVysl9+vSJPObGG2908lZbbZXznNnmVvmL5PoL71ZLpd6LMueqm/ej6667LnLMnwvXunVrJ3/55Zdlramcitp811q71Fr7SubrVZJel7S9pH6S7so87C6tbzggK/oIpaKHEAf6CMUqaA6TMaadpK6SXpbU2lr79e2dZZJab+BpgIM+QqnoIcSBPkIh8l5WwBjTTNJESedYa1c23kPGWms3dGvSGDNM0rBs30P9KaaP6CE0xnsR4kAfoVB5DZiMMZtofWPdY619IHN4uTGmwVq7NPOZ8Ipsz7XWjpE0JnOe2D/vXbZsmZOzzWHyNwTca6+9cp7zscceixx79tlnnfzQQw85+a233nJyLc9XKlaxfVTuHkqq10eM4l4AACAASURBVF57LXJsp512yvmcdevWlaucREjye5G/AXfnzp2Dz/nZz37m5FWrVpVchz9Pau+99448JjR39emnn3byH/7wh8hjkjJnqRhJ7qM08ftozZo1VaqkMvL5LTkj6Q5Jr1trr230rcmShmS+HiJpkv9c4Gv0EUpFDyEO9BGKlc8dpoMknSLpX8aYWZljP5d0haQJxpjTJC2S1L88JaJG0EcoFT2EONBHKEpwwGStfU7Shn5N84h4y0Gtoo9QKnoIcaCPUKzU7yXXo0cPJx93XPQ3Qf3P8FescD+aHjt2rJM/+uijyDlq/bNZJM+YMWMix4499tgqVIJyOeOMM6pyXf898OGHH3by8OHDnZzm9XRQPs2bN3dyv379nPzggw9WspyyY2sUAACAAAZMAAAAAQyYAAAAAhgwAQAABAQ33431YnW+yFctC214GZd66qG2bdtGjj3yyCNO3n333Z3ceLViSerYsWPkHGndfDdO5eijLl26OPmss86KPGbIkCGRY6Xyf56ff/65k//+979HnuP/QsHs2bNjr6taKvVeJNXX+9G7774bOdayZUsnd+3a1cn+RvdpUtTmuwAAAPWOARMAAEAAAyYAAIAA5jAhFsxhQgxSPYfJ52/6LUlDhw518qWXXupkf06Iv8m3JE2ZMsXJkya5W575G5LXG+Ywlcf48eMjx/w5lH379nXyokWLylpTOTGHCQAAoAgMmAAAAAIYMAEAAAQwhwmxYA4TYlBTc5hQHcxhQhyYwwQAAFAEBkwAAAABDJgAAAACGDABAAAEMGACAAAIYMAEAAAQwIAJAAAggAETAABAAAMmAACAAAZMAAAAAQyYAAAAAhgwAQAABGxc4eu9L2mRpG0zXycddeanbQWv9XUPSdV/3fmizvxUo4+q/ZrzRZ35qWQPSfRRuVS7zqx9ZKyt/GbLxpgZldqVvBTUmWxped3UmVxpec3UmWxped3UWRo+kgMAAAhgwAQAABBQrQHTmCpdt1DUmWxped3UmVxpec3UmWxped3UWYKqzGECAABIEz6SAwAACGDABAAAEFDRAZMxprcxZp4xZoExZnQlrx1ijBlrjFlhjJnd6FgrY8wUY8z8zJ8tq1ljpqY2xphpxpg5xpjXjDHDk1prudBHJddY9z0kJbeP0tBDmZrqvo+S2kNSOvoobT1UsQGTMaaJpN9LOlpSJ0kDjTGdKnX9PIyT1Ns7NlrSVGttB0lTM7na1koaYa3tJKmbpP/N/HtMYq2xo49iUdc9JCW+j8Yp+T0k1XkfJbyHpHT0Ubp6yFpbkX8kdZf0RKN8vqTzK3X9PGtsJ2l2ozxPUkPm6wZJ86pdY5aaJ0nqlYZaY3q99FH89dZVD2VeX6L7KG09lKmrrvoo6T2UqSlVfZT0HqrkR3LbS3qnUV6cOZZkra21SzNfL5PUuprF+Iwx7SR1lfSyEl5rjOijGNVpD0np66NE/2zqtI/S1kNSgn82aeghJn3nya4f6iZmDQZjTDNJEyWdY61d2fh7SasV/5Wknw09lE5J+9nQR+mUpJ9NWnqokgOmJZLaNMo7ZI4l2XJjTIMkZf5cUeV6JEnGmE20vrnusdY+kDmcyFrLgD6KQZ33kJS+Pkrkz6bO+yhtPSQl8GeTph6q5IBpuqQOxpj2xphNJQ2QNLmC1y/GZElDMl8P0frPV6vKGGMk3SHpdWvttY2+lbhay4Q+KhE9JCl9fZS4nw19lLoekhL2s0ldD1V4QlcfSW9IWijpgmpP4PJqu1fSUklfaf1n0adJ2kbrZ+jPl/Q3Sa0SUOfBWn978p+SZmX+6ZPEWumjZPYRPZTsPkpDD9FHye6htPRR2nqIrVEAAAACmPQNAAAQwIAJAAAggAETAABAAAMmAACAAAZMAAAAAQyYAAAAAkoaMBljehtj5hljFhhjkrGbMFKHPkKp6CHEgT5CLkWvw2SMaaL1C3b10vpFsaZLGmitnZPjOSz6VKOstaaY5xXaR/RQTXvfWrtdoU/ivQiNVeq9KPMc+qhGZeujUu4w7S9pgbX2TWvtGknjJfUr4XyoT/QRvraoyOfRQ4gDfYScShkwbS/pnUZ5ceaYwxgzzBgzwxgzo4RroXYF+4geQgDvRYgDfYScNi73Bay1YySNkbh9ieLQQ4gDfYQ40Ef1q5Q7TEsktWmUd8gcAwpBH6FU9BDiQB8hp1IGTNMldTDGtDfGbCppgKTJ8ZSFOkIfoVT0EOJAHyGnoj+Ss9auNcb8RNITkppIGmutfS22ylAX6COUih5CHOgjhBS9rEBRF+Pz3ppV7K/yFooeqmkzrbX7VuJC9FHtqtR7kUQf1bK4lxUAAACoCwyYAAAAAhgwAQAABDBgAgAACGDABAAAEMCACQAAIIABEwAAQAADJgAAgICyb76L6po6dWrkmDHuelyHH354pcqpa506dXLyMccc4+Rhw4Y5efr06ZFzvPrqqzmvcd111zl5zZo1hZQIANgA7jABAAAEMGACAAAIYMAEAAAQwBymGvO73/3OyQceeGDkMX/84x8rVU7d+tGPfhQ5dvXVVzu5WbNmOc+x8847R44NGDAg53P8eU/Tpk3L+XgAleX/d3/iiSc6+csvv3TyPvvsEznHVltt5eRBgwY5+emnn3bykiVLCi0zYtmyZZFjkyZNcvKMGTNKvk6ScYcJAAAggAETAABAAAMmAACAAGOtrdzFjKncxerEFVdc4eThw4c7+auvvoo854c//KGTJ0yYUHId1loTflTp0tJDrVq1ihx7/fXXnfyNb3wj9ut+/PHHTvbnR0jSk08+Gft1YzLTWrtvJS6Ulj5C4Sr1XiQV10dXXXWVk0eOHBlbPZW2bt06J8+ZM8fJ9957b84sSW+99VbsdcUhWx9xhwkAACCAARMAAEAAAyYAAIAABkwAAAABLFyZct26dXPyJpts4uTnnnsu8pw4Jnkjtw8//DBy7Je//KWTr7nmGidvscUWTn777bcj59hxxx1zXnfrrbd2cu/evSOPSfCkb6RY27Ztndy0aVMnDxw4MPKcM844I+c5H330USefeuqpRVaXHMcff3zJ5/jggw+c/M9//rPkc86bN8/Ju+66q5P99xZJ6tq1q5M7d+7s5Msuu8zJ2epM6qTvbLjDBAAAEMCACQAAIIABEwAAQABzmIrUo0cPJ19wwQVOzvZ5fbZ5LYXyz+t/Zrxw4UInp3lRtFpzyy23OPnHP/6xk/faay8nr1y5suRr3nTTTSWfA+jZs2fkmD8Xx39vatGihZOLWSTZn6NZC4466ignd+zY0clvvPFG8Byff/65k5cuXVp6YQH+hr+S9K9//cvJoTmWffv2jRzz56klGXeYAAAAAhgwAQAABDBgAgAACGAOU5HGjBnj5A4dOji5U6dOkedkWxOpUD//+c+dvM022zj59NNPd/I//vGPkq+J8rj00kud7M+D69KlS8nX2HTTTUs+B2rf7bff7uRvf/vbTt5vv/0KPueqVaucfM8990QeM336dCf7m7N++eWXBV836fx5pn5OqmOOOSZyLDRnafXq1U6+7bbbYq2p0rjDBAAAEMCACQAAIIABEwAAQEBwDpMxZqykYyStsNZ2zhxrJekvktpJektSf2vtR+UrM3n8dTD8NUY233zzkq+RbQ6Lv1/TunXrYr9uOdBHUffff7+T/Tlu2fZ88+eWhPjzpCTp+9//fkHnSAp6qDj+PEdJuvzyy538gx/8wMn+mnEzZ86MnOOKK65w8uzZs538xRdfODnb3ojVQB9l5893vOGGG5w8ePDggs/ZvXt3J8+aNavwwhIknztM4yT5O3iOljTVWttB0tRMBnIZJ/oIpRknegilGyf6CEUIDpistc9K8peo7ifprszXd0k6Lua6UGPoI5SKHkIc6CMUq9hlBVpba79ei32ZpNYbeqAxZpikYUVeB7Utrz6ih5AD70WIA32EoJLXYbLWWmPMBjcJstaOkTRGknI9DvUtVx/RQ8gH70WIA32EDSl2wLTcGNNgrV1qjGmQtCLOopLokksucbI/+fb11193cjELRm655ZZOHjVqVOQxW2yxhZNfeuklJ/sTiROu7vqosUGDBjnZ33zX31i5GHEslppwdd1D+fjFL34ROXbaaac5+cYbb3Syv4jqp59+Gn9hyVJ3ffQ///M/Tj7llFOcPHTo0OA5vvrqKyefffbZTp47d25xxSVUscsKTJY0JPP1EEmT4ikHdYY+QqnoIcSBPkJQcMBkjLlX0ouSdjXGLDbGnCbpCkm9jDHzJfXMZGCD6COUih5CHOgjFCv4kZy1duAGvnVEzLWghtFHKBU9hDjQRygWm+9m0aZNm8gxf1PbtWvXOvknP/mJk997772Cr3vttdc6+YQTTog85t1333XyQQcdVPB1UH677bZb5NiDDz7o5F122cXJG28c/3+OkydPjv2cqC5/HqM/19Gfi3LOOedEzjFt2jQnP/HEE06uxU1v69n+++8fOeYvjNukSZOCz+sv2OwvTvqf//yn4HMmGVujAAAABDBgAgAACGDABAAAEMAcJkXXu/HnmkjStttu62R/3ZJnnnmm4OuOHDnSyfmse3HZZZcVfB1U3u677x451r59eyeXY86S79xzz40cO+uss8p+XZTPhRde6GR/DtOECROcnG0TZ+Yo1Zf+/ftHjhUzZ8nnb9j76KOPOnnGjBlOfvjhhyPn8P++9TdxThLuMAEAAAQwYAIAAAhgwAQAABBg/HUUynqxKm1U6M8VOfnkk518xx13OHmjjaLjyHXr1jl5+vTpTp40yV1J319TSZJatWrl5IceesjJXbt2dfLdd98dOccPfvCDyLEksNaaSlwnzZtd+vssXXnllU7efPPNY7/mxIkTI8e+//3vx36dmMy01u5biQuluY/892w/H3fccU6ut7W4KvVeJKWnjw488MDIMX+/wP3228/J/rzdcvH/br3uuuucfNVVVzl5xYrKbPOXrY+4wwQAABDAgAkAACCAARMAAEAAAyYAAICAupj07U/yHjduXM7HGxOdM7hgwQIn77zzzjnP4S/YJUnbb7+9kxsaGpzsb9jrfz/JmPRduKOPPtrJW2+9dfA5/i8w3HTTTU5u3ry5k5n0nV2a++jll1928r77uv/KlixZ4uTTTjstco4pU6bEX1hCMOm7ODvuuKOT/UnfrVu3jjzn+OOPd7L/S0nZ/i4tlL8o9BFHHBF5jD9xPA5M+gYAACgCAyYAAIAABkwAAAABNTeH6cQTT4wc8xeAXLt2rZM//vhjJ5900kmRc3z00UdOvuaaa5x86KGHBmvzP88NLUC3bNmyyDkOO+wwJy9cuDB43UpgDlNl+D108cUXO/miiy5ycrb+8OcALFq0KJ7iSld3c5gOOOAAJ7/66qtOXrNmTeQ5/gK4/oKov/jFL5z86aefBq87d+7ccLEpwRym6hk0aJCT/Y2+999//5KvMXr06Mgxf3HLODCHCQAAoAgMmAAAAAIYMAEAAATU3Bymp556KnKsbdu2Tr700kudfOeddxZ8nU6dOjn51ltvdXL37t0jzwnNYfL9+c9/jhwbPHhwviVWFHOYKmOzzTZz8pdffpnz8dnmpvTq1cvJixcvLr2weNTUHKZs66g98sgjTvbXvjn33HOdnG0Dbp+/Xs7y5cuDzznkkEOc/MILLwSfkxbMYUoOf924v/3tb5HH9OjRo6Bz3n777ZFjw4YNK6ywPDCHCQAAoAgMmAAAAAIYMAEAAARsHH5IukyaNCly7IEHHnDyO++8U/J1/HkDnTt3Dj5n4MCBTp49e3bOxydobgkSwp9/F3LHHXdEjtFXlfHKK69Ejvl7/Y0aNcrJ+cxZ8g0fPjzn97PNGwm99wBx8Nc8nDlzZuQxhc5heuONN0qqqRTcYQIAAAhgwAQAABDAgAkAACCAARMAAEBAzS1cWS4tWrRwsj/59swzz3Rytk1PO3bsGH9hCVFLC1dus802kWP+4qb33ntvzhyHbAsf+gtR+pOIfTvvvHPk2JtvvllaYeVTUwtXnn/++ZFjF154oZObNm1a8Hnnz5/v5A4dOjjZ30z5e9/7XuQc2Sak14p6XLjSf684/fTTnZxtAdsJEyaUtSZJatKkiZOfeOKJyGMOP/zwnOfwJ45ne/xzzz1XRHW5sXAlAABAERgwAQAABAQHTMaYNsaYacaYOcaY14wxwzPHWxljphhj5mf+bFn+cpFW9BFKRQ8hDvQRipXPwpVrJY2w1r5ijNlK0kxjzBRJQyVNtdZeYYwZLWm0pFE5zpNq/hylM844w8krVqxwcuhz2TqUmj664YYbIseOPfZYJ/vz0d59910nL1myJHKOBQsWOHmfffbJec6f/exnkXOE5ixdc801OetKudT0kCRdfvnlkWNfffWVk7t27erknj17Bs/bsqX79/ijjz7q5JEjRzrZ7zukq4+y+eY3v+nkxx9/3Mnf/va3nez3TLm0bt3ayT/96U+dXMzfi6+//rqTyzFfKV/BO0zW2qXW2lcyX6+S9Lqk7SX1k3RX5mF3STquXEUi/egjlIoeQhzoIxSroK1RjDHtJHWV9LKk1tbapZlvLZPUegPPGSZpWPElotYU2kf0EHy8FyEO9BEKkfekb2NMM0kTJZ1jrV3Z+Ht2/doEWX+90lo7xlq7b6V+XRjJVkwf0UNojPcixIE+QqHyusNkjNlE6xvrHmvt1zvZLjfGNFhrlxpjGiSt2PAZ0qVt27aRYz/84Q+d7K9fNWbMGCezwWlUWvroxhtvjBxr3769k7t37+7kp59+2slvvfVW5Bxz5sxx8iGHHOLkrbbaKlib33f++iq//OUvnfzll18Gz5kmaemhDbn66qurXQKU/j667rrrnOzPWfL571+SNG/ePCd/8cUXOc+Rbc0wf56lP2cpn/c0Y9zljlatWuXks88+O3iOSsnnt+SMpDskvW6tvbbRtyZLGpL5eoikSfGXh1pBH6FU9BDiQB+hWPncYTpI0imS/mWMmZU59nNJV0iaYIw5TdIiSf3LUyJqBH2EUtFDiAN9hKIEB0zW2uckbWip+SPiLQe1ij5CqeghxIE+QrHYSy6LN954I3Jsp512cvLdd9/t5KFDh5azpMSrpb3ksvHXN/LXtrn55psrUseHH37o5Gz73qVYTe0lh+qoh73k/L3ibr311oLP8eqrrzr5k08+yfl4fz9VKbqOWDE+/fRTJ3/3u9918tSpU0u+RjHYSw4AAKAIDJgAAAACGDABAAAEMGACAAAIKGhrlHpx5513Ro5dcsklTp40iSU66smIESOcvNlmmzm5WbNmwXP4EyQHDhyY8/HZJmH26tUreB0AtW3KlClOHj9+vJMHDBgQPEccE7ZD1q5d62R/wU1JmjhxopNffvnlstZUCu4wAQAABDBgAgAACGDABAAAEMDClYhFrS9ciYpg4UqUrB4WrvT5cyr9xR8PP/zwyHP8BZr79u2b8xr+Rt/ZPPXUUzmfM2vWLKUFC1cCAAAUgQETAABAAAMmAACAAOYwIRbMYUIMmMOEktXjHCbEjzlMAAAARWDABAAAEMCACQAAIIABEwAAQAADJgAAgAAGTAAAAAEMmAAAAAIYMAEAAAQwYAIAAAhgwAQAABDAgAkAACCAARMAAEDAxhW+3vuSFknaNvN10lFnftpW8Fpf95BU/dedL+rMTzX6qNqvOV/UmZ9K9pBEH5VLtevM2kfG2spvtmyMmVGpXclLQZ3JlpbXTZ3JlZbXTJ3JlpbXTZ2l4SM5AACAAAZMAAAAAdUaMI2p0nULRZ3JlpbXTZ3JlZbXTJ3JlpbXTZ0lqMocJgAAgDThIzkAAIAABkwAAAABFR0wGWN6G2PmGWMWGGNGV/LaIcaYscaYFcaY2Y2OtTLGTDHGzM/82bKaNWZqamOMmWaMmWOMec0YMzyptZYLfVRyjXXfQ1Jy+ygNPZSpqe77KKk9JKWjj9LWQxUbMBljmkj6vaSjJXWSNNAY06lS18/DOEm9vWOjJU211naQNDWTq22tpBHW2k6Sukn638y/xyTWGjv6KBZ13UNS4vtonJLfQ1Kd91HCe0hKRx+lq4estRX5R1J3SU80yudLOr9S18+zxnaSZjfK8yQ1ZL5ukDSv2jVmqXmSpF5pqDWm10sfxV9vXfVQ5vUluo/S1kOZuuqqj5LeQ5maUtVHSe+hSn4kt72kdxrlxZljSdbaWrs08/UySa2rWYzPGNNOUldJLyvhtcaIPopRnfaQlL4+SvTPpk77KG09JCX4Z5OGHmLSd57s+qFuYtZgMMY0kzRR0jnW2pWNv5e0WvFfSfrZ0EPplLSfDX2UTkn62aSlhyo5YFoiqU2jvEPmWJItN8Y0SFLmzxVVrkeSZIzZROub6x5r7QOZw4mstQzooxjUeQ9J6eujRP5s6ryP0tZDUgJ/NmnqoUoOmKZL6mCMaW+M2VTSAEmTK3j9YkyWNCTz9RCt/3y1qowxRtIdkl631l7b6FuJq7VM6KMS0UOS0tdHifvZ0Eep6yEpYT+b1PVQhSd09ZH0hqSFki6o9gQur7Z7JS2V9JXWfxZ9mqRttH6G/nxJf5PUKgF1Hqz1tyf/KWlW5p8+SayVPkpmH9FDye6jNPQQfZTsHkpLH6Wth9gaBQAAIIBJ3wAAAAEMmAAAAAIYMAEAAAQwYAIAAAhgwAQAABDAgAkAACCgpAGTMaa3MWaeMWaBMSYZuwkjdegjlIoeQhzoI+RS9DpMxpgmWr9gVy+tXxRruqSB1to5OZ7Dok81ylprinleoX1ED9W096212xX6JN6L0Fil3osyz6GPalS2PirlDtP+khZYa9+01q6RNF5SvxLOh/pEH+Fri4p8Hj2EONBHyKmUAdP2kt5plBdnjjmMMcOMMTOMMTNKuBZqV7CP6CEE8F6EONBHyGnjcl/AWjtG0hiJ25coDj2EONBHiAN9VL9KucO0RFKbRnmHzDGgEPQRSkUPIQ70EXIqZcA0XVIHY0x7Y8ymkgZImhxPWagj9BFKRQ8hDvQRcir6Izlr7VpjzE8kPSGpiaSx1trXYqsMdYE+QqnoIcSBPkJI0csKFHUxPu91dOzY0cmPP/64k5s0aRJ5Ttu2bctaU7GK/VXeQtFDNW2mtXbfSlyIPqpdlXovkuijWhb3sgIAAAB1gQETAABAAAMmAACAAAZMAAAAAWVfuBL/deONNzr5xBNPdHKrVq2c/Mgjj5S9JgAAEMYdJgAAgAAGTAAAAAEMmAAAAAKYwxST1q1bO/mBBx6IPKZbt25O9hcNnT17tpNPO+20mKoDAACl4A4TAABAAAMmAACAAAZMAAAAAcxhKpK/ce7VV1/t5AMOOCB4jvPPP9/JM2bMcPIHH3xQZHVIImPcvRzvvfdeJ/fp08fJnTp1ipxj8eLF8RcGoO6ccsopTj7yyCOd3KVLFyfvuuuuwXO+9NJLTj722GOd/MknnxRSYuJwhwkAACCAARMAAEAAAyYAAIAA5jAVyd/3zZ9/kg9/Psq0adNKqgnJ1rRpUycfdNBBTm7WrJmTe/fuHTnH7bffHn9hAGrKtttu6+Rs7xv+/KKPP/7YyS+88IKT33rrrcg5DjvsMCcffPDBTn7xxRednG1eZppwhwkAACCAARMAAEAAAyYAAIAABkwAAAABTPrOk79Q5Z///Gcn+4sSZnP88cc7edKkSaUXhtT4/PPPnTx//nwnb7/99k7ebrvtyl4T6tOIESOcvOmmmzp59913d/KgQYOC55w7d66T99hjjyKrQ6kef/xxJ7dr1y7ymKuuusrJv/3tb5384YcfBq+z2267Ofn//b//52T/782LLrooco5f//rXweskBXeYAAAAAhgwAQAABDBgAgAACGAOU578jQp33HFHJz/22GNO/vGPfxw5x5IlS+IvDKn1+9//3sn+InD+PBIgm0MPPdTJnTt3zvl9Sfrud7/r5NAcTGttsI4OHTo4ec6cOU5O+6KFSdarVy8nd+3a1ckTJkyIPMff/L0Y/ry16667zskXXnihk0899dTIOZjDBAAAUEMYMAEAAAQwYAIAAAhgDlMW/qaDktSlSxcn+xsRnnvuuU5mvhJC/DVLfP37948cGzVqlJOXLl0aa02orIaGBiffe++9Tt5pp52C52jRooWTt9xySydnm580c+ZMJ++9997B64RstJH7/99+HSifjTd2/ypfsGCBk8ePH1+ROu6//34n+3OYNt9888hzmjdv7uSVK1fGX1hMuMMEAAAQwIAJAAAggAETAABAQHAOkzFmrKRjJK2w1nbOHGsl6S+S2kl6S1J/a+1H5SuzvPr16+fkAw44IPIYfx2S++67z8lffvll/IXVkHroo1L5c038/b0kqW/fvk6+9dZby1pTktRCD/Xs2dPJt912m5PbtGkT+zWzrX/0/vvvO3nbbbd18re+9S0n33nnnZFz7LDDDjmv66/DlBS10Ee+adOmOdlfh8nfx7JcVq9enfP7rVu3jhw76aSTnHzLLbfEWlOc8rnDNE5Sb+/YaElTrbUdJE3NZCCXcaKPUJpxoodQunGij1CE4IDJWvusJH/b4n6S7sp8fZek42KuCzWGPkKp6CHEgT5CsYpdVqC1tfbr32deJil6ny3DGDNM0rAir4Pallcf0UPIgfcixIE+QlDJ6zBZa60xZoMbDVlrx0gaI0m5Hof6lquP6CHkg/cixIE+woYUO2BaboxpsNYuNcY0SFoRZ1HltvXWWzv5kEMOKfgcH33kzgdcvHhxSTVJ0vDhw52czwTQkSNHlnzdKkp1H8Utnw1Os00Er3Op6qGf/exnTi5mkrc/sdZfzPSll15y8rx584Ln/OCDD5zsvxeFJnhL0cV8/Q3LEy5VfeRLyi8dvfnmm05+7bXXnLzHHntEnuNv2pxkxS4rMFnSkMzXQyRNiqcc1Bn6CKWihxAH+ghBwQGTMeZeSS9K2tUYs9gYc5qkU3W7KQAADK1JREFUKyT1MsbMl9Qzk4ENoo9QKnoIcaCPUKzgR3LW2oEb+NYRMdeCGkYfoVT0EOJAH6FYdbn57n/+8x8n77PPPk72N5GUpHXr1jn52WefLfi6/ga9vrPOOsvJbdu2DZ5zxIgRTvbnGrAJMFAdRx55ZORYt27dCjrH22+/HTnmzw16/vnnCyssD/nMWfJNmuR+iuUvjona99VXXzl57dq1VaqkPNgaBQAAIIABEwAAQAADJgAAgIC6nMN06KGHOtlfh8mfryRF5xKEPp/v0qVL5Jh/HX8jVd9nn30WOeav97Trrrs6+f7773fygAEDIudYtGhRzusCKJ0/v1CStthii5zPeeGFF5z8q1/9KvKYOOYstWzZ0sm9e7tbq/Xo0SN4Dr/Wxx57rOS6kG6bbbaZkzfffPPgc1atWlWucmLHHSYAAIAABkwAAAABDJgAAAAC6mIO01ZbbeXk9u3b53z8u+++Gzn2pz/9yckLFixwcseOHZ183nnnRc7Rr18/J/vzoJ588kknX3PNNZFztGjRwslPPfVUzu8jPYwxTs5nbzkk15gxYyLHtt12Wyd/8sknTj7ppJOcvGzZsvgLk/TjH//YyZdccknOx/t7gklS//79nVyuWpEe7dq1c7I/xzabxx9/vKBr+P8NSdJee+3l5O7duzv5vvvuc3I++ytmwx0mAACAAAZMAAAAAQyYAAAAAhgwAQAABNTFpO+DDz7Yyb/73e9yPv62226LHPv1r3/t5NatWzv56quvdnKfPn0i5/AX6JowYYKTR44c6eQOHTpEznHLLbfkPOfUqVOdzCKV6cEk79oyceLEvI6V27HHHhs5dtFFF+V8jr9pqv++IzHJu974i1JK0U2aDzzwwILP6/fWzJkznbz33ns7uVWrVpFztGnTxsn+34u77LKLk4cOHVpomZK4wwQAABDEgAkAACCAARMAAEBAXcxh2nPPPQt6vD9fKZsHHnjAyQcccEDwOf7Clc8884yTu3Xr5uTnnnsueM7rrrvOyf48KNSWf/7zn9UuASnz0EMPRY6F5sudffbZTs62CCeSo2nTpk7+xje+4WR/HpAU/fvm8MMPz3mNbBvp7rHHHvmWuEH+OUKLL48dOzZy7NFHH3Wyvyj0W2+9VVxxHu4wAQAABDBgAgAACGDABAAAEFAXc5i23nprJ/ubnE6aNCl4ji5dujjZ32TQP+eIESMi5/DnLPkb9v75z3/Oec5s5/XnMKG2LVy4sNolIOF+85vfOHmjjaL/X7xu3bqc5/Dfq1A9/vwkSbr44oud7K+1tdtuu5V83ZUrVzrZX9tIiq7XtfHGuYcUt99+e+SYvw7TK6+8km+JFccdJgAAgAAGTAAAAAEMmAAAAALqYg6Tz1+DpJg9vPw5AP45sq399PbbbzvZX9fi3//+t5MPOeSQyDk++eSTguoEUNs23XRTJ3ft2tXJ2eYr+e9Xw4cPd/L8+fNjqg6lyraOVq9evZy8evVqJ/vrEvl/t0jRubv+Ofy1ixYvXhw5x9y5c53sz8t98803nfzTn/40co5PP/00ciypuMMEAAAQwIAJAAAggAETAABAAAMmAACAgLqY9O1PbjvvvPOc7G+K629KKEUXrtxqq61yXnPw4MGRY/5ClP4Ggf5iZEuWLMl5DdSfzTbbrNoloMq22GILJ5988slO9icEZ3Pvvfc6+Z577nFyaGFLVM6RRx4ZOeZP4j7++OOdPGvWrJKv6y9CeeWVV0Yes/322zt5xYoVTu7fv7+T0zTBOxvuMAEAAAQwYAIAAAgIDpiMMW2MMdOMMXOMMa8ZY4ZnjrcyxkwxxszP/Nmy/OUiregjlIoeQhzoIxTLhBZtNMY0SGqw1r5ijNlK0kxJx0kaKulDa+0VxpjRklpaa0cFzlX4CpEx2GeffZw8bdo0J/tzArJtelvM4pY+//PbP/3pT04+66yzSr5GtVhro//SGomrj6rVQ+XQpk0bJy9atCj4HH+BwRtvvDHWmqpsprV23w19sxbeiwqVba7kbbfd5uTvf//7Oc+RbbHAm266ycm1NGepUu9FmXOVvY+y/Wz8OUr777+/k/1NcfPhL6R83333Ofk73/lO5Dn+YpdHH320k59++umC60iKbH0UvMNkrV1qrX0l8/UqSa9L2l5SP0l3ZR52l9Y3HJAVfYRS0UOIA32EYhX0W3LGmHaSukp6WVJra+3SzLeWSWq9gecMkzSs+BJRawrtI3oIPt6LEAf6CIXIe9K3MaaZpImSzrHWrmz8Pbv+86qstyattWOstfvmutWO+lFMH9FDaIz3IsSBPkKh8rrDZIzZROsb6x5r7QOZw8uNMQ3W2qWZz4RXbPgM1TVz5kwnDxw40Mn+Z/yHHXZYwde46667nPyvf/0r8phXX33Vyc8880zB10mztPdR3JYvX+7k1157zcl77LFHJctJhXrrIX+dGyk8Z2nhwoVOvuGGG2KtqRakqY/eeOONyDF/XcAxY8Y4eZtttnHyP/7xj8g5/I1x/fUJd911Vye//PLLkXOcccYZTo5j/acky+e35IykOyS9bq29ttG3Jksakvl6iKRJ/nOBr9FHKBU9hDjQRyhWPneYDpJ0iqR/GWO+Hj7+XNIVkiYYY06TtEhS/w08H5DoI5SOHkIc6CMUJThgstY+J2lDv6Z5RLzloFbRRygVPYQ40EcoVnAdplgvlpK1T1C40NoncanlHpo+fbqT/fXDJOmRRx5xct++fctaU4XlXIcpTknto912283JI0aMiDzm1FNPdbI/x8VfCyef9b1qSaXei6Tq9dEll1zi5JEjRzp5o40K38Rj8uTJTr7jjjuc/Pjjjxd8zjQrah0mAACAeseACQAAIIABEwAAQAADJgAAgICCtkYBUD7+om/ZJn03a9asUuWgCn7xi184+cQTTww+x9+Aud4medcjv0/8jPLgDhMAAEAAAyYAAIAABkwAAAABzGECEuKyyy5zcufOnSOPmTBhQqXKQQX4Gyw3b948+Bx/o9Wnnnoq1poAZMcdJgAAgAAGTAAAAAEMmAAAAALYfBexYPNdxKDuNt+98sornexvtpttTaU+ffo4ed68efEXlmL1sPkuyo/NdwEAAIrAgAkAACCAARMAAEAA6zABQJU8+eSTTvbnMP30pz+NPIc5S0B1cIcJAAAggAETAABAAAMmAACAAAZMAAAAASxciViwcCViUHcLVyJ+LFyJOLBwJQAAQBEYMAEAAAQwYAIAAAio9MKV70taJGnbzNdJR535aVvBa33dQ1L1X3e+qDM/1eijar/mfFFnfirZQxJ9VC7VrjNrH1V00vf/XdSYGZWa3FkK6ky2tLxu6kyutLxm6ky2tLxu6iwNH8kBAAAEMGACAAAIqNaAaUyVrlso6ky2tLxu6kyutLxm6ky2tLxu6ixBVeYwAQAApAkfyQEAAAQwYAIAAAio6IDJGNPbGDPPGLPAGDO6ktcOMcaMNcasMMbMbnSslTFmijFmfubPltWsMVNTG2PMNGPMHGPMa8aY4UmttVzoo5JrrPsekpLbR2nooUxNdd9HSe0hKR19lLYeqtiAyRjTRNLvJR0tqZOkgcaYTpW6fh7GSertHRstaaq1toOkqZlcbWsljbDWdpLUTdL/Zv49JrHW2NFHsajrHpIS30fjlPwekuq8jxLeQ1I6+ihdPWStrcg/krpLeqJRPl/S+ZW6fp41tpM0u1GeJ6kh83WDpHnVrjFLzZMk9UpDrTG9Xvoo/nrrqocyry/RfZS2HsrUVVd9lPQeytSUqj5Keg9V8iO57SW90ygvzhxLstbW2qWZr5dJal3NYnzGmHaSukp6WQmvNUb0UYzqtIek9PVRon82ddpHaeshKcE/mzT0EJO+82TXD3UTswaDMaaZpImSzrHWrmz8vaTViv9K0s+GHkqnpP1s6KN0StLPJi09VMkB0xJJbRrlHTLHkmy5MaZBkjJ/rqhyPZIkY8wmWt9c91hrH8gcTmStZUAfxaDOe0hKXx8l8mdT532Uth6SEvizSVMPVXLANF1SB2NMe2PMppIGSJpcwesXY7KkIZmv/397d2wSQRAGYPTbFrQdK7gmDCzDHuzA2CYsQkwOFGOLMDiDNRCTCURvBt6DjTYZmC/42RnY6/bz1bPatm2r7qvj6XS6+/ZqurX+ER39koaq9Tqabm90tFxDNdneLNfQP1/oOlQv1Vt1e+4LXD/W9lC9Vx/tZ9E31WX7Df3X6rG6mGCdV+2fJ5+rp6/nMONadTRnRxqau6MVGtLR3A2t0tFqDfk1CgDAgEvfAAADBiYAgAEDEwDAgIEJAGDAwAQAMGBgAgAYMDABAAx8Aj8ZkRV1S/KkAAAAAElFTkSuQmCC\n",
"text/plain": [
"<Figure size 720x720 with 16 Axes>"
]
},
"metadata": {
"tags": [],
"needs_background": "light"
}
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-Fw-hSFrmjbu",
"colab_type": "text"
},
"source": [
"Let's look at the current dimensions of our training dev sets."
]
},
{
"cell_type": "code",
"metadata": {
"id": "fu8z_yLajLoO",
"colab_type": "code",
"outputId": "e03f80ca-218d-43fa-9351-b9f687b88c33",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 35
}
},
"source": [
"print(\"Shape of Training sets : \", x_train.shape, ' | ', y_train.shape) "
],
"execution_count": 48,
"outputs": [
{
"output_type": "stream",
"text": [
"Shape of Training sets : (50000, 784) | (50000,)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "t7y-ap2VnTu_",
"colab_type": "text"
},
"source": [
"50000 training examples, each a 28X28 (~784) pixels/data-points (for X) and 1 label as Y"
]
},
{
"cell_type": "code",
"metadata": {
"id": "H_meHLDTm5eE",
"colab_type": "code",
"outputId": "d20a6ace-08c3-4bd3-f164-26c4d2d28e29",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 35
}
},
"source": [
"print(\"Shape of Dev sets : \", x_valid.shape, ' | ', y_valid.shape) # 10000 test/dev-set examples, each a 28X28 (~784) pixels/data-points"
],
"execution_count": 49,
"outputs": [
{
"output_type": "stream",
"text": [
"Shape of Dev sets : (10000, 784) | (10000,)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "iTcX52nrnaFO",
"colab_type": "text"
},
"source": [
"10,000 training examples, each a 28X28 (~784) pixels/data-points (for X) and 1 label as Y"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "eZlu3P77nnqs",
"colab_type": "text"
},
"source": [
"---\n",
"\n",
"### Map data to `torch.tensor` insted of `np.array` for futher computing\n"
]
},
{
"cell_type": "code",
"metadata": {
"id": "OK7Ha-q2nL8A",
"colab_type": "code",
"colab": {}
},
"source": [
"import torch\n",
"\n",
"x_train, y_train, x_valid, y_valid = map(torch.tensor, (x_train, y_train, x_valid, y_valid))"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "BTE83Sppn6bg",
"colab_type": "code",
"outputId": "c9bbacdf-b57e-4662-ca12-b1ef31142242",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 173
}
},
"source": [
"# seeing what our tensor looks like\n",
"print(x_train, y_train) \n",
"print(x_train.shape) # 50000x784 matrix\n",
"print(y_train.min(), y_train.max()) # 0 -- 9 labels"
],
"execution_count": 51,
"outputs": [
{
"output_type": "stream",
"text": [
"tensor([[0., 0., 0., ..., 0., 0., 0.],\n",
" [0., 0., 0., ..., 0., 0., 0.],\n",
" [0., 0., 0., ..., 0., 0., 0.],\n",
" ...,\n",
" [0., 0., 0., ..., 0., 0., 0.],\n",
" [0., 0., 0., ..., 0., 0., 0.],\n",
" [0., 0., 0., ..., 0., 0., 0.]]) tensor([5, 0, 4, ..., 8, 4, 8])\n",
"torch.Size([50000, 784])\n",
"tensor(0) tensor(9)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "r6MPQZGlouJ9",
"colab_type": "text"
},
"source": [
"---\n",
"\n",
"### Creating NN\n",
"\n",
"We will exploit PyTorch's autograd feature to dynamically compute gradient graph (for backward propagation) using `require_grad = True`\n",
"\n",
"#### Initialize weights and biases\n"
]
},
{
"cell_type": "code",
"metadata": {
"id": "7bK_YL68oc6J",
"colab_type": "code",
"colab": {}
},
"source": [
"n = x_train.shape[1]\n",
"w = torch.randn(n, 10) / np.sqrt(2.0/n)\n",
"w.requires_grad_(requires_grad=True)\n",
"b = torch.zeros(10, requires_grad=True)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "iWqIQHo-sCzQ",
"colab_type": "text"
},
"source": [
"### Create activation, cost and loss functions using `torch.NN.functional`"
]
},
{
"cell_type": "code",
"metadata": {
"id": "vne2ESqirNPt",
"colab_type": "code",
"colab": {}
},
"source": [
"import torch.nn.functional as F\n",
"\n",
"def learn(x):\n",
" z = x @ w + b\n",
" return z\n",
"\n",
"# `F.cross_entropy` criterion combines log_softmax and nll_loss in a single function.\n",
"\n",
"loss_func = F.cross_entropy"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "Fd6zW0_bvKtZ",
"colab_type": "text"
},
"source": [
"#### Forward Pass using Mini-Batch"
]
},
{
"cell_type": "code",
"metadata": {
"id": "m9n7nW7fsN6L",
"colab_type": "code",
"outputId": "98611a5a-4005-4402-cd15-524219bc4481",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 69
}
},
"source": [
"bs = 64 # batch size for using mini batch\n",
"xb = x_train[0:bs]\n",
"y = learn(xb)\n",
"y[0], y.shape\n",
"print()\n",
"print(y[0], y.shape)"
],
"execution_count": 54,
"outputs": [
{
"output_type": "stream",
"text": [
"\n",
"tensor([-238.8471, 183.6411, 116.2699, 356.3234, 394.4142, 168.6241,\n",
" 52.5215, -102.6483, -103.0229, 45.8241], grad_fn=<SelectBackward>) torch.Size([64, 10])\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "GV-NyCocxS1V",
"colab_type": "text"
},
"source": [
"Let’s also implement a function to calculate the accuracy of our model. For each prediction, if the index with the largest value matches the target value, then the prediction was correct.\n"
]
},
{
"cell_type": "code",
"metadata": {
"id": "1WREhgjZw2h1",
"colab_type": "code",
"colab": {}
},
"source": [
"def accuracy(out, yb):\n",
" preds = torch.argmax(out, dim=1)\n",
" return (preds == yb).float().mean()"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "4VPg6S6txcB4",
"colab_type": "text"
},
"source": [
"See the loss and accuracy."
]
},
{
"cell_type": "code",
"metadata": {
"id": "mhMvRfa8xTy9",
"colab_type": "code",
"outputId": "4839e86e-ba1d-40d4-d142-c4e81ad65a11",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 35
}
},
"source": [
"yb = y_train[0:bs]\n",
"print(loss(learn(mini_batch), yb), accuracy(learn(xb), yb))"
],
"execution_count": 56,
"outputs": [
{
"output_type": "stream",
"text": [
"tensor(275.6947, grad_fn=<NllLossBackward>) tensor(0.0781)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Oh8MlCKjyKe3",
"colab_type": "text"
},
"source": [
"\n",
"\n",
"---\n",
"\n",
"### Refactor entire code with `nn.Module` and `nn.Parameter`\n",
"\n",
"Next up, we’ll use `nn.Module` and `nn.Parameter`, for a clearer and more concise training loop. **We subclass nn.Module (which itself is a class and able to keep track of state)**. \n",
"\n",
"In this case, we want to create a class that **holds our weights, bias, and method for the forward step**. nn.Module has a number of attributes and methods (such as `.parameters()` and `.zero_grad()`) which we will be using."
]
},
{
"cell_type": "code",
"metadata": {
"id": "9Ds4PmpIxvP7",
"colab_type": "code",
"colab": {}
},
"source": [
"from torch import nn\n",
"\n",
"class MNIST_learn(nn.Module):\n",
" def __init__(self): \n",
" super().__init__()\n",
" self.w = nn.Parameter(torch.randn(n, 10) / np.sqrt(2.0/n))\n",
" self.b = nn.Parameter(torch.zeros(10))\n",
" def forward(self, xb):\n",
" return xb @ self.w + self.b\n"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "LF8Obh0G0LFf",
"colab_type": "code",
"colab": {}
},
"source": [
"model = MNIST_learn()"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "1dNrL7H70p33",
"colab_type": "text"
},
"source": [
"Now we can calculate the loss in the same way as before. **Note that nn.Module objects are used as if they are functions (i.e they are callable)**, but behind the scenes Pytorch will call our forward method automatically."
]
},
{
"cell_type": "code",
"metadata": {
"id": "R9GGaIfu0j1t",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 35
},
"outputId": "da3e4489-7579-48a3-9103-8b658cd0351f"
},
"source": [
"print(loss(model(xb), yb))"
],
"execution_count": 64,
"outputs": [
{
"output_type": "stream",
"text": [
"tensor(285.2584, grad_fn=<NllLossBackward>)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "46hnetWT0yoP",
"colab_type": "code",
"colab": {}
},
"source": [
"def train():\n",
" for epoch in range(epochs):\n",
" for i in range((n-1)//bs + 1):\n",
" xb = x_train[i*bs : i*bs+bs]\n",
" yb = y_train[i*bs : i*bs+bs]\n",
" loss = loss_func(model(xb), yb)\n",
" \n",
" loss.backward()\n",
" with torch.no_grad():\n",
" for param in model.parameters():\n",
" param = param - lr*param.grad\n",
" model.zero_grad()\n",
"\n",
"lr = 0.75 # learning rate\n",
"epochs = 5 # how many epochs to train for\n",
"\n",
"train()"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "BfNirG-w2qM-",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 35
},
"outputId": "839a1cba-d24c-43a7-dce7-ce2647080c90"
},
"source": [
"print(loss_func(model(xb), yb))"
],
"execution_count": 70,
"outputs": [
{
"output_type": "stream",
"text": [
"tensor(285.2584, grad_fn=<NllLossBackward>)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "8pLj9nvy2wi3",
"colab_type": "code",
"colab": {}
},
"source": [
"class MNIST_learn(nn.Module):\n",
" def __init__(self): \n",
" super().__init__()\n",
" # self.w = nn.Parameter(torch.randn(n, 10) / np.sqrt(2.0/n))\n",
" # self.b = nn.Parameter(torch.zeros(10))\n",
" self.lin = nn.Linear(784, 10)\n",
" def forward(self, xb):\n",
" return self.lin(xb)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "xLkCctwP3W-R",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 52
},
"outputId": "736ed9dc-e2c5-41b4-be20-be6cc55bd0b6"
},
"source": [
"model = MNIST_learn()\n",
"print(loss_func(model(xb), yb))\n",
"\n",
"train()\n",
"\n",
"print(loss_func(model(xb), yb))\n"
],
"execution_count": 74,
"outputs": [
{
"output_type": "stream",
"text": [
"tensor(2.3072, grad_fn=<NllLossBackward>)\n",
"tensor(2.3072, grad_fn=<NllLossBackward>)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "sPODFdFC3j90",
"colab_type": "code",
"colab": {}
},
"source": [
""
],
"execution_count": 0,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment