Skip to content

Instantly share code, notes, and snippets.

@georgehc
Last active December 3, 2020 03:22
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save georgehc/6d0bff4a4445c57baea240cda50935d9 to your computer and use it in GitHub Desktop.
Save georgehc/6d0bff4a4445c57baea240cda50935d9 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 94-775/95-865: Handwritten Digit Recognition with Neural Nets\n",
"\n",
"Author: George H. Chen (georgechen [at symbol] cmu.edu)\n",
"\n",
"This demo shows how to train and evaluate four neural net models using PyTorch:\n",
"\n",
"1. Flatten -> fully connected -> softmax activation*\n",
"\n",
"2. Flatten -> fully connected -> ReLU -> fully connected -> softmax activation*\n",
"\n",
"3. Conv2d -> ReLU -> MaxPool2d -> flatten -> fully connected -> softmax activation*\n",
"\n",
"4. Conv2d -> ReLU -> MaxPool2d -> Conv2d -> ReLU -> MaxPool2d -> flatten -> fully connected -> softmax activation*\n",
"\n",
"*In PyTorch, softmax activation is automatically done as part of using the cross entropy loss."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<torch._C.Generator at 0x130739530>"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%matplotlib inline\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"\n",
"# the next two lines are needed on my Intel-based MacBook Air to get the code to run; you likely don't need these two lines...\n",
"# (in fact I used to not need these two lines)\n",
"import os\n",
"os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'\n",
"\n",
"import torch\n",
"import torch.nn as nn\n",
"import torchvision\n",
"import torchvision.transforms as transforms\n",
"from torchsummaryX import summary\n",
"\n",
"from UDA_pytorch_utils import UDA_pytorch_classifier_fit, \\\n",
" UDA_plot_train_val_accuracy_vs_epoch, UDA_pytorch_classifier_predict, \\\n",
" UDA_compute_accuracy\n",
"\n",
"np.random.seed(0)\n",
"torch.manual_seed(0)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Loading in the data and a quick data inspection¶"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"train_dataset = torchvision.datasets.MNIST(root='data/',\n",
" train=True,\n",
" transform=transforms.ToTensor(),\n",
" download=True)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"train_images = torch.tensor([image.numpy() for image, label in train_dataset])"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"train_labels = torch.tensor([label for image, label in train_dataset])"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([60000, 1, 28, 28])"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_images.shape"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([60000])"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_labels.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We first take a look at the data."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.image.AxesImage at 0x120424090>"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAN80lEQVR4nO3df6hcdXrH8c+ncf3DrBpTMYasNhuRWBWbLRqLSl2RrD9QNOqWDVgsBrN/GHChhEr6xyolEuqP0qAsuYu6sWyzLqgYZVkVo6ZFCF5j1JjU1YrdjV6SSozG+KtJnv5xT+Su3vnOzcyZOZP7vF9wmZnzzJnzcLife87Md879OiIEYPL7k6YbANAfhB1IgrADSRB2IAnCDiRxRD83ZpuP/oEeiwiPt7yrI7vtS22/aftt27d281oAesudjrPbniLpd5IWSNou6SVJiyJia2EdjuxAj/XiyD5f0tsR8U5EfCnpV5Ku6uL1APRQN2GfJekPYx5vr5b9EdtLbA/bHu5iWwC61M0HdOOdKnzjND0ihiQNSZzGA03q5si+XdJJYx5/R9L73bUDoFe6CftLkk61/V3bR0r6kaR19bQFoG4dn8ZHxD7bSyU9JWmKpAci4o3aOgNQq46H3jraGO/ZgZ7ryZdqABw+CDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUii4ymbcXiYMmVKsX7sscf2dPtLly5tWTvqqKOK686dO7dYv/nmm4v1u+66q2Vt0aJFxXU///zzYn3lypXF+u23316sN6GrsNt+V9IeSfsl7YuIs+toCkD96jiyXxQRH9TwOgB6iPfsQBLdhj0kPW37ZdtLxnuC7SW2h20Pd7ktAF3o9jT+/Ih43/YJkp6x/V8RsWHsEyJiSNKQJNmOLrcHoENdHdkj4v3qdqekxyTNr6MpAPXrOOy2p9o++uB9ST+QtKWuxgDUq5vT+BmSHrN98HX+PSJ+W0tXk8zJJ59crB955JHF+nnnnVesX3DBBS1r06ZNK6577bXXFutN2r59e7G+atWqYn3hwoUta3v27Cmu++qrrxbrL7zwQrE+iDoOe0S8I+kvauwFQA8x9AYkQdiBJAg7kARhB5Ig7EASjujfl9om6zfo5s2bV6yvX7++WO/1ZaaD6sCBA8X6jTfeWKx/8sknHW97ZGSkWP/www+L9TfffLPjbfdaRHi85RzZgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJxtlrMH369GJ948aNxfqcOXPqbKdW7XrfvXt3sX7RRRe1rH355ZfFdbN+/6BbjLMDyRF2IAnCDiRB2IEkCDuQBGEHkiDsQBJM2VyDXbt2FevLli0r1q+44opi/ZVXXinW2/1L5ZLNmzcX6wsWLCjW9+7dW6yfccYZLWu33HJLcV3UiyM7kARhB5Ig7EAShB1IgrADSRB2IAnCDiTB9ewD4JhjjinW200vvHr16pa1xYsXF9e9/vrri/W1a9cW6xg8HV/PbvsB2zttbxmzbLrtZ2y/Vd0eV2ezAOo3kdP4X0i69GvLbpX0bEScKunZ6jGAAdY27BGxQdLXvw96laQ11f01kq6uuS8ANev0u/EzImJEkiJixPYJrZ5oe4mkJR1uB0BNen4hTEQMSRqS+IAOaFKnQ287bM+UpOp2Z30tAeiFTsO+TtIN1f0bJD1eTzsAeqXtabzttZK+L+l429sl/VTSSkm/tr1Y0u8l/bCXTU52H3/8cVfrf/TRRx2ve9NNNxXrDz/8cLHebo51DI62YY+IRS1KF9fcC4Ae4uuyQBKEHUiCsANJEHYgCcIOJMElrpPA1KlTW9aeeOKJ4roXXnhhsX7ZZZcV608//XSxjv5jymYgOcIOJEHYgSQIO5AEYQeSIOxAEoQdSIJx9knulFNOKdY3bdpUrO/evbtYf+6554r14eHhlrX77ruvuG4/fzcnE8bZgeQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJxtmTW7hwYbH+4IMPFutHH310x9tevnx5sf7QQw8V6yMjIx1vezJjnB1IjrADSRB2IAnCDiRB2IEkCDuQBGEHkmCcHUVnnnlmsX7PPfcU6xdf3Plkv6tXry7WV6xYUay/9957HW/7cNbxOLvtB2zvtL1lzLLbbL9ne3P1c3mdzQKo30RO438h6dJxlv9LRMyrfn5Tb1sA6tY27BGxQdKuPvQCoIe6+YBuqe3XqtP841o9yfYS28O2W/8zMgA912nYfybpFEnzJI1IurvVEyNiKCLOjoizO9wWgBp0FPaI2BER+yPigKSfS5pfb1sA6tZR2G3PHPNwoaQtrZ4LYDC0HWe3vVbS9yUdL2mHpJ9Wj+dJCknvSvpxRLS9uJhx9sln2rRpxfqVV17ZstbuWnl73OHir6xfv75YX7BgQbE+WbUaZz9iAisuGmfx/V13BKCv+LoskARhB5Ig7EAShB1IgrADSXCJKxrzxRdfFOtHHFEeLNq3b1+xfskll7SsPf/888V1D2f8K2kgOcIOJEHYgSQIO5AEYQeSIOxAEoQdSKLtVW/I7ayzzirWr7vuumL9nHPOaVlrN47eztatW4v1DRs2dPX6kw1HdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgnH2SW7u3LnF+tKlS4v1a665plg/8cQTD7mnidq/f3+xPjJS/u/lBw4cqLOdwx5HdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgnH2w0C7sexFi8abaHdUu3H02bNnd9JSLYaHh4v1FStWFOvr1q2rs51Jr+2R3fZJtp+zvc32G7ZvqZZPt/2M7beq2+N63y6ATk3kNH6fpL+PiD+X9FeSbrZ9uqRbJT0bEadKerZ6DGBAtQ17RIxExKbq/h5J2yTNknSVpDXV09ZIurpXTQLo3iG9Z7c9W9L3JG2UNCMiRqTRPwi2T2ixzhJJS7prE0C3Jhx229+W9Iikn0TEx/a4c8d9Q0QMSRqqXoOJHYGGTGjozfa3NBr0X0bEo9XiHbZnVvWZknb2pkUAdWh7ZPfoIfx+Sdsi4p4xpXWSbpC0srp9vCcdTgIzZswo1k8//fRi/d577y3WTzvttEPuqS4bN24s1u+8886WtccfL//KcIlqvSZyGn++pL+V9LrtzdWy5RoN+a9tL5b0e0k/7E2LAOrQNuwR8Z+SWr1Bv7jedgD0Cl+XBZIg7EAShB1IgrADSRB2IAkucZ2g6dOnt6ytXr26uO68efOK9Tlz5nTUUx1efPHFYv3uu+8u1p966qli/bPPPjvkntAbHNmBJAg7kARhB5Ig7EAShB1IgrADSRB2IIk04+znnntusb5s2bJiff78+S1rs2bN6qinunz66acta6tWrSque8cddxTre/fu7agnDB6O7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQRJpx9oULF3ZV78bWrVuL9SeffLJY37dvX7FeuuZ89+7dxXWRB0d2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUjCEVF+gn2SpIcknSjpgKShiPhX27dJuknS/1ZPXR4Rv2nzWuWNAehaRIw76/JEwj5T0syI2GT7aEkvS7pa0t9I+iQi7ppoE4Qd6L1WYZ/I/Owjkkaq+3tsb5PU7L9mAXDIDuk9u+3Zkr4naWO1aKnt12w/YPu4FusssT1se7irTgF0pe1p/FdPtL8t6QVJKyLiUdszJH0gKST9k0ZP9W9s8xqcxgM91vF7dkmy/S1JT0p6KiLuGac+W9KTEXFmm9ch7ECPtQp729N425Z0v6RtY4NefXB30EJJW7ptEkDvTOTT+Ask/Yek1zU69CZJyyUtkjRPo6fx70r6cfVhXum1OLIDPdbVaXxdCDvQex2fxgOYHAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJ9HvK5g8k/c+Yx8dXywbRoPY2qH1J9NapOnv7s1aFvl7P/o2N28MRcXZjDRQMam+D2pdEb53qV2+cxgNJEHYgiabDPtTw9ksGtbdB7Uuit071pbdG37MD6J+mj+wA+oSwA0k0Enbbl9p+0/bbtm9toodWbL9r+3Xbm5uen66aQ2+n7S1jlk23/Yztt6rbcefYa6i322y/V+27zbYvb6i3k2w/Z3ub7Tds31Itb3TfFfrqy37r+3t221Mk/U7SAknbJb0kaVFEbO1rIy3YflfS2RHR+BcwbP+1pE8kPXRwai3b/yxpV0SsrP5QHhcR/zAgvd2mQ5zGu0e9tZpm/O/U4L6rc/rzTjRxZJ8v6e2IeCcivpT0K0lXNdDHwIuIDZJ2fW3xVZLWVPfXaPSXpe9a9DYQImIkIjZV9/dIOjjNeKP7rtBXXzQR9lmS/jDm8XYN1nzvIelp2y/bXtJ0M+OYcXCarer2hIb7+bq203j309emGR+YfdfJ9OfdaiLs401NM0jjf+dHxF9KukzSzdXpKibmZ5JO0egcgCOS7m6ymWqa8Uck/SQiPm6yl7HG6asv+62JsG+XdNKYx9+R9H4DfYwrIt6vbndKekyjbzsGyY6DM+hWtzsb7ucrEbEjIvZHxAFJP1eD+66aZvwRSb+MiEerxY3vu/H66td+ayLsL0k61fZ3bR8p6UeS1jXQxzfYnlp9cCLbUyX9QIM3FfU6STdU92+Q9HiDvfyRQZnGu9U042p43zU+/XlE9P1H0uUa/UT+vyX9YxM9tOhrjqRXq583mu5N0lqNntb9n0bPiBZL+lNJz0p6q7qdPkC9/ZtGp/Z+TaPBmtlQbxdo9K3ha5I2Vz+XN73vCn31Zb/xdVkgCb5BByRB2IEkCDuQBGEHkiDsQBKEHUiCsANJ/D+f1mbtgJ8kQQAAAABJRU5ErkJggg==\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.imshow(train_images[0][0], cmap='gray')"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(5)"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_labels[0]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Basics of working with neural nets"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"simple_model = nn.Sequential(nn.Flatten(),\n",
" nn.Linear(in_features=784, out_features=10))"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"==================================================\n",
" Kernel Shape Output Shape Params Mult-Adds\n",
"Layer \n",
"0_0 - [1, 784] - -\n",
"1_1 [784, 10] [1, 10] 7.85k 7.84k\n",
"--------------------------------------------------\n",
" Totals\n",
"Total params 7.85k\n",
"Trainable params 7.85k\n",
"Non-trainable params 0.0\n",
"Mult-Adds 7.84k\n",
"==================================================\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Kernel Shape</th>\n",
" <th>Output Shape</th>\n",
" <th>Params</th>\n",
" <th>Mult-Adds</th>\n",
" </tr>\n",
" <tr>\n",
" <th>Layer</th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0_0</th>\n",
" <td>-</td>\n",
" <td>[1, 784]</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1_1</th>\n",
" <td>[784, 10]</td>\n",
" <td>[1, 10]</td>\n",
" <td>7850.0</td>\n",
" <td>7840.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Kernel Shape Output Shape Params Mult-Adds\n",
"Layer \n",
"0_0 - [1, 784] NaN NaN\n",
"1_1 [784, 10] [1, 10] 7850.0 7840.0"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"summary(simple_model, torch.zeros((1, 1, 28, 28))) # (batch size, num channels, height, width)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"proper_train_size = int(0.8 * len(train_dataset))\n",
"val_size = len(train_dataset) - proper_train_size\n",
"proper_train_dataset, val_dataset = torch.utils.data.random_split(train_dataset,\n",
" [proper_train_size,\n",
" val_size])"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1 [==================================================] 48000/48000\n",
" Train accuracy: 0.8920\n",
" Validation accuracy: 0.8883\n",
"Epoch 2 [==================================================] 48000/48000\n",
" Train accuracy: 0.9057\n",
" Validation accuracy: 0.9033\n",
"Epoch 3 [==================================================] 48000/48000\n",
" Train accuracy: 0.9138\n",
" Validation accuracy: 0.9114\n",
"Epoch 4 [==================================================] 48000/48000\n",
" Train accuracy: 0.9179\n",
" Validation accuracy: 0.9127\n",
"Epoch 5 [==================================================] 48000/48000\n",
" Train accuracy: 0.9207\n",
" Validation accuracy: 0.9153\n",
"Epoch 6 [==================================================] 48000/48000\n",
" Train accuracy: 0.9233\n",
" Validation accuracy: 0.9173\n",
"Epoch 7 [==================================================] 48000/48000\n",
" Train accuracy: 0.9240\n",
" Validation accuracy: 0.9173\n",
"Epoch 8 [==================================================] 48000/48000\n",
" Train accuracy: 0.9258\n",
" Validation accuracy: 0.9187\n",
"Epoch 9 [==================================================] 48000/48000\n",
" Train accuracy: 0.9264\n",
" Validation accuracy: 0.9193\n",
"Epoch 10 [==================================================] 48000/48000\n",
" Train accuracy: 0.9277\n",
" Validation accuracy: 0.9185\n"
]
}
],
"source": [
"num_epochs = 10 # during optimization, how many times we look at training data\n",
"batch_size = 128 # during optimization, how many training data to use at each step\n",
"learning_rate = 0.001 # during optimization, how much we nudge our solution at each step\n",
"\n",
"train_accuracies, val_accuracies = \\\n",
" UDA_pytorch_classifier_fit(simple_model,\n",
" torch.optim.Adam(simple_model.parameters(),\n",
" lr=learning_rate),\n",
" nn.CrossEntropyLoss(), # includes softmax\n",
" proper_train_dataset, val_dataset,\n",
" num_epochs, batch_size)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"UDA_plot_train_val_accuracy_vs_epoch(train_accuracies, val_accuracies)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"====================================================\n",
" Kernel Shape Output Shape Params Mult-Adds\n",
"Layer \n",
"0_0 - [1, 784] - -\n",
"1_1 [784, 512] [1, 512] 401.92k 401.408k\n",
"2_2 - [1, 512] - -\n",
"3_3 [512, 10] [1, 10] 5.13k 5.12k\n",
"----------------------------------------------------\n",
" Totals\n",
"Total params 407.05k\n",
"Trainable params 407.05k\n",
"Non-trainable params 0.0\n",
"Mult-Adds 406.528k\n",
"====================================================\n",
"Epoch 1 [==================================================] 48000/48000\n",
" Train accuracy: 0.9528\n",
" Validation accuracy: 0.9458\n",
"Epoch 2 [==================================================] 48000/48000\n",
" Train accuracy: 0.9650\n",
" Validation accuracy: 0.9565\n",
"Epoch 3 [==================================================] 48000/48000\n",
" Train accuracy: 0.9796\n",
" Validation accuracy: 0.9693\n",
"Epoch 4 [==================================================] 48000/48000\n",
" Train accuracy: 0.9833\n",
" Validation accuracy: 0.9723\n",
"Epoch 5 [==================================================] 48000/48000\n",
" Train accuracy: 0.9912\n",
" Validation accuracy: 0.9768\n",
"Epoch 6 [==================================================] 48000/48000\n",
" Train accuracy: 0.9924\n",
" Validation accuracy: 0.9789\n",
"Epoch 7 [==================================================] 48000/48000\n",
" Train accuracy: 0.9949\n",
" Validation accuracy: 0.9782\n",
"Epoch 8 [==================================================] 48000/48000\n",
" Train accuracy: 0.9955\n",
" Validation accuracy: 0.9774\n",
"Epoch 9 [==================================================] 48000/48000\n",
" Train accuracy: 0.9975\n",
" Validation accuracy: 0.9791\n",
"Epoch 10 [==================================================] 48000/48000\n",
" Train accuracy: 0.9980\n",
" Validation accuracy: 0.9797\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"deeper_model = nn.Sequential(nn.Flatten(),\n",
" nn.Linear(in_features=784, out_features=512),\n",
" nn.ReLU(),\n",
" nn.Linear(in_features=512, out_features=10))\n",
"summary(deeper_model, torch.zeros((1, 1, 28, 28))) # (batch size, num channels, height, width)\n",
"\n",
"train_accuracies, val_accuracies = \\\n",
" UDA_pytorch_classifier_fit(deeper_model,\n",
" torch.optim.Adam(deeper_model.parameters(),\n",
" lr=learning_rate),\n",
" nn.CrossEntropyLoss(), # includes softmax\n",
" proper_train_dataset, val_dataset,\n",
" num_epochs, batch_size)\n",
"\n",
"UDA_plot_train_val_accuracy_vs_epoch(train_accuracies, val_accuracies)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Convnets"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"=========================================================\n",
" Kernel Shape Output Shape Params Mult-Adds\n",
"Layer \n",
"0_0 [1, 32, 3, 3] [1, 32, 26, 26] 320.0 194.688k\n",
"1_1 - [1, 32, 26, 26] - -\n",
"2_2 - [1, 32, 13, 13] - -\n",
"3_3 - [1, 5408] - -\n",
"4_4 [5408, 10] [1, 10] 54.09k 54.08k\n",
"---------------------------------------------------------\n",
" Totals\n",
"Total params 54.41k\n",
"Trainable params 54.41k\n",
"Non-trainable params 0.0\n",
"Mult-Adds 248.768k\n",
"=========================================================\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Kernel Shape</th>\n",
" <th>Output Shape</th>\n",
" <th>Params</th>\n",
" <th>Mult-Adds</th>\n",
" </tr>\n",
" <tr>\n",
" <th>Layer</th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0_0</th>\n",
" <td>[1, 32, 3, 3]</td>\n",
" <td>[1, 32, 26, 26]</td>\n",
" <td>320.0</td>\n",
" <td>194688.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1_1</th>\n",
" <td>-</td>\n",
" <td>[1, 32, 26, 26]</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2_2</th>\n",
" <td>-</td>\n",
" <td>[1, 32, 13, 13]</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3_3</th>\n",
" <td>-</td>\n",
" <td>[1, 5408]</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4_4</th>\n",
" <td>[5408, 10]</td>\n",
" <td>[1, 10]</td>\n",
" <td>54090.0</td>\n",
" <td>54080.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Kernel Shape Output Shape Params Mult-Adds\n",
"Layer \n",
"0_0 [1, 32, 3, 3] [1, 32, 26, 26] 320.0 194688.0\n",
"1_1 - [1, 32, 26, 26] NaN NaN\n",
"2_2 - [1, 32, 13, 13] NaN NaN\n",
"3_3 - [1, 5408] NaN NaN\n",
"4_4 [5408, 10] [1, 10] 54090.0 54080.0"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"simple_convnet = nn.Sequential(nn.Conv2d(1, 32, 3),\n",
" nn.ReLU(),\n",
" nn.MaxPool2d(2),\n",
" nn.Flatten(),\n",
" nn.Linear(in_features=5408, out_features=10))\n",
"summary(simple_convnet, torch.zeros((1, 1, 28, 28))) # (batch size, num channels, height, width)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1 [==================================================] 48000/48000\n",
" Train accuracy: 0.9555\n",
" Validation accuracy: 0.9508\n",
"Epoch 2 [==================================================] 48000/48000\n",
" Train accuracy: 0.9733\n",
" Validation accuracy: 0.9674\n",
"Epoch 3 [==================================================] 48000/48000\n",
" Train accuracy: 0.9812\n",
" Validation accuracy: 0.9757\n",
"Epoch 4 [==================================================] 48000/48000\n",
" Train accuracy: 0.9837\n",
" Validation accuracy: 0.9788\n",
"Epoch 5 [==================================================] 48000/48000\n",
" Train accuracy: 0.9870\n",
" Validation accuracy: 0.9801\n",
"Epoch 6 [==================================================] 48000/48000\n",
" Train accuracy: 0.9876\n",
" Validation accuracy: 0.9794\n",
"Epoch 7 [==================================================] 48000/48000\n",
" Train accuracy: 0.9893\n",
" Validation accuracy: 0.9801\n",
"Epoch 8 [==================================================] 48000/48000\n",
" Train accuracy: 0.9910\n",
" Validation accuracy: 0.9816\n",
"Epoch 9 [==================================================] 48000/48000\n",
" Train accuracy: 0.9899\n",
" Validation accuracy: 0.9821\n",
"Epoch 10 [==================================================] 48000/48000\n",
" Train accuracy: 0.9920\n",
" Validation accuracy: 0.9813\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"train_accuracies, val_accuracies = \\\n",
" UDA_pytorch_classifier_fit(simple_convnet,\n",
" torch.optim.Adam(simple_convnet.parameters(),\n",
" lr=learning_rate),\n",
" nn.CrossEntropyLoss(), # includes softmax\n",
" proper_train_dataset, val_dataset,\n",
" num_epochs, batch_size)\n",
"\n",
"UDA_plot_train_val_accuracy_vs_epoch(train_accuracies, val_accuracies)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"=========================================================\n",
" Kernel Shape Output Shape Params Mult-Adds\n",
"Layer \n",
"0_0 [1, 32, 3, 3] [1, 32, 26, 26] 320.0 194.688k\n",
"1_1 - [1, 32, 26, 26] - -\n",
"2_2 - [1, 32, 13, 13] - -\n",
"3_3 [32, 16, 3, 3] [1, 16, 11, 11] 4.624k 557.568k\n",
"4_4 - [1, 16, 11, 11] - -\n",
"5_5 - [1, 16, 5, 5] - -\n",
"6_6 - [1, 400] - -\n",
"7_7 [400, 10] [1, 10] 4.01k 4.0k\n",
"---------------------------------------------------------\n",
" Totals\n",
"Total params 8.954k\n",
"Trainable params 8.954k\n",
"Non-trainable params 0.0\n",
"Mult-Adds 756.256k\n",
"=========================================================\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Kernel Shape</th>\n",
" <th>Output Shape</th>\n",
" <th>Params</th>\n",
" <th>Mult-Adds</th>\n",
" </tr>\n",
" <tr>\n",
" <th>Layer</th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0_0</th>\n",
" <td>[1, 32, 3, 3]</td>\n",
" <td>[1, 32, 26, 26]</td>\n",
" <td>320.0</td>\n",
" <td>194688.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1_1</th>\n",
" <td>-</td>\n",
" <td>[1, 32, 26, 26]</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2_2</th>\n",
" <td>-</td>\n",
" <td>[1, 32, 13, 13]</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3_3</th>\n",
" <td>[32, 16, 3, 3]</td>\n",
" <td>[1, 16, 11, 11]</td>\n",
" <td>4624.0</td>\n",
" <td>557568.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4_4</th>\n",
" <td>-</td>\n",
" <td>[1, 16, 11, 11]</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5_5</th>\n",
" <td>-</td>\n",
" <td>[1, 16, 5, 5]</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6_6</th>\n",
" <td>-</td>\n",
" <td>[1, 400]</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7_7</th>\n",
" <td>[400, 10]</td>\n",
" <td>[1, 10]</td>\n",
" <td>4010.0</td>\n",
" <td>4000.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Kernel Shape Output Shape Params Mult-Adds\n",
"Layer \n",
"0_0 [1, 32, 3, 3] [1, 32, 26, 26] 320.0 194688.0\n",
"1_1 - [1, 32, 26, 26] NaN NaN\n",
"2_2 - [1, 32, 13, 13] NaN NaN\n",
"3_3 [32, 16, 3, 3] [1, 16, 11, 11] 4624.0 557568.0\n",
"4_4 - [1, 16, 11, 11] NaN NaN\n",
"5_5 - [1, 16, 5, 5] NaN NaN\n",
"6_6 - [1, 400] NaN NaN\n",
"7_7 [400, 10] [1, 10] 4010.0 4000.0"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"deeper_convnet = nn.Sequential(nn.Conv2d(1, 32, 3),\n",
" nn.ReLU(),\n",
" nn.MaxPool2d(2),\n",
" nn.Conv2d(32, 16, 3),\n",
" nn.ReLU(),\n",
" nn.MaxPool2d(2),\n",
" nn.Flatten(),\n",
" nn.Linear(in_features=400, out_features=10))\n",
"summary(deeper_convnet, torch.zeros((1, 1, 28, 28))) # (batch size, num channels, height, width)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1 [==================================================] 48000/48000\n",
" Train accuracy: 0.9578\n",
" Validation accuracy: 0.9557\n",
"Epoch 2 [==================================================] 48000/48000\n",
" Train accuracy: 0.9709\n",
" Validation accuracy: 0.9708\n",
"Epoch 3 [==================================================] 48000/48000\n",
" Train accuracy: 0.9774\n",
" Validation accuracy: 0.9768\n",
"Epoch 4 [==================================================] 48000/48000\n",
" Train accuracy: 0.9782\n",
" Validation accuracy: 0.9782\n",
"Epoch 5 [==================================================] 48000/48000\n",
" Train accuracy: 0.9806\n",
" Validation accuracy: 0.9797\n",
"Epoch 6 [==================================================] 48000/48000\n",
" Train accuracy: 0.9844\n",
" Validation accuracy: 0.9822\n",
"Epoch 7 [==================================================] 48000/48000\n",
" Train accuracy: 0.9854\n",
" Validation accuracy: 0.9835\n",
"Epoch 8 [==================================================] 48000/48000\n",
" Train accuracy: 0.9877\n",
" Validation accuracy: 0.9854\n",
"Epoch 9 [==================================================] 48000/48000\n",
" Train accuracy: 0.9867\n",
" Validation accuracy: 0.9833\n",
"Epoch 10 [==================================================] 48000/48000\n",
" Train accuracy: 0.9896\n",
" Validation accuracy: 0.9875\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAY4AAAEGCAYAAABy53LJAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAgAElEQVR4nOzdeVyVddr48c/FIqCiKLiCiuaKGxrZnlultmnaZsu0jk9NTdP01KQz/XqapkZnbNpmmpoWrWwxMzMr0xyz3CrFcMU1RWVREQUXQDlw/f64b/SIqKDncBCu9+t1Xpx7PddNdi6+u6gqxhhjTEUFBToAY4wxZxdLHMYYYyrFEocxxphKscRhjDGmUixxGGOMqZSQQAdQFWJiYjQ+Pj7QYRhjzFll2bJlu1W1Sdn9tSJxxMfHk5ycHOgwjDHmrCIiW8vbb1VVxhhjKsUShzHGmEqxxGGMMaZS/NrGISKDgZeBYOAtVR1X5ngbYALQBNgD3K6q6e6xvwFXu6f+RVU/dve3BSYDjYGfgTtU9XBlYysqKiI9PZ3CwsLTejZzvPDwcOLi4ggNDQ10KMYYP/Jb4hCRYOBV4AogHVgqIjNUNdXrtOeB91T1XREZAIwF7hCRq4HeQCIQBnwvIl+r6j7gb8CLqjpZRF4H7gVeq2x86enpREZGEh8fj4icyaMaQFXJyckhPT2dtm3bBjocY4wf+bOqqg+wSVU3uyWCycDQMuckAHPd9/O8jicA36uqR1UPAiuAweJ8ww8AprrnvQsMO53gCgsLiY6OtqThIyJCdHS0leCMqQamp2Rw8bhvaTv6Ky4e9y3TUzJ8en9/Jo5YYLvXdrq7z9sKYIT7/nogUkSi3f1DRKSuiMQA/YFWQDSQq6qek9wTABEZJSLJIpKcnZ1dboCWNHzLfp/GBN70lAzGTFtFRm4BCmTkFjBm2iqfJg9/Jo7yvkXKzuH+GNBXRFKAvkAG4FHVb4CZwGLgI+AHwFPBezo7Vd9Q1SRVTWrS5LjxK8YYUyONn72egqLiY/YVFBUzfvZ6n32GPxNHOk4poVQckOl9gqpmqupwVe0F/Mndl+f+fE5VE1X1CpyEsRHYDUSJSMiJ7nm2yMnJITExkcTERJo3b05sbOyR7cOHK9bWf/fdd7N+/cn/Mbz66qt88MEHvgjZGHMWyMwtqNT+0+HPXlVLgQ5uL6gM4BbgVu8T3GqoPapaAozB6WFV2rAepao5ItID6AF8o6oqIvOAG3DaTO4EPvfjMxwxPSWD8bPXk5lbQMuoCB4f1IlhvcqtJauQ6Oholi9fDsDTTz9N/fr1eeyxx445R1VRVYKCys/vEydOPOXnPPjgg6cdozHm7JGRW8C/vt1YfhUM0DIqwmef5bcSh9sO8RAwG1gLTFHVNSLyjIhc557WD1gvIhuAZsBz7v5QYIGIpAJv4HTTLW3XeAJ4VEQ24bR5vO2vZyhVFXWGpTZt2kS3bt24//776d27N1lZWYwaNYqkpCS6du3KM888c+TcSy65hOXLl+PxeIiKimL06NH07NmTCy+8kF27dgHw5JNP8tJLLx05f/To0fTp04dOnTqxePFiAA4ePMiIESPo2bMnI0eOJCkp6UhSM8ZUbzv3FfLU56vpP/47pi5L55L20YSHHPvVHhEazOODOvnsM/06jkNVZ+K0VXjve8rr/VSO9pDyPqcQp2dVeffcjNNjy2f+/MUaUjP3nfB4yrZcDheXHLOvoKiYP0xdyUdLtpV7TULLBvzftV1PK57U1FQmTpzI66+/DsC4ceNo3LgxHo+H/v37c8MNN5CQcOyvJy8vj759+zJu3DgeffRRJkyYwOjRo4+7t6qyZMkSZsyYwTPPPMOsWbP45z//SfPmzfn0009ZsWIFvXv3Pq24jTFVZ/eBQ7z+3S9M+nErxSXKjUlxPDSgA7FRET6vISmrVkxyeKbKJo1T7T9T55xzDuedd96R7Y8++oi3334bj8dDZmYmqampxyWOiIgIhgwZAsC5557LggULyr338OHDj5yTlpYGwMKFC3niiScA6NmzJ127nl7CM8b4396Dh3ljwWbeXZxGYVEx1/eK43cDO9A6uu6Rc4b1ivVpoijLEgecsmRw8bhvySinYSk2KoKP/+dCn8dTr169I+83btzIyy+/zJIlS4iKiuL2228vd6xEnTp1jrwPDg7G4/Ecdw5AWFjYceeonqhW1BhTXewrLOKtBVuYsHALBw97uKZHSx65vAPnNKlf5bHYXFUV8PigTkSEBh+zz9d1hieyb98+IiMjadCgAVlZWcyePdvnn3HJJZcwZcoUAFatWkVqauoprjDGVJWDhzy8Om8Tl/5tHq/M3cgl7WOY9bvL+OfIXgFJGmAljgopLfL5s87wRHr37k1CQgLdunWjXbt2XHzxxT7/jN/+9rf86le/okePHvTu3Ztu3brRsGFDn3+OMabiCg4XM+nHNF7/fjN7Dh5mYOem/P6KjnSLDfz/m1IbqimSkpK07EJOa9eupUuXLgGKqHrxeDx4PB7Cw8PZuHEjV155JRs3biQkpPJ/V9jv1ZgzU1hUzEdLtvHv734he/8hLu0Qw++v6Ejv1o2qPBYRWaaqSWX3W4nDcODAAQYOHIjH40FV+c9//nNaScOYQPJ3TyJ/O+wp4ZNl2/nXt5vIyivk/LaNefXW3vRp2zjQoR3Hvh0MUVFRLFu2LNBhGHPaSsdalU61UTrWCqj2ycNTXMK0lAxembuR9L0F9G4dxfM39uSic3wwCeu8sdB/jG8C9WKJwxhz1vvrzLXlzs/05y/W0L5pfdo3rU94mQ4ugVZcony5MpOX/ruRLbsP0j22IX8Z2o1+nZr4bsLQ78dZ4jDGmFJ5BUXMWJHJJ8nb2bX/ULnn7M0v4pp/LiRIID66Hh2bRdKxeSSdmkXSqXl94qPrERJctZ1LS0qUWWt28OKcDWzcdYDOzSP5zx3ncmVCM98lDM8hmOvOMrFjFTTv7pv7uixxGGPOGiUlyg+bc5iSvJ1Zq3dwyFNC5+aRNIwIIa/g+LFLTSPDePq6rqzfsZ8NO/ezfud+vkndQYnbJ6hOcBDtmtSjU/NIOjYrTSiRxEZFEBTk22UCVJW5a3fxjzkbWJu1j3Oa1ONft/biqm4tfPtZX/0vLH3r6Pbrlzg/+472WenDEocxptpL35vPp8sy+GTZdtL3FtAgPISbz2vFTUmt6NqyAZ8vzzymjQOcsVZ/vKoLV3VvwVXdWxzZX1hUzKZdB44kkg079pOctpfPlx+daLtenWA6uImktITSsXl9mtQPq3SpQFWZv3E3L8zZwIrtubSJrssLN/VkaGIswb5MGKqwbCKkfAB1o2Hoq/DRLfB0nu8+w2WJI0D69evHmDFjGDRo0JF9L730Ehs2bODf//53udfUr1+fAwcOkJmZycMPP8zUqcdN80W/fv14/vnnSUo6rgfdMZ8zatQo6tZ1pii46qqr+PDDD4mKijrDpzLGdwqLivkmdSefJG9n4abdqMIl7WN4fFAnBnVtfkybRWXGWoWHBtMttuFx4yH2FRaxcaebUNwSytx1O/k4+eh6dI3qhjolk9ISSvNIOjaNpGHd0CPnePfuiq5fh8iwELbk5BMbFcHfRnRneO84Qn1dPZa/B2b8FtZ9Ce36w/WvQ2Rz336GF0scleWjXgojR45k8uTJxySOyZMnM378+FNe27Jly3KTRkW99NJL3H777UcSx8yZM09xhTFVZ3VGHlOStzM9JYN9hR5ioyL43cAOjOgdR6vGdU943ZnOz9QgPJRz2zTi3DbHjpfYfeAQG9ySyXo3sXz2cwb7Dx2tGmveIJyOzSMJFmXhphyKitW99jC7DxzmhnPj+Ov13akT4of2lM3fw2f/Awd3w5XPwQW/gdKlGPoeP9GpL1jiqCwf9VK44YYbePLJJzl06BBhYWGkpaWRmZlJYmIiAwcOZO/evRQVFfHss88ydOixS7WnpaVxzTXXsHr1agoKCrj77rtJTU2lS5cuFBQcnVPrgQceYOnSpRQUFHDDDTfw5z//mVdeeYXMzEz69+9PTEwM8+bNIz4+nuTkZGJiYnjhhReYMGECAPfddx+PPPIIaWlpDBkyhEsuuYTFixcTGxvL559/TkSE7+b3N7Xb3oOH+Xx5BlOS00nN2kedkCAGd23OTUmtuOicaJ+3N1RGTP0wYuqHcdE5MUf2qSpZeYVHqrrW73RKKKszyp9l+4dfcnyfNDyHYd6zsOgViOkAt34MLXoee44felSBJQ7H16OdngcVNfHqU5/TvDsMGXfCw9HR0fTp04dZs2YxdOhQJk+ezM0330xERASfffYZDRo0YPfu3VxwwQVcd911J6xXfe2116hbty4rV65k5cqVx0yJ/txzz9G4cWOKi4sZOHAgK1eu5OGHH+aFF15g3rx5xMTEHHOvZcuWMXHiRH766SdUlfPPP5++ffvSqFEjNm7cyEcffcSbb77JTTfdxKeffsrtt99esd+XMeUoLlEWbdrNx8nbmbNmJ4eLS9wuqV25rmfsMdU/1Y2I0DIqgpZREfTv1PTI/rajvyp3ISVfrr4HwO5N8Om9kLUczr0bBv0V6py4NOZrljgqIncr5B2t52TrQudnw1YQ1ea0b1taXVWaOCZMmICq8sc//pH58+cTFBRERkYGO3fupHnz8usr58+fz8MPPwxAjx496NGjx5FjU6ZM4Y033sDj8ZCVlUVqauoxx8tauHAh119//ZHZeYcPH86CBQu47rrraNu2LYmJicCxU7IbU1nbcvKZumw7U5elk5lXSFTdUG49vzU3JbUioWWDQId3RlpGRZQ7k7bPVt9ThZT34es/QEgY3Pw+dLnWN/euBEsccNKSwXGebuizXgrDhg3j0Ucf5eeff6agoIDevXvzzjvvkJ2dzbJlywgNDSU+Pr7cadS9lVca2bJlC88//zxLly6lUaNG3HXXXae8z8nmLSudjh2cKdm9q8SMOZWCw8XMWpPFlKXp/LA5BxG4rEMT/nR1ApcnNCUspHoNzjtdjw/qVG7vLp/MpF2wF774HaR+DvGXwvA3oEHLM7/vabDEEUD169enX79+3HPPPYwcORJwVvJr2rQpoaGhzJs3j61bt570HpdddhkffPAB/fv3Z/Xq1axcuRJwpmOvV68eDRs2ZOfOnXz99df069cPgMjISPbv339cVdVll13GXXfdxejRo1FVPvvsMyZNmuT7Bze1gqqyMt1p6J6xPJP9hzy0blyXx67syPDecT5dA7u68NtM2mmLYNooOLADLn8aLnoYggKXbC1xVJaPeymMHDmS4cOHM3nyZABuu+02rr32WpKSkkhMTKRz584nvf6BBx7g7rvvpkePHiQmJtKnj7Oqbs+ePenVqxddu3Y9bjr2UaNGMWTIEFq0aMG8efOO7O/duzd33XXXkXvcd9999OrVy6qlzEmVnVzwgX7tKCwq4ZPkdNbv3E94aBBXdWvBjUmtOL9tY/83dPtpfqaK8unqe8VF8N04WPAPaNwW7v0GYs/1zb3PgE2rbnzKfq+1S9nJBb0ltoripqRWXNOzBQ3Cq7Ch24fVyQG1ZzN8+mvISIZet8Pgv0FY1S7cFJBp1UVkMPAyEAy8parjyhxvA0wAmgB7gNtVNd099nfgapxVCucAv1NVFZHvgBZAaSX7laq6y5/PYYw56sAhD+t37CM1az9jy5lcEJypPqY/6PtFx8ql6vQuSv0cNnzj7FvzGXQcDKFnYXWYKqz82Jk6JCgYbnwHul4f6KiO4bfEISLBwKvAFUA6sFREZqiq97qkzwPvqeq7IjIAGAvcISIXARcDpV2AFgJ9ge/c7dtU9dgihDHGp1SV9L0FpGbtY+2R13627ck/5bXZJ5h00IfBQcbPkDrdSRi5ZdoCP7nL+dm8uzMoLv6SgLYJVFhhHnz5KKyeCm0uhuv/A1GtAh3VcfxZ4ugDbFLVzQAiMhkYCngnjgTg9+77ecB0970C4UAdQIBQYKevA1RV381GaU7aK8tUb/mHPazbsZ91WfuPJIl1O/ZzwB0dLe7sst1iG3DjuXF0adGAzi0iufk/P5CRe3xvPb80fJeUQMayo8kibzsEhUC7fnDZ49D5aqjb2Kmq+tXnsHIKpM6A966DyJbQ/QbocTM07+b72Hxh208w7T7Iy4ABT8Ilj1bbZOfPxBELeA1+IB04v8w5K4ARONVZ1wORIhKtqj+IyDwgCydx/EtV13pdN1FEioFPgWe1nG8sERkFjAJo3br1ccGFh4eTk5NDdLQPFksxqCo5OTmEh4cHOhRzEqpKRm4Ba7P2sy5rH2t3OKWItJyDlP5fVD8shM7NI7m+VyxdWjSgSwtnTqa6dY7/unh8UGf/dT8FJ1mkL4E102HtDNiXAUGhcM4A6DcGOl8FEeUsqdqun/O6+h+w/msnifz4b1j8CjRNgB43QfcboWGcb+I8E8UemD8e5v8dolrDPbOh1XmBjuqk/Jk4yvs2LvsF/xjwLxG5C5gPZAAeEWkPdAFK/6vOEZHLVHU+TjVVhohE4iSOO4D3jvsg1TeAN8BpHC97PC4ujvT0dLKzs0/r4czxwsPDiYurBv8j1hKnWiq1sKiY9Tv2Hyk9pGbtY13WPvYVHp1jqU10Xbo0b8DQxJZ0adGAhBYNiGsUUeE/pvzS/bSkGLb96JQq1s6A/VkQXAfaXw4Dn3LaLiJOMiGnd8/H0AjoNtx5HcyBNdOcJPLfp+G/f3aqsLrfCAlDT35Pf9m7Fab9Grb/BD1ugavGQ3j1HwTpt15VInIh8LSqDnK3xwCo6tgTnF8fWKeqcSLyOBCuqn9xjz0FFKrq38tccxeQpKoPnSyW8npVGXM2K683U53gIK5IaIqIsDZrH1t2Hzyy7kTdOsF0bh7pliBKSxENqB9WTXrklxTD1kVusvgCDuyE4DDocAUkDIOOg3z7hbpnM6ya6jRC52xyPqvjIKck0uFKZ1S2v62aCl+6NfVXvwA9bvT/Z1bSiXpV+TNxhAAbgIE4JYmlwK2qusbrnBhgj6qWiMhzQLGqPiUiNwO/BgbjlFxmAS8BXwNRqrpbREKBj4D/qurrJ4vFEoepaS4e9225U1sAxDWKOJIgElpE0rl5A1o3rhvQiQLLVexxpu9ZM92ZDvxgNoREOMmi6zDnCzws0r8xqEJmilMKWT3ViSG8odOLqcfN0OqCozPN+sqh/TDzcVjxEbQ63xkB3ijet5/hI1XeHVdVPSLyEDAbpzvuBFVdIyLPAMmqOgPoB4wVEcWpqnrQvXwqMABYhVO9NUtVvxCResBsN2kEA/8F3vTXMxhTXZ0oaQiw8IkBVRtMZRQXwZb5Tsli3ZeQnwOhdZ2/9hOGOsmiTr2qi0cEYns7ryufhc3fwaopTiJZ9g40bH20Ub3pyQfjVkh6sjM5Ye42p43m0scguJqU+iqh1g4ANOZslL3/EE9/sYavVmaVezw2KoJFo6tZ4vAchi3fO72h1n3lzLlUp77TVpEw1Gm7qMKZXSvk0AFYP9NJIL98C1oMzXs4VVndboAGLU59D28lxbDwBWdUe4NYGPEmtL7AP7H7UJVXVVUnljjM2U5VmbosnWe/WkvB4WIu79KUb9fvorCo5Mg5EaHBjB3e3XfTXVSW91QfnkPwyzynZLH+K2d8QlgD6DTESRbnDDh7Bucd2AWrpzntIZk/AwLt+jqlkM7XnLrtJS/dmWdq6yLoNsJpzwhEQ/xpsMRhicOcpbbl5DPms5Us2pTDefGNGDu8B+2b1j9lr6oq93RDuOUjp2Sx/ms4tA/CGjpdZhOGwTn9q6bR2Z92b3Krsj6GvWlOm0ynIU4SaT8Qgr2mVpk3Fpp2gS8edkocVz0PPW9xqsfOEpY4LHGYs4ynuISJi9L4x5z1hAQF8cSQztzWp3X1aeQ+dMCZ6iM9GdKXOm0WAOFRzl/iXYdB274QUiewcfqDqvPMK6fA6k+hYA/UjYauw53qrKYJMNZN4rHnwoi3oHG7wMZ8GgIyV5Ux5vSsycxj9KerWJWRx+VdmvGXYV1p0TCAVTslxZC93plwLz3ZGcG9KxW05PhzC3OdgXUdrqj6OKuKCLTq47wGj4VNc51SSMokWPqm070XnMbvfqOPLYnUAFbiMKYaKSwq5uW5G3lj/mYa1Q3lz9d146ruzat+doP9O9wE4SaKzOVweL9zLLyh81d0bBLEJTnv68XUnFlpz8Scp2HRi8fv7zs6oFO9ny4rcRhTzf24OYcx01axZfdBbkqK449XdSGqbhVU8xzOP1rllJEM6ctgX7pzLCgEmnWDnjcfTRSNz/H92Iaa4oqnnRfU6ERqicOYAMsrKGLc12v5aMl2Wjeuywf3nc/F7WNOfeHpKCmBnI3HliZ2rnG6m4IzV1KrPhD3GydRtOhR8d5PPl7kzFRfljiMCaBZq3fw1Oer2X3gEP9zWTseubwjEXUqOSPqyVa8O5Dt1S6RDBkpcMj9KzisgTPw7ZLfH61yqt/09B/mLKyK8asanEgtcRgTADv3FfJ/n69h1podJLRowNt3nkf3uIand7Pvxzlf2kUFkLXy2ESRu805R4KhWVfoPuJolVN0B6ty8qcanEgtcRhThVSVyUu389eZaznsKeGJwZ2579K2hAafxhe4qjOqGeA/fWHnaihxZ75t2MopQfQZ5VY59ax+o7PNWcsShzFVZMvug4yZtpIfN+/hgnaNGTu8B21jTmNepqJCZ4W7DV8f3Ze13PnZdbjTPTSyuU9iNqY8ljiM8bOi4hLeXLCZl/67kbCQIMYN787N57WqfBfbA9mQ/DYseRPyd0Oz7nDhgzD9/hrbe8dUT5Y4jPGjlem5PPHpKtZm7WNIt+b8+bquNG1QyVUSd62FH151RikXH4IOg5yE0fYyZyDa9Pv9E7wxJ2CJwxg/yD/s4cU5G3h74RZi6ofx+u3nMrhbJaqPStsvfngVfpnrzInU6zY4/wFo0vHYc2tw7x1TPVniMMbHFm7czZjPVrJ9TwEj+7Rm9JDONIyo4JQTRYXOJHo//Buy10L9ZjDgSTj3HqgXXf41Nbj3jqmeLHEY4yO5+Yd59qu1TF2WTtuYekwedQEXtDvBl31ZB7Jh6VvOq7T9YtjrzlrZZ/uMsqbGscRhzBlSVb5cmcWfv1jD3vwiftPvHB4e2IHw0AoM5DtV+4Ux1ZAlDmMqyXsdjGYNwoiuX4c1mfvpHtuQ9+45n4SWp1jYR9Vpt/jhVacd42TtF8ZUQ5Y4jKmE6SkZjJm2ioIiZ26nHfsOsWPfIYYltuT5G3sScrKBfKfTfmFMNWSJw5hK+NusdUeShrelaXtPnDSs/cLUMH5NHCIyGHgZCAbeUtVxZY63ASYATYA9wO2qmu4e+ztwNRAEzAF+p6oqIucC7wARwMzS/f58DlO7qSrLt+cy6cetZOUVlntOZm7B8Tut/cLUUH5LHCISDLwKXAGkA0tFZIaqpnqd9jzwnqq+KyIDgLHAHSJyEXAx0MM9byHQF/gOeA0YBfyIkzgGA15zLxjjGwWHi/liRSaTftzKqow86tUJpl6dYA4ePr7E0TLKnXrc2i9MLeDPEkcfYJOqbgYQkcnAUMA7cSQAv3ffzwOmu+8VCAfqAAKEAjtFpAXQQFV/cO/5HjAMSxzGh7bsPsgHP27lk2Xp5BUU0bFZff4ytCvX947jv6k7j7RxPBIylZc8NxARGswTl8fDz+9Z+4WpFfyZOGKB7V7b6cD5Zc5ZAYzAqc66HogUkWhV/UFE5gFZOInjX6q6VkSS3Pt43zO2vA8XkVE4JRNat27tg8cxNVlxifLtul1M+nEr8zdkExIkDOrWnDsuaMP5bRsfmVdqWC/nn9v42et5pHAac+pdyz/il9L52wet/cLUGv5MHOVV4pZti3gM+JeI3AXMBzIAj4i0B7oAce55c0TkMqCciuTj7unsVH0DeAOcNccrHb2pFXYfOMTHS7fz4U/byHC71/7+8o7c0qcVzU4wp9Swro0Y1gCYBF957of11n5hahd/Jo50oJXXdhyQ6X2CqmYCwwFEpD4wQlXz3NLCj6p6wD32NXABMImjyaTcexpzKqrKz9v2MumHrcxctYPDxSVcdE40T17dhcsTmh2/NobnsLMo0pb5kPI+5HkVpIsPOT9b9oJ2favuIYwJIH8mjqVABxFpi1OSuAW41fsEEYkB9qhqCTAGp4cVwDbg1yIyFqfk0hd4SVWzRGS/iFwA/AT8CvinH5/B1CD5hz18vjyTST9sJTVrH5FhIdx6fmtuv6A17ZtGHj2xpNhZ32LLfOe17UcoygfEWRCp6/XQti98MMKmMze1kt8Sh6p6ROQhYDZOd9wJqrpGRJ4BklV1BtAPGCsiilNV9aB7+VRgALAKpypqlqp+4R57gKPdcb/GGsbNKfySfYD3f9zK1GXp7C/00Ll5JM9d341hibHUCwtxekLtTHUTxfeQtujoutxNukCvO5wqqPiLIaJRYB/GmGpAasMQiKSkJE1OTg50GKYKeYpL+O/aXbz/41YWbtpNaLAwpFsL7riwDUmto5C9W46WKLbMdxq2ARq1dZJE28sg/lKIbHbiD5k31mamNTWaiCxT1aSy+23kuKlRsvcfYvKSbXy4ZBtZeYW0bBjOY1d2ZGTnEKJ3/Qgpb8C0+bDP7ZwX2QLaX+4mi0shqhI98CxpmFrKEoc566kqS9P2MunHrcxanUVRsXJVuxBe67WLHkUrCFo9H+b/4pwc0dhNEo867RTR51gvKGMqyRKHOWuUzkp744FJfFL/Dh4e0J6iEuX9H7eSsWMnfcM3MCl2K72KVxKWudbpb1cn0mmbOO9eJ2E07QpBJ5mI0BhzSpY4zFnBe1baR8Kn8XrutXwx/UMuClrDy2Hr6BC+iSBKYE84tL4AEm90ShQtEiHY/pkb40v2f5SpdopLlKy8Arbm5JOWc5CtOfl8sngdvUo2cH7IWgBWht1HHSnGQzDBsX2Qttc7JYq482zEtjF+ZonDBERRcQnpewtIyznINq8EkZZzkPQ9BYQVH+DcoPWcH7SOIcHreCJ4E8EhR3sA1hFnosH/eK7mwXsmBeoxjKmVLHGYCvFe9a5lVASPD+p0ZN6mEyksKmb7nnzScvLZmnPwSJE6gIUAACAASURBVHLYmpNPRm4BxSVHE0HLOgcZErmF20LW07XhaprmbySIEjQoFGLPZVJmD+YWtGdZSQdWh99HfOGHAMRGRRwZ/GOMqRqWOMwplV31LiO3gDHTVgFweUIztnolBO8EsWNfId7DhBqEh9A2ph6JraK4NaEOiSWptMtfTuPsZEJy1sFBICTcqW6KHw5tLkJik6BOXRqkZLBk2ioKSo5OaR4RGszjgzpV5a/CGIMlDlMB42evP27Vu4KiYh6dspySMuNHY+qH0Sa6LheeE018dD3aRNelTXQ92obk0HDXUkj7ArYuhvVu99g69d3G7JugzcXOnE/ltFF4z0r70oHhxFaw1GOM8T1LHOaUyl3dDihR+MPgTsckiPqlU3jk/AJbF8HmxTBv0dGJAcOjoM1FkHS3kyia96hwr6dhvWLdRDGAR3z0bMaYyrPEYU4pqm4oe/OLAI4sXgRO+8Jv+rWHkhLIXgcrFjnJYutiOLDTubheEydBXPSwkzCaJtg4CmPOcpY4zEkt3rSbfQVFBIlTwngkZBqveIbTK3Q7T7fPhY/egm2LoWCvc0GDWGf8RJuLIP4SiG5vI7ONqWEscZgTSs3cx6hJyzinaX3uurA1O/77KnhgZfgo6pMPq3EmBex0tTM6u81FENXGEoUxNZwlDlOu7XvyuWviEiLDQ/i08/dEzvrHkWP1yXfeXPAgDP5rgCI0xgSKJQ5znD0HD3PnxCUUFhUz9f4LiUyZ4xy45FFY+IItXmRMLWetlOYYBYeLuffdpWTsLeDtO5PouOp5+Ok1p3Qx8KlAh2eMqQYscZgjPMUlPPThz6zYnssrI3txXtobsOhlSLoXBj3ntF30HR3oMI0xAXbKxCEiD4mIrZdZw6kqf/psNXPX7eKZod0YtOcD+H4c9Lodrnr+aIO3LV5kTK1XkRJHc2CpiEwRkcEi1mWmJnpxzgY+Tt7OwwPac7t+CXOfge43wbWv2LgLY8wxTvmNoKpPAh2At4G7gI0i8lcROcfPsZkq8v6PW3nl203cnNSK30fNh9l/hIShMOw1CAoOdHjGmGqmQn9KqqoCO9yXB2gETBWRv5/sOreEsl5ENonIcZXjItJGROaKyEoR+U5E4tz9/UVkuderUESGucfeEZEtXscSK/nMxsus1Tt46vPVDOzclL+2XY7MfAw6XQUj3rYFkIwx5TrlN4OIPAzcCewG3gIeV9UiEQkCNgJ/OMF1wcCrwBVAOk511wxVTfU67XngPVV9V0QGAGOBO1R1HpDo3qcxsAn4xuu6x1V1auUe1ZS1NG0PD09OoWerKF7rsYngGQ9D+8vhxncgODTQ4RljqqmK/EkZAwxX1a3eO1W1RESuOcl1fYBNqroZQEQmA0MB78SRAPzefT8PmF7OfW4AvlbV/ArEaipow8793PvOUuIaRfDeBZnU+eI3zhQhN79vK+gZY06qIlVVM4E9pRsiEiki5wOo6tqTXBcLbPfaTnf3eVsBjHDfXw9Eikh0mXNuAT4qs+85t3rrRREp91tOREaJSLKIJGdnZ58kzNonK6+AOycsISw0mI8v20Pkl/dDq/Ph1o8hNCLQ4RljqrmKJI7XgANe2wfdfadSXu+rMqs38BjQV0RSgL5ABk4binMDkRZAd2C21zVjgM7AeUBj4InyPlxV31DVJFVNatKkSQXCrR3y8ou4c8ISDhR6+HTgQZp8PQpa9IRbp0CdeoEOzxhzFqhIVZW4jePAkSqqilyXDrTy2o4DMr1PUNVMYDiAiNQHRqiq93wWNwGfqWqR1zVZ7ttDIjIRJ/mYCigsKubX7yWTtjufz4YcpvV/R0HTLnD7pxDeINDhGWPOEhUpcWwWkYdFJNR9/Q7YXIHrlgIdRKStiNTBqXKa4X2CiMS4jezglCQmlLnHSMpUU7mlENzxJMNw5mg1p1BcojwyeTlL0vYwcUARXb+/Hxq3gzumQ4SN7zTGVFxFEsf9wEU41UjpwPnAqFNdpKoe4CGcaqa1wBRVXSMiz4jIde5p/YD1IrIBaAY8V3q9iMTjlFi+L3PrD0RkFbAKp+H+2Qo8Q62mqjw9Yw2z1uzgn5d6uPinB5x1M371OdQr26RkjDEnJ161UDVWUlKSJicnBzqMgHl13ibGz17PU+ce5p5Nv4O6jeHumdCgZaBDM8ZUYyKyTFWTyu6vyDiOcOBeoCsQXrpfVe/xaYTGL6Ykb2f87PX8pkshd2/+XwhvCHd+YUnDGHPaKlJVNQlnvqpBONVGccB+fwZlfOPbdTsZM20VN8fn8/jOPyAhEXDn5xDV6tQXG2PMCVQkcbRX1f8HHFTVd4GrcbrImmosZdtefvPBzwxseoCxB/6ESBDcOcNpEDfGmDNQkcRR2hU2V0S6AQ2BeL9FZM7Y5uwD3PPOUnrUy+O14qcJKvE4SSOmQ6BDM8bUABUZj/GGux7HkzjdaesD/8+vUZnTtmtfIb+asITm7OGDOs8RfPgA3PmlM17DGGN84KSJwx1jsU9V9wLzAavnqMb2FxZx58SlBB3cyfSovxFauNdp02jRI9ChGWNqkJNWValqCc5YDFPNHfIUc//7y8jZmcHXUc8TVrALbp8KsecGOjRjTA1TkTaOOSLymIi0EpHGpS+/R2YqrKREeeyTlazetJVvYl6gXn66M2Fh6wsCHZoxpgaqSBtH6XiNB732KVZtVW08N3Mt363YxNwmLxF1cAuMnAxtLw10WMaYGuqUiUNV21ZFIOb0vDl/M5MXpjIr+iWaHNwAN38A7QcGOixjTA1WkZHjvypvv6q+5/twTGVMT8ngHzOXM6PRK8Tlr0VunAidBgc6LGNMDVeRqqrzvN6HAwOBnwFLHAG0YGM2f/xkKVMa/pMOBSuREW9BwtBAh2WMqQUqUlX1W+9tEWmIMw2JCZDVGXn8dtJPTKz7T3oc+hmG/hu63xDosIwxtURFelWVlQ/YEOQA2ZpzkHsnLOblkFc435MM17wIvW4LdFjGmFqkIm0cX3B0ydcgIAGY4s+gTPl2HzjE3W//wNPF/6QvP8HgcZBkkxQbY6pWRdo4nvd67wG2qmq6n+IxZUxPyWD87PVk5hYQGqyMDfoPQ4IXweV/hgseCHR4xphaqCKJYxuQpaqFACISISLxqprm18gM01MyGDNtFQVFxTwSMpWm5DIieD5rOz9El0seCXR4xphaqiKJ4xOcpWNLFbv7ziv/dOMr42evp6CoGFAeCZkGwKue6/hwywAWBTY0Y0wtVpHEEaKqh0s3VPWwiNTxY0zGlZlbAMD/hnwCwFueIYz33IzkFQYyLGNMLVeRXlXZInJd6YaIDAV2+y8kU+rJep+TFn4rvw2ZDsB9IV+TFn4bT9b7PMCRGWNqs4okjvuBP4rINhHZBjwB/E9Fbi4ig0VkvYhsEpHR5RxvIyJzRWSliHwnInHu/v4istzrVSgiw9xjbUXkJxHZKCIf1+TST9gVf2Js0cgj2/GFH9Kl+GOir/m/AEZljKntTpk4VPUXVb0ApxtuV1W9SFU3neo6EQkGXgWGuNeOFJGEMqc9D7ynqj2AZ4Cx7mfOU9VEVU0EBuCMHfnGveZvwIuq2gHYC9xbgec8K+UePMRNwd+xQjoBEBsVwdjh3RnWKzbAkRljarNTJg4R+auIRKnqAVXdLyKNROTZCty7D7BJVTe7bSSTgbJzYiQAc93388o5DnAD8LWq5ouI4CSSqe6xd4FhFYjlrHPIU0zK4m84JyiLntf+FvqOZtHoAZY0jDEBV5GqqiGqmlu64a4GeFUFrosFtnttp7v7vK0ARrjvrwciRSS6zDm3AB+576OBXFX1nOSeNcKXK7IYdOgbikPqQtfrof+YQIdkjDFAxRJHsIiElW6ISAQQdpLzj5xazj4ts/0Y0FdEUoC+QAbOIMPSz2oBdAdmV+KepdeOEpFkEUnOzs6uQLjVh6rywYI1XBvyE0HdhkNY/UCHZIwxR1SkO+77wFwRmehu341TRXQq6UArr+04INP7BFXNBIYDiEh9YISq5nmdchPwmaoWudu7gSgRCXFLHcfd0+vebwBvACQlJZWbXKqrHzbn0D57DhGhhdC73FntjTEmYCrSOP534FmgC06bxCygTQXuvRTo4PaCqoNT5TTD+wQRiRGR0hjGABPK3GMkR6upUFXFaQspnQr2TqDG9U19e8EWbgv9npLoDtCqT6DDMcaYY1R0dtwdQAlOe8RAYO2pLnBLBA/hVDOtBaao6hoRecZrXEg/YL2IbACaAc+VXi8i8Tgllu/L3PoJ4FER2YTT5vF2BZ/hrLA5+wBp61PoyQaCet8BUl7tnDHGBM4Jq6pEpCNOKWEkkAN8DIiq9q/ozVV1JjCzzL6nvN5P5WgPqbLXplFOw7eqbsbpsVUjTVi0hZEh36NBIUjPkae+wBhjqtjJ2jjWAQuAa0vHbYjI76skqloqN/8w05el8UOdRUiHwVC/aaBDMsaY45ysqmoEThXVPBF5U0QGUn6vJuMjH/y0jYuKlxFZvBd63R7ocIwxplwnTByq+pmq3gx0Br4Dfg80E5HXROTKKoqv1jjsKeHdxWk80GAx1G8O7a8IdEjGGFOuivSqOqiqH6jqNTjdX5cDx807Zc7MV6syYf8OEg8thcSREFyRntLGGFP1KrXmuKruUdX/qOoAfwVUG6kqby3YwqiGPyFaAolWTWWMqb4qlTiMf/y0ZQ9rMvO4Kfh7aH0RxLQPdEjGGHNCljiqgbcWbGFgxCYa5G+F3ncEOhxjjDkpq0gPsC27DzJ33U6+iFsCeZGQUN4EwcYYU31YiSPAJi7aQlRQIV33fgvdhkOdeoEOyRhjTsoSRwDl5RfxSXI6f2ydingKbEJDY8xZwRJHAH24ZBsFRcVcXTwXmnSG2HMDHZIxxpySJY4AKSp2Bvzd1GY/dXelQC+b0NAYc3awxBEgM1dlsWNfIQ9F/QBBIdDzlkCHZIwxFWKJIwBKB/x1jKlDq+1fQKchUC8m0GEZY0yFWOIIgKVpe1mVkcef2m9F8nOglzWKG2POHpY4AuCtBZuJqhvKJfu/hsiW0H5goEMyxpgKs8RRxbbmHGTO2p3cnxhO8OZvnQkNg4IDHZYxxlSYJY4qNnFRGiFBwu0Ri0BLbN0NY8xZxxJHFcorKGJK8nau69Gc+qmTIf5SaNwu0GEZY0ylWOKoQpOXbCP/cDEPnbML9qZZacMYc1bya+IQkcEisl5ENonIcYs/iUgbEZkrIitF5DsRifM61lpEvhGRtSKSKiLx7v53RGSLiCx3X4n+fAZfKSou4Z3FaVzYLpq226ZBWAPocl2gwzLGmErzW+IQkWDgVWAIkACMFJGEMqc9D7ynqj2AZ4CxXsfeA8arahegD7DL69jjqprovpb76xl86evVO8jKK2RUnxhI/Ry63wB16gY6LGOMqTR/ljj6AJtUdbOqHgYmA2XnDE8A5rrv55UedxNMiKrOAVDVA6qa78dY/UpVeXvBZtrG1KPv4e/BU+hMMWKMMWchfyaOWGC713a6u8/bCmCE+/56IFJEooGOQK6ITBORFBEZ75ZgSj3nVm+9KCJh/noAX1m2dS8r0vO45+J4gpZPgqZdoWWvQIdljDGnxZ+Jo7wZ+7TM9mNAXxFJAfoCGYAHZ4GpS93j5wHtgLvca8YAnd39jYEnyv1wkVEikiwiydnZ2Wf2JGforQVbaBgRyg2t8iAzxVnlzyY0NMacpfyZONKBVl7bcUCm9wmqmqmqw1W1F/And1+ee22KW83lAaYDvd3jWeo4BEzEqRI7jqq+oapJqprUpEkTXz9bhW3Lyeeb1B3cdn5rIlZ9BMF1oMfNAYvHGGPOlD8Tx1Kgg4i0FZE6wC3ADO8TRCRGREpjGANM8Lq2kYiUfuMPAFLda1q4PwUYBqz24zOcsYmLtxAkwq/OawErJ0Onq6Bu40CHZYwxp81vicMtKTwEzAbWAlNUdY2IPCMipf1Q+wHrRWQD0Ax4zr22GKeaaq6IrMKp9nrTveYDd98qIAZ41l/PcKb2FRYxZel2ru3ZkuY7voWCvU41lTHGnMVC/HlzVZ0JzCyz7ymv91OBqSe4dg7Qo5z9A3wcpt98vGQ7Bw8Xc+8lbeHbZ6BBHLTrH+iwjDHmjNjIcT/xuAP+zm/bmG719sEv30LirTahoTHmrGeJw09mrdlBRm6BU9pY8RGg0Ou2QIdljDFnzBKHn7y1YAvx0XUZ2LkJpEyCtn2hUXygwzLGmDNmicMPlm3dy/LtudxzSVuCty6A3G02UtwYU2NY4vCDtxdupkF4CCN6x0HK+xDeELpcE+iwjDHGJyxx+Nj2PfnMWr2DW89vQ72S/ZA6A7rfBKERgQ7NGGN8whKHj72zOI0gEe68qA2smgrFh2zdDWNMjWKJw4f2Fxbx8dLtXN2jBS0aRjiN4s27Q8uzYskQY4ypEEscPvTx0u0cOORxuuBmrYSsFdDrV4EOyxhjfMoSh494ikuYuCiNPvGN6REX5ZQ2gsOcBZuMMaYGscThI9+k7nQG/F3aFooKYeUUpyeVTWhojKlhLHH4yFsLNtO6cV0u79IM1n0Jhbk2dsMYUyNZ4vCBn7ft5edtudxzcTzBQeJUUzVs7YwWN8aYGsYShw+8vXALkeEh3JjUCvZuhc3fO/NSBdmv1xhT89g32xlK35vP16uyuLVPa+qFhcDyD50DibcGNjBjjPETSxxn6N3FaYgId14UDyXFsPwDaNcPoloHODJjjPEPSxxn4MAhD5OXbOeq7i1oGRUBW76HvO22yp8xpkazxHEGpizdzv7SAX8AP0+CiEbQ2SY0NMbUXJY4TlNxiTJx8RaS2jQisVUU5O9xuuF2vwlCwgIdnjHG+I0ljtM0J3UH2/cUcN+lbmlj1SdQfNiqqYwxNZ4ljtP01oIttGocwRUJzUHVqaZqkehMamiMMTWYXxOHiAwWkfUisklERpdzvI2IzBWRlSLynYjEeR1rLSLfiMhaEUkVkXh3f1sR+UlENorIxyJSx5/PUJ7l23NJ3rqXuy9q6wz4y1oOO1fZ9OnGmFrBb4lDRIKBV4EhQAIwUkQSypz2PPCeqvYAngHGeh17Dxivql2APsAud//fgBdVtQOwF7jXX89wIm8v3EJkWAg3ndfK2ZHyPoSEQ/cbqzoUY4ypcv4scfQBNqnqZlU9DEwGhpY5JwGY676fV3rcTTAhqjoHQFUPqGq+iAgwAJjqXvMuMMyPz3CcjNwCZq7K4pY+ragfFgJFBbDyE+hyHUREVWUoxhgTEP5MHLHAdq/tdHeftxXACPf99UCkiEQDHYFcEZkmIikiMt4twUQDuarqOck9ARCRUSKSLCLJ2dnZPnokeG9xGoAz4A9g7ZdwKM+qqYwxtYY/E4eUs0/LbD8G9BWRFKAvkAF4gBDgUvf4eUA74K4K3tPZqfqGqiapalKTJk1O6wHKOnjIw4dLtjG4W3PiGtV1dqa8B1FtIP5Sn3yGMcZUd/5MHOlAK6/tOCDT+wRVzVTV4araC/iTuy/PvTbFrebyANOB3sBuIEpEQk50T3/6JHk7+ws93Fc64G/PFtgy35k+3SY0NMbUEv78tlsKdHB7QdUBbgFmeJ8gIjEiUhrDGGCC17WNRKS0qDAASFVVxWkLKV1W707gcz8+wxHFJcqERWn0bh1Fr9aNnJ3LPwQEEkdWRQjGGFMt+C1xuCWFh4DZwFpgiqquEZFnROQ697R+wHoR2QA0A55zry3GqaaaKyKrcKqo3nSveQJ4VEQ24bR5vO2vZ/D237U72bYnn/subefsKJ3QsP1AaBh38ouNMaYGCTn1KadPVWcCM8vse8rr/VSO9pAqe+0coEc5+zfj9NiqUm8v2EJsVARXJjRzdvwyD/ZlwKC/VnUoxhgTUFYxXwEr03NZkraHuy+OJyTY/ZWlTIKIxtBpSGCDM8aYKmaJowLeXriF+mEh3Fw64O9gDqz7CnreYhMaGmNqHUscp5CVV8BXK7O4+bxWRIaHOjtXfgwlRU5vKmOMqWX82sZxNpueksH42evJyC0AoGVUuHNA1ZlipGVvaFZ2BhVjjKn5LHGUY3pKBmOmraKgqPjIvudnbyC6XhjDmu6AXWvgmhcDGKExxgSOVVWVY/zs9cckDYCComLGz17vTJ8eEgHdRpzgamOMqdkscZQj062eKmtPbi6s/hQShkJ4wyqOyhhjqgdLHOVoGRVR7v6Rkcvh0D5b5c8YU6tZ4ijH44M6EREafMy+iNBgftNgMTRuB20uDlBkxhgTeJY4yjGsVyxjh3cnNioCAWKjInjlygbE5CyFxNtAypuk1xhjagfrVXUCw3rFMqyX11Ifc58BCYLEWwMXlDHGVANW4qiIYo8zE277K6BBy0BHY4wxAWWJoyJ++Rb2Z9kqf8YYgyWOikl5D+rGQMfBgY7EGGMCzhLHqRzIhvVfuxMa1gl0NMYYE3CWOE5l5cdQ4rEJDY0xxmWJ42RUYeELEHceNO0c6GiMMaZasMRxMunJkJ9jpQ1jjPFiieNkUiY5P7teH9g4jDGmGrEBgOWZNxa+H3d0e5y78l/f0dB/TGBiMsaYasKviUNEBgMvA8HAW6o6rszxNsAEoAmwB7hdVdPdY8XAKvfUbap6nbv/HaAvkOceu0tVl/s08P5jjiaIpxvC03knP98YY2oRvyUOEQkGXgWuANKBpSIyQ1VTvU57HnhPVd8VkQHAWKC0QaFAVRNPcPvHVXWqv2I3xhhzYv5s4+gDbFLVzap6GJgMDC1zTgIw130/r5zjgdd3dKAjMMaYasWfiSMW2O61ne7u87YCKF1K73ogUkSi3e1wEUkWkR9FZFiZ654TkZUi8qKIhJX34SIyyr0+OTs7+/Sfwto0jDHmGP5MHOXNPa5lth8D+opICk67RQbgcY+1VtUk4FbgJRE5x90/BugMnAc0Bp4o78NV9Q1VTVLVpCZNmpzZkxhjjDnCn4kjHWjltR0HZHqfoKqZqjpcVXsBf3L35ZUec39uBr4DernbWeo4BEzEqRIzxhhTRfyZOJYCHUSkrYjUAW4BZnifICIxIlIawxicHlaISKPSKigRiQEuBlLd7RbuTwGGAav9+AzGGGPK8FuvKlX1iMhDwGyc7rgTVHWNiDwDJKvqDKAfMFZEFJgPPOhe3gX4j4iU4CS3cV69sT4QkSY4VWHLgfv99QzGGGOOJ6plmx1qnqSkJE1OTg50GMYYc1YRkWVuW/Ox+2tD4hCRbGDraV4eA+z2YTiny+KoXjGAxVGWxXGs6hDHmcbQRlWP611UKxLHmRCR5PIyrsVRu2OwOCyOsyEOf8VgkxwaY4ypFEscxhhjKsUSx6m9EegAXBbHUdUhBrA4yrI4jlUd4vBLDNbGYYwxplKsxGGMMaZSLHEYY4ypFEscJyAiE0Rkl4gEbEoTEWklIvNEZK2IrBGR3wUojnARWSIiK9w4/hyIOLziCRaRFBH5MoAxpInIKhFZLiIBG10qIlEiMlVE1rn/Ti4MQAyd3N9D6WufiDwSgDh+7/77XC0iH4lIeFXH4MbxOzeGNVX5eyjvO0tEGovIHBHZ6P5s5IvPssRxYu8AgwMcgwf4X1XtAlwAPCgiCQGI4xAwQFV7AonAYBG5IABxlPodsDaAn1+qv6omBriv/svALFXtDPQkAL8XVV3v/h4SgXOBfOCzqoxBRGKBh4EkVe2GM83RLVUZgxtHN+DXOJOv9gSuEZEOVfTx73D8d9ZoYK6qdsBZ+8gnCwxZ4jgBVZ2Ps5xtIGPIUtWf3ff7cb4Uyq5pUhVxqKoecDdD3VdAelWISBxwNfBWID6/OhGRBsBlwNsAqnpYVXMDGxUDgV9U9XRnajgTIUCEiIQAdSkzG3cV6QL8qKr5quoBvsdZa8jvTvCdNRR4133/Ls7EsGfMEsdZQkTicaaW/ylAnx8sIsuBXcAcVQ1IHMBLwB+AkgB9fikFvhGRZSIyKkAxtAOygYlu1d1bIlIvQLGUugX4qKo/VFUzcJai3gZkAXmq+k1Vx4EzW/dlIhItInWBqzh2eYmq1kxVs8D5QxRo6oubWuI4C4hIfeBT4BFV3ReIGFS12K2KiAP6uEXyKiUi1wC7VHVZVX92OS5W1d7AEJwqxMsCEEMI0Bt4zV3T5iA+qoo4He7yCdcBnwTgsxvh/HXdFmgJ1BOR26s6DlVdC/wNmAPMwlnl1HPSi85CljiqOREJxUkaH6jqtEDH41aFfEdg2n8uBq4TkTScNewHiMj7AYjDe6GxXTj1+YFYUCwdSPcq/U3FSSSBMgT4WVV3BuCzLwe2qGq2qhYB04CLAhAHqvq2qvZW1ctwqo42BiIO106vNYxa4NQYnDFLHNWYu1jV28BaVX0hgHE0EZEo930Ezv+k66o6DlUdo6pxqhqPUyXyrapW+V+VIlJPRCJL3wNXEoAFxVR1B7Bd/n97988aRRSFYfx5iX+IiE2CISAxhcFCsBARwTL4BRQJQSSIhYiolQg2NhbaBm0UBQUVghJsRIQUgiixUFRsg4gQwYAigoQQjsWc1SUm4MDMjsr7a+bu3WXmzsLumbkzc460NbuGyYJnDRmlgWmq9B7YLWld/m6GaegGCkkbczkA7KO57wSK4nlj2R4D7lex0toKOf3rJN2hKDTVK+kDcC4irnV4GHuAQ8CbvL4AcDYiHnR4HP3ADUldFAcbExHR2K2wf4E+YLL4f2IVcDsiHjY0lhMUxc3WADPA4SYGkfP5e4GjTWw/IqYl3QVeUEwNvaS5lB/3JPUAC8DxiPjciY0u958FXAAmJB2hCK4HKtmWU46YmVkZnqoyM7NSHDjMzKwUBw4zMyvFgcPMzEpx4DAzs1IcOMwqIGlxSYbYyp7gljTYZJZms6X8HIdZNb5nShaz/57POMxqlHU7LmY9k+eStmT/ZklTkl7nciD7+yRNZu2TV5JaaTO6JF3NGg+P8gl+s0Y4cJhVo3vJVNVI23tfI2IX1yonkwAAAQtJREFUcIkiuy/ZvhkR24FbwHj2jwOPs/bJDuBt9g8BlyNiG/AF2F/z/pityE+Om1VA0reIWL9M/zuKIlgzmbDyY0T0SJoD+iNiIftnI6JX0idgU0TMt61jkCKV/VC+PgOsjojz9e+Z2e98xmFWv1ihvdJnljPf1l7E1yetQQ4cZvUbaVs+y/ZTfpU2PQg8yfYUcAx+Fs/a0KlBmv0pH7WYVaO7LYMxFHXAW7fkrpU0TXGgNpp9J4Hrkk5TVPFrZbU9BVzJbKaLFEFktvbRm5XgaxxmNcprHDsjYq7psZhVxVNVZmZWis84zMysFJ9xmJlZKQ4cZmZWigOHmZmV4sBhZmalOHCYmVkpPwAByIRfWrmnxwAAAABJRU5ErkJggg==\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"train_accuracies, val_accuracies = \\\n",
" UDA_pytorch_classifier_fit(deeper_convnet,\n",
" torch.optim.Adam(deeper_convnet.parameters(),\n",
" lr=learning_rate),\n",
" nn.CrossEntropyLoss(), # includes softmax\n",
" proper_train_dataset, val_dataset,\n",
" num_epochs, batch_size)\n",
"\n",
"UDA_plot_train_val_accuracy_vs_epoch(train_accuracies, val_accuracies)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Finally evaluate on test data"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"test_dataset = torchvision.datasets.MNIST(root='data/',\n",
" train=False,\n",
" transform=transforms.ToTensor(),\n",
" download=True)\n",
"test_images = torch.tensor([image.numpy() for image, label in test_dataset])"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Test accuracy: 0.9243\n"
]
}
],
"source": [
"predicted_test_labels = UDA_pytorch_classifier_predict(simple_model, test_images)\n",
"print('Test accuracy:', UDA_compute_accuracy(predicted_test_labels, test_dataset.targets))"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Test accuracy: 0.9783\n"
]
}
],
"source": [
"predicted_test_labels = UDA_pytorch_classifier_predict(deeper_model, test_images)\n",
"print('Test accuracy:', UDA_compute_accuracy(predicted_test_labels, test_dataset.targets))"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Test accuracy: 0.9804\n"
]
}
],
"source": [
"predicted_test_labels = UDA_pytorch_classifier_predict(simple_convnet, test_images)\n",
"print('Test accuracy:', UDA_compute_accuracy(predicted_test_labels, test_dataset.targets))"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Test accuracy: 0.9861\n"
]
}
],
"source": [
"predicted_test_labels = UDA_pytorch_classifier_predict(deeper_convnet, test_images)\n",
"print('Test accuracy:', UDA_compute_accuracy(predicted_test_labels, test_dataset.targets))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
"""
Helper code for Carnegie Mellon University's Unstructured Data Analytics course
Author: George H. Chen (georgechen [at symbol] cmu.edu)
I wrote this code for my class to make teaching how to use PyTorch as simple as
using Keras. Note that this code only has been tested using categorical cross
entropy loss.
"""
import matplotlib.pyplot as plt
import numpy as np
import sys
import torch
import torch.nn as nn
from matplotlib.ticker import MaxNLocator
from torchnlp.encoders.text import stack_and_pad_tensors
from torchnlp.samplers import BucketBatchSampler
from torchnlp.utils import collate_tensors
def UDA_pytorch_classifier_fit(model, optimizer, loss,
proper_train_dataset, val_dataset,
num_epochs, batch_size, device=None,
sequence=False):
"""
Trains a neural net classifier `model` using an `optimizer` such as Adam or
stochastic gradient descent. We specifically minimize the given `loss`
using the data given by `proper_train_dataset` using the number of epochs
given by `num_epochs` and a batch size given by `batch_size`.
Accuracies on the (proper) training data (`proper_train_dataset`) and
validation data (`val_dataset`) are computed at the end of each epoch;
`val_dataset` can be set to None if you don't want to use a validation set.
The function outputs the training and validation accuracies.
You can manually set which device (CPU or GPU) to use with the optional
`device` argument (e.g., setting `device=torch.device('cpu')` or
`device=torch.device('cuda')`). By default, the code tries to use a GPU if
it is available.
Lastly, the boolean argument `sequence` says whether we are looking at time
series data (set this True for working with recurrent neural nets).
"""
if device is None:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
if loss._get_name() != 'CrossEntropyLoss':
raise Exception('Unsupported loss: ' + loss._get_name())
if not sequence:
# PyTorch uses DataLoader to load data in batches
proper_train_loader = \
torch.utils.data.DataLoader(dataset=proper_train_dataset,
batch_size=batch_size,
shuffle=True)
if val_dataset is not None:
val_loader = torch.utils.data.DataLoader(dataset=val_dataset,
batch_size=batch_size,
shuffle=False)
else:
proper_train_loader = \
UDA_get_batches_sequence(proper_train_dataset,
batch_size,
shuffle=True,
device=device)
if val_dataset is not None:
val_loader = \
UDA_get_batches_sequence(val_dataset,
batch_size,
shuffle=False,
device=device)
proper_train_size = len(proper_train_dataset)
val_size = len(val_dataset)
train_accuracies = np.zeros(num_epochs)
val_accuracies = np.zeros(num_epochs)
for epoch_idx in range(num_epochs):
# go through training data
num_training_examples_so_far = 0
for batch_idx, (batch_features, batch_labels) \
in enumerate(proper_train_loader):
# make sure the data are stored on the right device
batch_features = batch_features.to(device)
batch_labels = batch_labels.to(device)
# make predictions for current batch and compute loss
batch_outputs = model(batch_features)
batch_loss = loss(batch_outputs, batch_labels)
# update model parameters
optimizer.zero_grad() # reset which direction optimizer is going
batch_loss.backward() # compute new direction optimizer should go
optimizer.step() # move the optimizer
# draw fancy progress bar
num_training_examples_so_far += batch_features.shape[0]
sys.stdout.write('\r')
sys.stdout.write("Epoch %d [%-50s] %d/%d"
% (epoch_idx + 1,
'=' * int(num_training_examples_so_far
/ proper_train_size * 50),
num_training_examples_so_far,
proper_train_size))
sys.stdout.flush()
# draw fancy progress bar at 100%
sys.stdout.write('\r')
sys.stdout.write("Epoch %d [%-50s] %d/%d"
% (epoch_idx + 1,
'=' * 50,
num_training_examples_so_far, proper_train_size))
sys.stdout.flush()
sys.stdout.write('\n')
sys.stdout.flush()
# compute proper training and validation set raw accuracies
train_accuracy = \
UDA_pytorch_classifier_evaluate(model,
proper_train_dataset,
device=device,
batch_size=batch_size,
sequence=sequence)
print(' Train accuracy: %.4f' % train_accuracy, flush=True)
train_accuracies[epoch_idx] = train_accuracy
if val_dataset is not None:
val_accuracy = \
UDA_pytorch_classifier_evaluate(model,
val_dataset,
device=device,
batch_size=batch_size,
sequence=sequence)
print(' Validation accuracy: %.4f' % val_accuracy, flush=True)
val_accuracies[epoch_idx] = val_accuracy
return train_accuracies, val_accuracies
def UDA_pytorch_model_transform(model, inputs, device=None, batch_size=128,
sequence=False):
"""
Given a neural net `model`, evaluate the model given `inputs`, which should
*not* be already batched. This helper function automatically batches the
data, feeds each batch through the neural net, and then unbatches the
outputs. The outputs are stored as a PyTorch tensor.
You can manually set which device (CPU or GPU) to use with the optional
`device` argument (e.g., setting `device=torch.device('cpu')` or
`device=torch.device('cuda')`). By default, the code tries to use a GPU if
it is available.
You can also manually set `batch_size`; this is less critical than in
training since we are, at this point, just evaluating the model without
updating its parameters.
Lastly, the boolean argument `sequence` says whether we are looking at time
series data (set this True for working with recurrent neural nets).
"""
if device is None:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
# batch the inputs
if not sequence:
feature_loader = torch.utils.data.DataLoader(dataset=inputs,
batch_size=batch_size,
shuffle=False)
else:
feature_loader = \
UDA_get_batches_from_encoded_text(inputs,
None,
batch_size,
shuffle=False,
device=device)
outputs = []
with torch.no_grad():
idx = 0
for batch_features in feature_loader:
batch_features = batch_features.to(device)
batch_outputs = model(batch_features)
outputs.append(batch_outputs)
return torch.cat(outputs, 0)
def UDA_pytorch_classifier_predict(model, inputs, device=None, batch_size=128,
sequence=False):
"""
Given a neural net classifier `model`, predict labels for the given
`inputs`, which should *not* be already batched. This helper function
automatically batches the data, feeds each batch through the neural net,
and then computes predicted labels by looking at the argmax. The output
predicted labels are stored as a PyTorch tensor.
You can manually set which device (CPU or GPU) to use with the optional
`device` argument (e.g., setting `device=torch.device('cpu')` or
`device=torch.device('cuda')`). By default, the code tries to use a GPU if
it is available.
You can also manually set `batch_size`; this is less critical than in
training since we are, at this point, just evaluating the model without
updating its parameters.
Lastly, the boolean argument `sequence` says whether we are looking at time
series data (set this True for working with recurrent neural nets).
"""
outputs = UDA_pytorch_model_transform(model,
inputs,
device=device,
batch_size=batch_size,
sequence=sequence)
with torch.no_grad():
return outputs.argmax(axis=1).view(-1)
def UDA_pytorch_classifier_evaluate(model, dataset, device=None,
batch_size=128, sequence=False):
"""
Evaluate the raw accuracy of a neural net classifier `model` for a
`dataset`, which should be a list of pairs of the format (input, label).
You can manually set which device (CPU or GPU) to use with the optional
`device` argument (e.g., setting `device=torch.device('cpu')` or
`device=torch.device('cuda')`). By default, the code tries to use a GPU if
it is available.
You can also manually set `batch_size`; this is less critical than in
training since we are, at this point, just evaluating the model without
updating its parameters.
Lastly, the boolean argument `sequence` says whether we are looking at time
series data (set this True for working with recurrent neural nets).
"""
if device is None:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
if not sequence:
loader = torch.utils.data.DataLoader(dataset=dataset,
batch_size=batch_size,
shuffle=False)
else:
loader = UDA_get_batches_sequence(dataset,
batch_size,
shuffle=False,
device=device)
with torch.no_grad():
num_correct = 0.
for batch_features, batch_labels in loader:
batch_features = batch_features.to(device)
batch_outputs = model(batch_features)
batch_predicted_labels = batch_outputs.argmax(axis=1)
if type(batch_labels) == np.ndarray:
batch_predicted_labels = \
batch_predicted_labels.view(-1).cpu().numpy()
num_correct += (batch_predicted_labels == batch_labels).sum()
else:
num_correct += \
(batch_predicted_labels.view(-1)
== batch_labels.to(device).view(-1)).sum().item()
return num_correct / len(dataset)
def UDA_plot_train_val_accuracy_vs_epoch(train_accuracies, val_accuracies):
"""
Helper function for plotting (proper) training and validation accuracies
across epochs; `train_accuracies` and `val_accuracies` should be the same
length, which should equal the number of epochs.
"""
ax = plt.figure().gca()
num_epochs = len(train_accuracies)
plt.plot(np.arange(1, num_epochs + 1), train_accuracies, '-o',
label='Training')
plt.plot(np.arange(1, num_epochs + 1), val_accuracies, '-+',
label='Validation')
plt.legend()
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
def UDA_compute_accuracy(labels1, labels2):
"""
Computes the raw accuracy of two label sequences `labels1` and `labels2`
agreeing. This helper function coerces both label sequences to be on the
CPU, flattened, and stored as 1D NumPy arrays before computing the average
agreement.
"""
if type(labels1) == torch.Tensor:
labels1 = labels1.detach().view(-1).cpu().numpy()
elif type(labels1) != np.ndarray:
labels1 = np.array(labels1).flatten()
else:
labels1 = labels1.flatten()
if type(labels2) == torch.Tensor:
labels2 = labels2.detach().view(-1).cpu().numpy()
elif type(labels2) != np.ndarray:
labels2 = np.array(labels2).flatten()
else:
labels2 = labels2.flatten()
return np.mean(labels1 == labels2)
class UDA_LSTMforSequential(nn.Module):
"""
This helper class allows for an LSTM to be used with nn.Sequential().
"""
def __init__(self, input_size, hidden_size, return_sequences=False):
super().__init__()
self.return_sequences = return_sequences
self.model = nn.LSTM(input_size=input_size,
hidden_size=hidden_size,
batch_first=True) # axis 0 indexes data in batch
def forward(self, x):
# x should be of shape (batch size, sequence length, feature dimension)
outputs, _ = self.model(x)
if self.return_sequences:
return outputs
else:
return outputs[:, -1, :] # take last time step's output
def UDA_get_batches_sequence(dataset, batch_size, shuffle=True, device=None):
"""
Helper function that does the same thing as
`UDA_get_batches_from_encoded_text()` except that the input dataset is a
list of pairs of the format (encoded text, label). This function
basically converts the input format to be what is expected by
`UDA_get_batches_from_encoded_text()` and then runs that function. See
the documentation for that function to understand what the arguments are.
"""
text_encoded = []
labels = []
for text, label in dataset:
text_encoded.append(text)
labels.append(label)
return UDA_get_batches_from_encoded_text(text_encoded, labels,
batch_size, shuffle, device)
def UDA_get_batches_from_encoded_text(text_encoded, labels, batch_size,
shuffle=True, device=None):
"""
Batches sequence data, where sequences within the same batch could have
unequal lengths, so padding is needed to get their lengths to be the same
for feeding to the neural net. The input text `text_encoded` should already
be encoded so that each text sequence consists of word indices to represent
indices into a vocabulary. The i-th element of `text_encoded` should have a
label given by the i-th entry in `labels` (which will be converted to a
PyTorch tensor). The batch size is specified by `batch_size`.
If `shuffle` is set to True, a bucket sampling strategy is used that reduces
how much padding is needed in different batches while injecting some
randomness.
You can manually set which device (CPU or GPU) to use with the optional
`device` argument (e.g., setting `device=torch.device('cpu')` or
`device=torch.device('cuda')`). By default, the code tries to use a GPU if
it is available.
"""
if device is None:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if shuffle:
# use bucket sampling strategy to reduce the amount of padding needed
sampler = torch.utils.data.sampler.SequentialSampler(text_encoded)
loader = BucketBatchSampler(
sampler, batch_size=batch_size, drop_last=False,
sort_key=lambda i: text_encoded[i].shape[0])
else:
indices = list(range(len(text_encoded)))
loader = torch.utils.data.DataLoader(dataset=indices,
batch_size=batch_size,
shuffle=False)
if labels is None:
batches = [collate_tensors([text_encoded[i] for i in batch],
stack_tensors=stack_and_pad_tensors
).tensor.to(device)
for batch in loader]
else:
batches = [(collate_tensors([text_encoded[i] for i in batch],
stack_tensors=stack_and_pad_tensors
).tensor.to(device),
torch.tensor([labels[i] for i in batch],
dtype=torch.long).to(device).view(-1))
for batch in loader]
return batches
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment